package cc.mallet.fst;

import cc.mallet.fst.TransducerTrainer;
import cc.mallet.optimize.LimitedMemoryBFGS;
import cc.mallet.optimize.Optimizable;
import cc.mallet.optimize.OptimizationException;
import cc.mallet.optimize.Optimizer;
import cc.mallet.types.InstanceList;
import cc.mallet.types.MatrixOps;
import cc.mallet.util.MalletLogger;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.BitSet;
import java.util.Random;
import java.util.logging.Logger;

/* loaded from: input_file:cc/mallet/fst/CRFTrainerByValueGradients.class */
public class CRFTrainerByValueGradients extends TransducerTrainer implements TransducerTrainer.ByOptimization {
    private static Logger logger;
    CRF crf;
    Optimizable.ByGradientValue[] optimizableByValueGradientObjects;
    OptimizableCRF ocrf;
    Optimizer opt;
    boolean converged;
    public static final int DEFAULT_MAX_RESETS = 3;
    private static final long serialVersionUID = 1;
    private static final int CURRENT_SERIAL_VERSION = 1;
    static final int NULL_INTEGER = -1;
    static final /* synthetic */ boolean $assertionsDisabled;
    int iterationCount = 0;
    private int cachedValueWeightsStamp = -1;
    private int cachedGradientWeightsStamp = -1;
    int maxResets = 3;

    /* loaded from: input_file:cc/mallet/fst/CRFTrainerByValueGradients$OptimizableCRF.class */
    public class OptimizableCRF implements Optimizable.ByGradientValue, Serializable {
        InstanceList trainingSet;
        double[] cachedGradie;
        CRF crf;
        Optimizable.ByGradientValue[] opts;
        private static final long serialVersionUID = 1;
        private static final int CURRENT_SERIAL_VERSION = 0;
        double cachedValue = -1.23456789E8d;
        BitSet infiniteValues = null;

        protected OptimizableCRF(CRF crf, InstanceList instanceList) {
            this.crf = crf;
            this.trainingSet = instanceList;
            this.opts = CRFTrainerByValueGradients.this.optimizableByValueGradientObjects;
            this.cachedGradie = new double[crf.parameters.getNumFactors()];
            CRFTrainerByValueGradients.this.cachedValueWeightsStamp = -1;
            CRFTrainerByValueGradients.this.cachedGradientWeightsStamp = -1;
        }

        @Override // cc.mallet.optimize.Optimizable
        public int getNumParameters() {
            return this.crf.parameters.getNumFactors();
        }

        @Override // cc.mallet.optimize.Optimizable
        public void getParameters(double[] dArr) {
            this.crf.parameters.getParameters(dArr);
        }

        @Override // cc.mallet.optimize.Optimizable
        public double getParameter(int i) {
            return this.crf.parameters.getParameter(i);
        }

        @Override // cc.mallet.optimize.Optimizable
        public void setParameters(double[] dArr) {
            this.crf.parameters.setParameters(dArr);
            this.crf.weightsValueChanged();
        }

        @Override // cc.mallet.optimize.Optimizable
        public void setParameter(int i, double d) {
            this.crf.parameters.setParameter(i, d);
            this.crf.weightsValueChanged();
        }

        @Override // cc.mallet.optimize.Optimizable.ByGradientValue
        public double getValue() {
            if (this.crf.weightsValueChangeStamp != CRFTrainerByValueGradients.this.cachedValueWeightsStamp) {
                long currentTimeMillis = System.currentTimeMillis();
                this.cachedValue = 0.0d;
                for (int i = 0; i < this.opts.length; i++) {
                    this.cachedValue += this.opts[i].getValue();
                }
                CRFTrainerByValueGradients.this.cachedValueWeightsStamp = this.crf.weightsValueChangeStamp;
                CRFTrainerByValueGradients.logger.info("getValue() (loglikelihood) = " + this.cachedValue);
                CRFTrainerByValueGradients.logger.fine("Inference milliseconds = " + (System.currentTimeMillis() - currentTimeMillis));
            }
            return this.cachedValue;
        }

        @Override // cc.mallet.optimize.Optimizable.ByGradientValue
        public void getValueGradient(double[] dArr) {
            if (CRFTrainerByValueGradients.this.cachedGradientWeightsStamp != this.crf.weightsValueChangeStamp) {
                getValue();
                MatrixOps.setAll(this.cachedGradie, 0.0d);
                double[] dArr2 = new double[dArr.length];
                for (int i = 0; i < this.opts.length; i++) {
                    MatrixOps.setAll(dArr2, 0.0d);
                    this.opts[i].getValueGradient(dArr2);
                    MatrixOps.plusEquals(this.cachedGradie, dArr2);
                }
                CRFTrainerByValueGradients.this.cachedGradientWeightsStamp = this.crf.weightsValueChangeStamp;
            }
            System.arraycopy(this.cachedGradie, 0, dArr, 0, this.cachedGradie.length);
        }

        private void writeObject(ObjectOutputStream objectOutputStream) throws IOException {
            objectOutputStream.writeInt(0);
            objectOutputStream.writeObject(this.trainingSet);
            objectOutputStream.writeDouble(this.cachedValue);
            objectOutputStream.writeObject(this.cachedGradie);
            objectOutputStream.writeObject(this.infiniteValues);
            objectOutputStream.writeObject(this.crf);
        }

        private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
            objectInputStream.readInt();
            this.trainingSet = (InstanceList) objectInputStream.readObject();
            this.cachedValue = objectInputStream.readDouble();
            this.cachedGradie = (double[]) objectInputStream.readObject();
            this.infiniteValues = (BitSet) objectInputStream.readObject();
            this.crf = (CRF) objectInputStream.readObject();
        }
    }

    static {
        $assertionsDisabled = !CRFTrainerByValueGradients.class.desiredAssertionStatus();
        logger = MalletLogger.getLogger(CRFTrainerByLabelLikelihood.class.getName());
    }

    public CRFTrainerByValueGradients(CRF crf, Optimizable.ByGradientValue[] byGradientValueArr) {
        this.crf = crf;
        this.optimizableByValueGradientObjects = byGradientValueArr;
    }

    @Override // cc.mallet.fst.TransducerTrainer
    public Transducer getTransducer() {
        return this.crf;
    }

    public CRF getCRF() {
        return this.crf;
    }

    @Override // cc.mallet.fst.TransducerTrainer.ByOptimization
    public Optimizer getOptimizer() {
        return this.opt;
    }

    public boolean isConverged() {
        return this.converged;
    }

    @Override // cc.mallet.fst.TransducerTrainer
    public boolean isFinishedTraining() {
        return this.converged;
    }

    @Override // cc.mallet.fst.TransducerTrainer
    public int getIteration() {
        return this.iterationCount;
    }

    public Optimizable.ByGradientValue[] getOptimizableByGradientValueObjects() {
        return this.optimizableByValueGradientObjects;
    }

    public OptimizableCRF getOptimizableCRF(InstanceList instanceList) {
        if (this.ocrf == null || this.ocrf.trainingSet != instanceList) {
            this.ocrf = new OptimizableCRF(this.crf, instanceList);
            this.opt = null;
        }
        return this.ocrf;
    }

    public Optimizer getOptimizer(InstanceList instanceList) {
        getOptimizableCRF(instanceList);
        if (this.opt == null || this.ocrf != this.opt.getOptimizable()) {
            this.opt = new LimitedMemoryBFGS(this.ocrf);
        }
        return this.opt;
    }

    public boolean trainIncremental(InstanceList instanceList) {
        return train(instanceList, Integer.MAX_VALUE);
    }

    @Override // cc.mallet.fst.TransducerTrainer
    public boolean train(InstanceList instanceList, int i) {
        if (i <= 0) {
            return false;
        }
        if (!$assertionsDisabled && instanceList.size() <= 0) {
            throw new AssertionError();
        }
        getOptimizableCRF(instanceList);
        getOptimizer(instanceList);
        int i2 = 0;
        boolean z = false;
        logger.info("CRF about to train with " + i + " iterations");
        int i3 = 0;
        while (true) {
            if (i3 >= i) {
                break;
            }
            try {
                long currentTimeMillis = System.currentTimeMillis();
                z = this.opt.optimize(1);
                logger.info("CRF finished one iteration of maximizer, i=" + i3 + ", " + ((System.currentTimeMillis() - currentTimeMillis) / 1000) + " secs.");
                this.iterationCount++;
                runEvaluators();
            } catch (OptimizationException e) {
                e.printStackTrace();
                logger.info("Catching exception.");
                if (i2 < this.maxResets) {
                    logger.info("Resetting optimizer.");
                    i2++;
                    this.opt = null;
                    getOptimizer(instanceList);
                } else {
                    logger.info("Saying converged.");
                    z = true;
                }
            }
            if (z) {
                logger.info("CRF training has converged, i=" + i3);
                break;
            }
            i3++;
        }
        return z;
    }

    public boolean train(InstanceList instanceList, int i, double[] dArr) {
        int i2 = 0;
        if (!$assertionsDisabled && dArr.length <= 0) {
            throw new AssertionError();
        }
        boolean z = false;
        for (int i3 = 0; i3 < dArr.length; i3++) {
            if (!$assertionsDisabled && dArr[i3] > 1.0d) {
                throw new AssertionError();
            }
            logger.info("Training on " + dArr[i3] + "% of the data this round.");
            z = dArr[i3] == 1.0d ? train(instanceList, i) : train(instanceList.split(new Random(serialVersionUID), new double[]{dArr[i3], 1.0d - dArr[i3]})[0], i);
            i2 += i;
        }
        return z;
    }

    public void setMaxResets(int i) {
        this.maxResets = i;
    }

    private void writeObject(ObjectOutputStream objectOutputStream) throws IOException {
        objectOutputStream.writeInt(1);
        objectOutputStream.writeInt(this.cachedGradientWeightsStamp);
        objectOutputStream.writeInt(this.cachedValueWeightsStamp);
        throw new IllegalStateException("Implementation not yet complete.");
    }

    private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
        objectInputStream.readInt();
        throw new IllegalStateException("Implementation not yet complete.");
    }
}
