package org.deeplearning4j.nn;

import java.io.Serializable;
import org.deeplearning4j.optimize.LogisticRegressionOptimizer;
import org.deeplearning4j.util.MatrixUtil;
import org.deeplearning4j.util.NonZeroStoppingConjugateGradient;
import org.jblas.DoubleMatrix;
import org.jblas.MatrixFunctions;

/* loaded from: input_file:org/deeplearning4j/nn/LogisticRegression.class */
public class LogisticRegression implements Serializable {
    private static final long serialVersionUID = -7065564817460914364L;
    private int nIn;
    private int nOut;
    private DoubleMatrix input;
    private DoubleMatrix labels;
    private DoubleMatrix W;
    private DoubleMatrix b;
    private double l2;
    private boolean useRegularization;

    /* loaded from: input_file:org/deeplearning4j/nn/LogisticRegression$Builder.class */
    public static class Builder {
        private DoubleMatrix W;
        private LogisticRegression ret;
        private DoubleMatrix b;
        private int nIn;
        private int nOut;
        private DoubleMatrix input;

        public Builder withWeights(DoubleMatrix doubleMatrix) {
            this.W = doubleMatrix;
            return this;
        }

        public Builder withBias(DoubleMatrix doubleMatrix) {
            this.b = doubleMatrix;
            return this;
        }

        public Builder numberOfInputs(int i) {
            this.nIn = i;
            return this;
        }

        public Builder numberOfOutputs(int i) {
            this.nOut = i;
            return this;
        }

        public LogisticRegression build() {
            this.ret = new LogisticRegression(this.input, this.nIn, this.nOut);
            if (this.W != null) {
                this.ret.W = this.W;
            }
            if (this.b != null) {
                this.ret.b = this.b;
            }
            return this.ret;
        }
    }

    private LogisticRegression() {
        this.l2 = 0.01d;
        this.useRegularization = true;
    }

    public LogisticRegression(DoubleMatrix doubleMatrix, DoubleMatrix doubleMatrix2, int i, int i2) {
        this.l2 = 0.01d;
        this.useRegularization = true;
        this.input = doubleMatrix;
        this.labels = doubleMatrix2;
        this.nIn = i;
        this.nOut = i2;
        this.W = DoubleMatrix.zeros(i, i2);
        this.b = DoubleMatrix.zeros(i2);
    }

    public LogisticRegression(DoubleMatrix doubleMatrix, int i, int i2) {
        this(doubleMatrix, null, i, i2);
    }

    public LogisticRegression(int i, int i2) {
        this(null, null, i, i2);
    }

    public synchronized void train(double d) {
        train(this.input, this.labels, d);
    }

    public synchronized void train(DoubleMatrix doubleMatrix, double d) {
        MatrixUtil.complainAboutMissMatchedMatrices(doubleMatrix, this.labels);
        train(doubleMatrix, this.labels, d);
    }

    public synchronized void trainTillConvergence(DoubleMatrix doubleMatrix, DoubleMatrix doubleMatrix2, double d, int i) {
        MatrixUtil.complainAboutMissMatchedMatrices(doubleMatrix, doubleMatrix2);
        this.input = doubleMatrix;
        this.labels = doubleMatrix2;
        trainTillConvergence(d, i);
    }

    public synchronized void trainTillConvergence(double d, int i) {
        new NonZeroStoppingConjugateGradient(new LogisticRegressionOptimizer(this, d)).optimize(i);
    }

    public synchronized void merge(LogisticRegression logisticRegression, int i) {
        this.W.addi(logisticRegression.W.subi(this.W).div(i));
        this.b.addi(logisticRegression.b.subi(this.b).div(i));
    }

    public synchronized double negativeLogLikelihood() {
        MatrixUtil.complainAboutMissMatchedMatrices(this.input, this.labels);
        DoubleMatrix softmax = MatrixUtil.softmax(this.input.mmul(this.W).addRowVector(this.b));
        if (!this.useRegularization) {
            return -this.labels.mul(MatrixUtil.log(softmax)).add(MatrixUtil.oneMinus(this.labels).mul(MatrixUtil.log(MatrixUtil.oneMinus(softmax)))).columnSums().mean();
        }
        return (-this.labels.mul(MatrixUtil.log(softmax)).add(MatrixUtil.oneMinus(this.labels).mul(MatrixUtil.log(MatrixUtil.oneMinus(softmax)))).columnSums().mean()) + ((2.0d / this.l2) * MatrixFunctions.pow(this.W, 2.0d).sum());
    }

    public synchronized void train(DoubleMatrix doubleMatrix, DoubleMatrix doubleMatrix2, double d) {
        MatrixUtil.complainAboutMissMatchedMatrices(doubleMatrix, doubleMatrix2);
        this.input = doubleMatrix;
        this.labels = doubleMatrix2;
        LogisticRegressionGradient gradient = getGradient(d);
        this.W.addi(gradient.getwGradient());
        this.b.addi(gradient.getbGradient());
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public LogisticRegression m15clone() {
        LogisticRegression logisticRegression = new LogisticRegression();
        logisticRegression.b = this.b.dup();
        logisticRegression.W = this.W.dup();
        logisticRegression.l2 = this.l2;
        logisticRegression.labels = this.labels.dup();
        logisticRegression.nIn = this.nIn;
        logisticRegression.nOut = this.nOut;
        logisticRegression.useRegularization = this.useRegularization;
        logisticRegression.input = this.input.dup();
        return logisticRegression;
    }

    public synchronized LogisticRegressionGradient getGradient(double d) {
        MatrixUtil.complainAboutMissMatchedMatrices(this.input, this.labels);
        DoubleMatrix sub = this.labels.sub(MatrixUtil.sigmoid(this.input.mmul(this.W).addRowVector(this.b)));
        if (this.useRegularization) {
            sub.divi(this.input.rows);
        }
        return new LogisticRegressionGradient(this.input.transpose().mmul(sub).mul(d), sub);
    }

    public synchronized DoubleMatrix predict(DoubleMatrix doubleMatrix) {
        return MatrixUtil.softmax(doubleMatrix.mmul(this.W).addRowVector(this.b));
    }

    public synchronized int getnIn() {
        return this.nIn;
    }

    public synchronized void setnIn(int i) {
        this.nIn = i;
    }

    public synchronized int getnOut() {
        return this.nOut;
    }

    public synchronized void setnOut(int i) {
        this.nOut = i;
    }

    public synchronized DoubleMatrix getInput() {
        return this.input;
    }

    public synchronized void setInput(DoubleMatrix doubleMatrix) {
        this.input = doubleMatrix;
    }

    public synchronized DoubleMatrix getLabels() {
        return this.labels;
    }

    public synchronized void setLabels(DoubleMatrix doubleMatrix) {
        this.labels = doubleMatrix;
    }

    public synchronized DoubleMatrix getW() {
        return this.W;
    }

    public synchronized void setW(DoubleMatrix doubleMatrix) {
        this.W = doubleMatrix;
    }

    public synchronized DoubleMatrix getB() {
        return this.b;
    }

    public synchronized void setB(DoubleMatrix doubleMatrix) {
        this.b = doubleMatrix;
    }

    public synchronized double getL2() {
        return this.l2;
    }

    public synchronized void setL2(double d) {
        this.l2 = d;
    }

    public synchronized boolean isUseRegularization() {
        return this.useRegularization;
    }

    public synchronized void setUseRegularization(boolean z) {
        this.useRegularization = z;
    }
}
