package de.jungblut.ner;

import com.google.common.base.Preconditions;
import de.jungblut.classification.AbstractClassifier;
import de.jungblut.math.DoubleMatrix;
import de.jungblut.math.DoubleVector;
import de.jungblut.math.ViterbiUtils;
import de.jungblut.math.dense.DenseDoubleMatrix;
import de.jungblut.math.dense.DenseDoubleVector;
import de.jungblut.math.minimize.DenseMatrixFolder;
import de.jungblut.math.minimize.Minimizer;
import de.jungblut.math.sparse.SparseDoubleRowMatrix;
import java.util.Collections;

/* loaded from: input_file:de/jungblut/ner/MaxEntMarkovModel.class */
public final class MaxEntMarkovModel extends AbstractClassifier {
    private final Minimizer minimizer;
    private final boolean verbose;
    private final int numIterations;
    private DoubleMatrix theta;
    private int classes;

    public MaxEntMarkovModel(Minimizer minimizer, int i, boolean z) {
        this.minimizer = minimizer;
        this.numIterations = i;
        this.verbose = z;
    }

    public MaxEntMarkovModel(DenseDoubleMatrix denseDoubleMatrix, int i) {
        this(null, -1, false);
        this.theta = denseDoubleMatrix;
        this.classes = i;
    }

    @Override // de.jungblut.classification.AbstractClassifier, de.jungblut.classification.Classifier
    public void train(DoubleVector[] doubleVectorArr, DoubleVector[] doubleVectorArr2) {
        Preconditions.checkArgument(doubleVectorArr.length == doubleVectorArr2.length && doubleVectorArr.length > 0, "There wasn't at least a single featurevector, or the two array didn't match in size.");
        this.classes = doubleVectorArr2[0].getDimension() == 1 ? 2 : doubleVectorArr2[0].getDimension();
        SparseDoubleRowMatrix sparseDoubleRowMatrix = doubleVectorArr[0].isSparse() ? new SparseDoubleRowMatrix(doubleVectorArr) : new DenseDoubleMatrix(doubleVectorArr);
        this.theta = DenseMatrixFolder.unfoldMatrix(this.minimizer.minimize(new ConditionalLikelihoodCostFunction(sparseDoubleRowMatrix, new DenseDoubleMatrix(doubleVectorArr2)), new DenseDoubleVector(sparseDoubleRowMatrix.getColumnCount() * this.classes), this.numIterations, this.verbose), this.classes, (int) (r0.getLength() / this.classes));
    }

    @Override // de.jungblut.classification.Predictor
    public DoubleVector predict(DoubleVector doubleVector) {
        Preconditions.checkArgument(doubleVector.getClass().equals(UnrollableDoubleVector.class), "Features must be an instance of the class UnrollableDoubleVector.");
        UnrollableDoubleVector unrollableDoubleVector = (UnrollableDoubleVector) doubleVector;
        return predict(unrollableDoubleVector.getMainVector(), unrollableDoubleVector.getSideVectors());
    }

    public DoubleMatrix getTheta() {
        return this.theta;
    }

    public DoubleVector predict(DoubleVector doubleVector, DoubleVector[] doubleVectorArr) {
        return ViterbiUtils.decode(this.theta, new SparseDoubleRowMatrix(Collections.singletonList(doubleVector)), new SparseDoubleRowMatrix(doubleVectorArr), this.classes).getRowVector(0);
    }

    public DoubleMatrix predict(DoubleMatrix doubleMatrix, DoubleMatrix doubleMatrix2) {
        return ViterbiUtils.decode(this.theta, doubleMatrix, doubleMatrix2, this.classes);
    }
}
