package cc.mallet.regression;

import cc.mallet.types.FeatureVector;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.InvertedIndex;
import java.io.File;
import java.text.NumberFormat;
import java.util.Iterator;

/* loaded from: input_file:cc/mallet/regression/CoordinateDescent.class */
public class CoordinateDescent {
    LinearRegression regression;
    double[] parameters;
    InstanceList trainingData;
    double[] scaledResiduals;
    double tuningConstant;
    double[] sumSquaredX;
    double[] scaledThresholds;
    InvertedIndex featureIndex;
    int interceptIndex;
    int precisionIndex;
    int dimension;
    NumberFormat formatter = NumberFormat.getInstance();

    public CoordinateDescent(InstanceList instanceList, double d) {
        this.tuningConstant = d;
        this.trainingData = instanceList;
        this.regression = new LinearRegression(this.trainingData.getDataAlphabet());
        this.parameters = this.regression.getParameters();
        this.interceptIndex = this.parameters.length - 2;
        this.precisionIndex = this.parameters.length - 1;
        this.formatter.setMaximumFractionDigits(3);
        this.dimension = this.parameters.length - 1;
        this.scaledResiduals = new double[this.dimension];
        this.sumSquaredX = new double[this.dimension];
        this.scaledThresholds = new double[this.dimension];
        this.featureIndex = new InvertedIndex(instanceList);
        Iterator<Instance> it = instanceList.iterator();
        while (it.hasNext()) {
            Instance next = it.next();
            FeatureVector featureVector = (FeatureVector) next.getData();
            double doubleValue = ((Double) next.getTarget()).doubleValue();
            double[] dArr = this.scaledResiduals;
            int i = this.interceptIndex;
            dArr[i] = dArr[i] + doubleValue;
            for (int i2 = 0; i2 < featureVector.numLocations(); i2++) {
                int indexAtLocation = featureVector.indexAtLocation(i2);
                double valueAtLocation = featureVector.valueAtLocation(i2);
                double[] dArr2 = this.scaledResiduals;
                dArr2[indexAtLocation] = dArr2[indexAtLocation] + (doubleValue * valueAtLocation);
                double[] dArr3 = this.sumSquaredX;
                dArr3[indexAtLocation] = dArr3[indexAtLocation] + (valueAtLocation * valueAtLocation);
            }
        }
        double[] dArr4 = this.scaledResiduals;
        int i3 = this.interceptIndex;
        dArr4[i3] = dArr4[i3] / instanceList.size();
        for (int i4 = 0; i4 < this.dimension - 1; i4++) {
            double[] dArr5 = this.scaledResiduals;
            int i5 = i4;
            dArr5[i5] = dArr5[i5] / this.sumSquaredX[i4];
            this.scaledThresholds[i4] = this.tuningConstant / this.sumSquaredX[i4];
        }
        boolean z = false;
        int i6 = 0;
        while (!z) {
            double d2 = this.parameters[this.interceptIndex] - this.scaledResiduals[this.interceptIndex];
            double abs = 0.0d + Math.abs(d2);
            this.parameters[this.interceptIndex] = this.scaledResiduals[this.interceptIndex];
            Iterator<Instance> it2 = instanceList.iterator();
            while (it2.hasNext()) {
                FeatureVector featureVector2 = (FeatureVector) it2.next().getData();
                for (int i7 = 0; i7 < featureVector2.numLocations(); i7++) {
                    int indexAtLocation2 = featureVector2.indexAtLocation(i7);
                    double valueAtLocation2 = featureVector2.valueAtLocation(i7);
                    double[] dArr6 = this.scaledResiduals;
                    dArr6[indexAtLocation2] = dArr6[indexAtLocation2] + ((valueAtLocation2 * d2) / this.sumSquaredX[indexAtLocation2]);
                }
            }
            for (int i8 = 0; i8 < this.dimension - 1; i8++) {
                double d3 = this.parameters[i8];
                if (this.scaledResiduals[i8] > this.tuningConstant) {
                    this.parameters[i8] = this.scaledResiduals[i8] - this.tuningConstant;
                } else if (this.scaledResiduals[i8] < (-this.tuningConstant)) {
                    this.parameters[i8] = this.scaledResiduals[i8] + this.tuningConstant;
                }
                double d4 = d3 - this.parameters[i8];
                abs += Math.abs(d4);
                Iterator it3 = this.featureIndex.getInstancesWithFeature(i8).iterator();
                while (it3.hasNext()) {
                    FeatureVector featureVector3 = (FeatureVector) ((Instance) it3.next()).getData();
                    double d5 = 0.0d;
                    int i9 = 0;
                    while (true) {
                        if (i9 >= featureVector3.numLocations()) {
                            break;
                        }
                        if (featureVector3.indexAtLocation(i9) == i8) {
                            d5 = featureVector3.valueAtLocation(i9);
                            break;
                        }
                        i9++;
                    }
                    double[] dArr7 = this.scaledResiduals;
                    int i10 = this.interceptIndex;
                    dArr7[i10] = dArr7[i10] + ((d5 * d4) / instanceList.size());
                    for (int i11 = 0; i11 < featureVector3.numLocations(); i11++) {
                        int indexAtLocation3 = featureVector3.indexAtLocation(i11);
                        double valueAtLocation3 = featureVector3.valueAtLocation(i11);
                        if (indexAtLocation3 != i8) {
                            double[] dArr8 = this.scaledResiduals;
                            dArr8[indexAtLocation3] = dArr8[indexAtLocation3] + (((d5 * valueAtLocation3) * d4) / this.sumSquaredX[indexAtLocation3]);
                        }
                    }
                }
            }
            if (abs < 1.0E-4d) {
                z = true;
            } else {
                i6++;
                if (i6 % 100 == 0) {
                    System.out.println(abs);
                }
            }
        }
    }

    public String toString() {
        double d = 0.0d;
        for (int i = 0; i < this.trainingData.size(); i++) {
            Instance instance = this.trainingData.get(i);
            double doubleValue = ((Double) instance.getTarget()).doubleValue() - this.regression.predict(instance);
            d += doubleValue * doubleValue;
        }
        StringBuilder sb = new StringBuilder();
        sb.append("(Int)\t" + this.formatter.format(this.parameters[this.interceptIndex]) + "\n");
        for (int i2 = 0; i2 < this.dimension - 1; i2++) {
            sb.append(this.trainingData.getDataAlphabet().lookupObject(i2) + "\t");
            sb.append(String.valueOf(this.formatter.format(this.parameters[i2])) + "\n");
        }
        sb.append("SSE: " + this.formatter.format(d) + "\n");
        return sb.toString();
    }

    public static void main(String[] strArr) throws Exception {
        System.out.println(new CoordinateDescent(InstanceList.load(new File(strArr[0])), Double.parseDouble(strArr[1])));
    }
}
