package cc.mallet.fst.semi_supervised.pr;

import cc.mallet.fst.CRF;
import cc.mallet.optimize.Optimizable;
import cc.mallet.types.InstanceList;
import cc.mallet.types.MatrixOps;
import cc.mallet.types.Sequence;
import cc.mallet.util.MalletLogger;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.logging.Logger;

/* loaded from: input_file:cc/mallet/fst/semi_supervised/pr/CRFOptimizableByKL.class */
public class CRFOptimizableByKL implements Serializable, Optimizable.ByGradientValue {
    private static Logger logger;
    private static final long serialVersionUID = 1;
    protected int numParameters;
    protected int numThreads;
    protected double weight;
    protected double[] cachedGradient;
    protected List<double[]> initialProbList;
    protected List<double[]> finalProbList;
    protected List<double[][][]> transitionProbList;
    protected InstanceList trainingSet;
    protected CRF crf;
    protected CRF.Factors constraints;
    protected CRF.Factors expectations;
    protected ThreadPoolExecutor executor;
    protected PRAuxiliaryModel auxModel;
    static final /* synthetic */ boolean $assertionsDisabled;
    protected double gaussianPriorVariance = 1.0d;
    protected double cachedValue = -1.23456789E8d;
    protected int cachedValueWeightsStamp = -1;
    protected int cachedGradientWeightsStamp = -1;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:cc/mallet/fst/semi_supervised/pr/CRFOptimizableByKL$ExpectationTask.class */
    public class ExpectationTask implements Callable<Double> {
        private int start;
        private int end;
        private CRF.Factors expectationsCopy;

        public ExpectationTask(int i, int i2, CRF.Factors factors) {
            this.start = i;
            this.end = i2;
            this.expectationsCopy = factors;
        }

        public CRF.Factors getExpectationsCopy() {
            return this.expectationsCopy;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.concurrent.Callable
        public Double call() throws Exception {
            double d = 0.0d;
            for (int i = this.start; i < this.end; i++) {
                Sequence sequence = (Sequence) CRFOptimizableByKL.this.trainingSet.get(i).getData();
                double[] dArr = CRFOptimizableByKL.this.initialProbList.get(i);
                double[] dArr2 = CRFOptimizableByKL.this.finalProbList.get(i);
                double[][][] dArr3 = CRFOptimizableByKL.this.transitionProbList.get(i);
                double[][][] dArr4 = new double[sequence.size()][CRFOptimizableByKL.this.crf.numStates()][CRFOptimizableByKL.this.crf.numStates()];
                for (int i2 = 0; i2 < sequence.size(); i2++) {
                    for (int i3 = 0; i3 < CRFOptimizableByKL.this.crf.numStates(); i3++) {
                        for (int i4 = 0; i4 < CRFOptimizableByKL.this.crf.numStates(); i4++) {
                            dArr4[i2][i3][i4] = Double.NEGATIVE_INFINITY;
                        }
                    }
                }
                double totalWeight = new SumLatticeKL(CRFOptimizableByKL.this.crf, sequence, dArr, dArr2, dArr3, dArr4, null).getTotalWeight();
                CRF crf = CRFOptimizableByKL.this.crf;
                CRF.Factors factors = this.expectationsCopy;
                factors.getClass();
                d = (d + totalWeight) - new SumLatticeDefaultCachedDot(crf, sequence, null, dArr4, new CRF.Factors.Incrementor(), false, null).getTotalWeight();
            }
            return Double.valueOf(d);
        }
    }

    static {
        $assertionsDisabled = !CRFOptimizableByKL.class.desiredAssertionStatus();
        logger = MalletLogger.getLogger(CRFOptimizableByKL.class.getName());
    }

    public CRFOptimizableByKL(CRF crf, InstanceList instanceList, PRAuxiliaryModel pRAuxiliaryModel, double[][][][] dArr, int i, double d) {
        this.crf = crf;
        this.trainingSet = instanceList;
        this.numParameters = crf.getParameters().getNumFactors();
        this.cachedGradient = new double[this.numParameters];
        if (!$assertionsDisabled && d <= 0.0d) {
            throw new AssertionError();
        }
        this.weight = d;
        gatherConstraints(pRAuxiliaryModel, dArr);
        this.numThreads = i;
        this.executor = (ThreadPoolExecutor) Executors.newFixedThreadPool(i);
    }

    private double[] toProbabilities(double[] dArr) {
        double[] dArr2 = new double[dArr.length];
        for (int i = 0; i < dArr.length; i++) {
            dArr2[i] = Math.exp(dArr[i]);
        }
        MatrixOps.normalize(dArr2);
        return dArr2;
    }

    private void toProbabilities(double[][][] dArr) {
        for (int i = 0; i < dArr.length; i++) {
            for (int i2 = 0; i2 < dArr[i].length; i2++) {
                for (int i3 = 0; i3 < dArr[i][i2].length; i3++) {
                    dArr[i][i2][i3] = Math.exp(dArr[i][i2][i3]);
                }
            }
        }
    }

    protected void gatherConstraints(PRAuxiliaryModel pRAuxiliaryModel, double[][][][] dArr) {
        this.initialProbList = new ArrayList();
        this.finalProbList = new ArrayList();
        this.transitionProbList = new ArrayList();
        this.constraints = new CRF.Factors(this.crf.getParameters());
        this.expectations = new CRF.Factors(this.crf.getParameters());
        this.constraints.zero();
        for (int i = 0; i < this.trainingSet.size(); i++) {
            Sequence sequence = (Sequence) this.trainingSet.get(i).getData();
            SumLatticePR sumLatticePR = new SumLatticePR(this.crf, i, sequence, null, pRAuxiliaryModel, dArr[i], false, null, null, true);
            double[][] gammas = sumLatticePR.getGammas();
            double[] probabilities = toProbabilities(gammas[0]);
            this.initialProbList.add(probabilities);
            double[] probabilities2 = toProbabilities(gammas[gammas.length - 1]);
            this.finalProbList.add(probabilities2);
            double[][][] xis = sumLatticePR.getXis();
            toProbabilities(xis);
            this.transitionProbList.add(xis);
            CRF crf = this.crf;
            CRF.Factors factors = this.constraints;
            factors.getClass();
            new SumLatticeKL(crf, sequence, probabilities, probabilities2, xis, null, new CRF.Factors.Incrementor());
        }
    }

    protected double getExpectationValue() {
        this.expectations.zero();
        ArrayList arrayList = new ArrayList();
        int size = this.trainingSet.size() / this.numThreads;
        int i = 0;
        int i2 = size;
        int i3 = 0;
        while (i3 < this.numThreads) {
            arrayList.add(new ExpectationTask(i, i2, new CRF.Factors(this.expectations)));
            i = i2;
            i2 = i3 == this.numThreads - 2 ? this.trainingSet.size() : i + size;
            i3++;
        }
        double d = 0.0d;
        try {
            Iterator it = this.executor.invokeAll(arrayList).iterator();
            while (it.hasNext()) {
                try {
                    d += ((Double) ((Future) it.next()).get()).doubleValue();
                } catch (ExecutionException e) {
                    e.printStackTrace();
                }
            }
        } catch (InterruptedException e2) {
            e2.printStackTrace();
        }
        Iterator it2 = arrayList.iterator();
        while (it2.hasNext()) {
            this.expectations.plusEquals(((ExpectationTask) ((Callable) it2.next())).getExpectationsCopy(), 1.0d);
        }
        return d;
    }

    @Override // cc.mallet.optimize.Optimizable.ByGradientValue
    public double getValue() {
        if (this.crf.getWeightsValueChangeStamp() != this.cachedValueWeightsStamp) {
            this.cachedValueWeightsStamp = this.crf.getWeightsValueChangeStamp();
            long currentTimeMillis = System.currentTimeMillis();
            this.cachedValue = getExpectationValue();
            double gaussianPrior = this.crf.getParameters().gaussianPrior(this.gaussianPriorVariance);
            this.cachedValue += gaussianPrior;
            logger.info("Gaussian prior = " + gaussianPrior);
            this.cachedValue *= this.weight;
            if (!$assertionsDisabled && (Double.isNaN(this.cachedValue) || Double.isInfinite(this.cachedValue))) {
                throw new AssertionError("Label likelihood is NaN/Infinite");
            }
            logger.info("getValue() (loglikelihood, optimizable by klDiv) = " + this.cachedValue);
            logger.fine("Inference milliseconds = " + (System.currentTimeMillis() - currentTimeMillis));
        }
        return this.cachedValue;
    }

    @Override // cc.mallet.optimize.Optimizable.ByGradientValue
    public void getValueGradient(double[] dArr) {
        if (this.cachedGradientWeightsStamp != this.crf.getWeightsValueChangeStamp()) {
            this.cachedGradientWeightsStamp = this.crf.getWeightsValueChangeStamp();
            getValue();
            this.expectations.plusEquals(this.constraints, -1.0d);
            this.expectations.plusEqualsGaussianPriorGradient(this.crf.getParameters(), -this.gaussianPriorVariance);
            this.expectations.assertNotNaNOrInfinite();
            this.expectations.getParameters(this.cachedGradient);
            MatrixOps.timesEquals(this.cachedGradient, -this.weight);
        }
        System.arraycopy(this.cachedGradient, 0, dArr, 0, this.cachedGradient.length);
    }

    @Override // cc.mallet.optimize.Optimizable
    public int getNumParameters() {
        return this.numParameters;
    }

    @Override // cc.mallet.optimize.Optimizable
    public void getParameters(double[] dArr) {
        this.crf.getParameters().getParameters(dArr);
    }

    @Override // cc.mallet.optimize.Optimizable
    public double getParameter(int i) {
        return this.crf.getParameters().getParameter(i);
    }

    @Override // cc.mallet.optimize.Optimizable
    public void setParameters(double[] dArr) {
        this.crf.getParameters().setParameters(dArr);
        this.crf.weightsValueChanged();
    }

    @Override // cc.mallet.optimize.Optimizable
    public void setParameter(int i, double d) {
        this.crf.getParameters().setParameter(i, d);
        this.crf.weightsValueChanged();
    }

    public void setGaussianPriorVariance(double d) {
        this.gaussianPriorVariance = d;
    }

    public void shutdown() {
        this.executor.shutdown();
    }
}
