package cc.mallet.classify;

import cc.mallet.classify.constraints.ge.MaxEntGEConstraint;
import cc.mallet.optimize.Optimizable;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.MatrixOps;
import cc.mallet.util.MalletProgressMessageLogger;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.logging.Logger;

/* loaded from: input_file:cc/mallet/classify/MaxEntOptimizableByGE.class */
public class MaxEntOptimizableByGE implements Optimizable.ByGradientValue {
    private static Logger progressLogger;
    protected int defaultFeatureIndex;
    protected double cachedValue;
    protected double[] cachedGradient;
    protected double[] parameters;
    protected InstanceList trainingList;
    protected MaxEnt classifier;
    protected ArrayList<MaxEntGEConstraint> constraints;
    static final /* synthetic */ boolean $assertionsDisabled;
    protected boolean cacheStale = true;
    protected double temperature = 1.0d;
    protected double objWeight = 1.0d;
    protected double gaussianPriorVariance = 1.0d;

    static {
        $assertionsDisabled = !MaxEntOptimizableByGE.class.desiredAssertionStatus();
        progressLogger = MalletProgressMessageLogger.getLogger(String.valueOf(MaxEntOptimizableByGE.class.getName()) + "-pl");
    }

    public MaxEntOptimizableByGE(InstanceList instanceList, ArrayList<MaxEntGEConstraint> arrayList, MaxEnt maxEnt) {
        this.trainingList = instanceList;
        int size = instanceList.getDataAlphabet().size();
        this.defaultFeatureIndex = size;
        int size2 = instanceList.getTargetAlphabet().size();
        this.cachedGradient = new double[(size + 1) * size2];
        this.cachedValue = 0.0d;
        if (maxEnt != null) {
            this.parameters = maxEnt.parameters;
            this.classifier = maxEnt;
        } else {
            this.parameters = new double[(size + 1) * size2];
            this.classifier = new MaxEnt(instanceList.getPipe(), this.parameters);
        }
        this.constraints = arrayList;
        Iterator<MaxEntGEConstraint> it = arrayList.iterator();
        while (it.hasNext()) {
            it.next().preProcess(instanceList);
        }
    }

    public void setGaussianPriorVariance(double d) {
        this.gaussianPriorVariance = d;
    }

    public void setTemperature(double d) {
        this.temperature = d;
    }

    public void setWeight(double d) {
        this.objWeight = d;
    }

    public MaxEnt getClassifier() {
        return this.classifier;
    }

    @Override // cc.mallet.optimize.Optimizable.ByGradientValue
    public double getValue() {
        if (!this.cacheStale) {
            return this.cachedValue;
        }
        if (this.objWeight == 0.0d) {
            return 0.0d;
        }
        Iterator<MaxEntGEConstraint> it = this.constraints.iterator();
        while (it.hasNext()) {
            it.next().zeroExpectations();
        }
        Arrays.fill(this.cachedGradient, 0.0d);
        int size = this.trainingList.getDataAlphabet().size() + 1;
        int size2 = this.trainingList.getTargetAlphabet().size();
        double[][] dArr = new double[this.trainingList.size()][size2];
        double[] dArr2 = new double[size2];
        for (int i = 0; i < this.trainingList.size(); i++) {
            Instance instance = this.trainingList.get(i);
            double instanceWeight = this.trainingList.getInstanceWeight(instance);
            if (instance.getTarget() == null) {
                FeatureVector featureVector = (FeatureVector) instance.getData();
                this.classifier.getClassificationScoresWithTemperature(instance, this.temperature, dArr[i]);
                Iterator<MaxEntGEConstraint> it2 = this.constraints.iterator();
                while (it2.hasNext()) {
                    it2.next().computeExpectations(featureVector, dArr[i], instanceWeight);
                }
            }
        }
        double d = 0.0d;
        Iterator<MaxEntGEConstraint> it3 = this.constraints.iterator();
        while (it3.hasNext()) {
            d += it3.next().getValue();
        }
        double d2 = d * this.objWeight;
        for (int i2 = 0; i2 < this.trainingList.size(); i2++) {
            Instance instance2 = this.trainingList.get(i2);
            if (instance2.getTarget() == null) {
                Arrays.fill(dArr2, 0.0d);
                double d3 = 0.0d;
                double instanceWeight2 = this.trainingList.getInstanceWeight(instance2);
                FeatureVector featureVector2 = (FeatureVector) instance2.getData();
                Iterator<MaxEntGEConstraint> it4 = this.constraints.iterator();
                while (it4.hasNext()) {
                    MaxEntGEConstraint next = it4.next();
                    next.preProcess(featureVector2);
                    for (int i3 = 0; i3 < size2; i3++) {
                        double compositeConstraintFeatureValue = next.getCompositeConstraintFeatureValue(featureVector2, i3);
                        int i4 = i3;
                        dArr2[i4] = dArr2[i4] + compositeConstraintFeatureValue;
                        d3 += compositeConstraintFeatureValue * dArr[i2][i3];
                    }
                }
                for (int i5 = 0; i5 < size2; i5++) {
                    if (dArr[i2][i5] != 0.0d) {
                        if (!$assertionsDisabled && Double.isInfinite(dArr[i2][i5])) {
                            throw new AssertionError();
                        }
                        double d4 = (((this.objWeight * instanceWeight2) * dArr[i2][i5]) * (dArr2[i5] - d3)) / this.temperature;
                        if (!$assertionsDisabled && Double.isNaN(d4)) {
                            throw new AssertionError();
                        }
                        MatrixOps.rowPlusEquals(this.cachedGradient, size, i5, featureVector2, d4);
                        double[] dArr3 = this.cachedGradient;
                        int i6 = (size * i5) + this.defaultFeatureIndex;
                        dArr3[i6] = dArr3[i6] + d4;
                    }
                }
            }
        }
        this.cachedValue = d2;
        this.cacheStale = false;
        progressLogger.info("Value (GE=" + d2 + " Gaussian prior= " + getRegularization() + ") = " + this.cachedValue);
        return this.cachedValue;
    }

    protected double getRegularization() {
        double d = 0.0d;
        for (int i = 0; i < this.parameters.length; i++) {
            double d2 = this.parameters[i];
            d -= (d2 * d2) / (2.0d * this.gaussianPriorVariance);
            double[] dArr = this.cachedGradient;
            int i2 = i;
            dArr[i2] = dArr[i2] - (d2 / this.gaussianPriorVariance);
        }
        this.cachedValue += d;
        return d;
    }

    @Override // cc.mallet.optimize.Optimizable.ByGradientValue
    public void getValueGradient(double[] dArr) {
        if (this.cacheStale) {
            getValue();
        }
        if (!$assertionsDisabled && dArr.length != this.cachedGradient.length) {
            throw new AssertionError();
        }
        System.arraycopy(this.cachedGradient, 0, dArr, 0, dArr.length);
    }

    @Override // cc.mallet.optimize.Optimizable
    public int getNumParameters() {
        return this.parameters.length;
    }

    @Override // cc.mallet.optimize.Optimizable
    public double getParameter(int i) {
        return this.parameters[i];
    }

    @Override // cc.mallet.optimize.Optimizable
    public void getParameters(double[] dArr) {
        if (!$assertionsDisabled && dArr.length != this.parameters.length) {
            throw new AssertionError();
        }
        System.arraycopy(this.parameters, 0, dArr, 0, dArr.length);
    }

    @Override // cc.mallet.optimize.Optimizable
    public void setParameter(int i, double d) {
        this.cacheStale = true;
        this.parameters[i] = d;
    }

    @Override // cc.mallet.optimize.Optimizable
    public void setParameters(double[] dArr) {
        if (!$assertionsDisabled && dArr.length != this.parameters.length) {
            throw new AssertionError();
        }
        this.cacheStale = true;
        System.arraycopy(dArr, 0, this.parameters, 0, this.parameters.length);
    }
}
