package org.nd4j.linalg.solvers;

import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.solvers.api.IterationListener;
import org.nd4j.linalg.solvers.api.OptimizableByGradientValueMatrix;
import org.nd4j.linalg.solvers.api.OptimizerMatrix;
import org.nd4j.linalg.solvers.api.TrainingEvaluator;
import org.nd4j.linalg.solvers.exception.InvalidStepException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/linalg/solvers/VectorizedDeepLearningGradientAscent.class */
public class VectorizedDeepLearningGradientAscent implements OptimizerMatrix {
    private IterationListener listener;
    boolean converged;
    OptimizableByGradientValueMatrix optimizable;
    private double maxStep;
    static final double initialStepSize = 0.20000000298023224d;
    double tolerance;
    int maxIterations;
    VectorizedBackTrackLineSearch lineMaximizer;
    double stpmax;
    private static Logger logger = LoggerFactory.getLogger(VectorizedDeepLearningGradientAscent.class);
    final double eps = 1.000000013351432E-10d;
    double step;
    TrainingEvaluator eval;

    public VectorizedDeepLearningGradientAscent(OptimizableByGradientValueMatrix optimizableByGradientValueMatrix, double d) {
        this.converged = false;
        this.maxStep = 1.0d;
        this.tolerance = 9.999999747378752E-6d;
        this.maxIterations = 200;
        this.stpmax = 100.0d;
        this.eps = 1.000000013351432E-10d;
        this.step = initialStepSize;
        this.optimizable = optimizableByGradientValueMatrix;
        this.lineMaximizer = new VectorizedBackTrackLineSearch(optimizableByGradientValueMatrix);
        this.lineMaximizer.setAbsTolx(this.tolerance);
    }

    public VectorizedDeepLearningGradientAscent(OptimizableByGradientValueMatrix optimizableByGradientValueMatrix, IterationListener iterationListener) {
        this(optimizableByGradientValueMatrix, 0.009999999776482582d);
        this.listener = iterationListener;
    }

    public VectorizedDeepLearningGradientAscent(OptimizableByGradientValueMatrix optimizableByGradientValueMatrix, double d, IterationListener iterationListener) {
        this(optimizableByGradientValueMatrix, d);
        this.listener = iterationListener;
    }

    public VectorizedDeepLearningGradientAscent(OptimizableByGradientValueMatrix optimizableByGradientValueMatrix) {
        this(optimizableByGradientValueMatrix, 0.009999999776482582d);
    }

    @Override // org.nd4j.linalg.solvers.api.OptimizerMatrix
    public void setMaxIterations(int i) {
        this.maxIterations = i;
    }

    public OptimizableByGradientValueMatrix getOptimizable() {
        return this.optimizable;
    }

    @Override // org.nd4j.linalg.solvers.api.OptimizerMatrix
    public boolean isConverged() {
        return this.converged;
    }

    public VectorizedBackTrackLineSearch getLineMaximizer() {
        return this.lineMaximizer;
    }

    @Override // org.nd4j.linalg.solvers.api.OptimizerMatrix
    public void setTolerance(double d) {
        this.tolerance = d;
    }

    public double getInitialStepSize() {
        return initialStepSize;
    }

    public void setInitialStepSize(double d) {
        this.step = d;
    }

    public double getStpmax() {
        return this.stpmax;
    }

    public void setStpmax(double d) {
        this.stpmax = d;
    }

    @Override // org.nd4j.linalg.solvers.api.OptimizerMatrix
    public boolean optimize() {
        return optimize(this.maxIterations);
    }

    @Override // org.nd4j.linalg.solvers.api.OptimizerMatrix
    public boolean optimize(int i) {
        double value;
        double value2 = this.optimizable.getValue();
        INDArray valueGradient = this.optimizable.getValueGradient(0);
        for (int i2 = 0; i2 < i; i2++) {
            logger.info("At iteration " + i2 + ", cost = " + value2 + ", scaled = " + this.maxStep + " step = " + this.step + ", gradient infty-norm = " + valueGradient.normmax(Integer.MAX_VALUE));
            this.optimizable.setCurrentIteration(i2);
            double doubleValue = ((Double) valueGradient.norm2(Integer.MAX_VALUE).element()).doubleValue();
            if (doubleValue > this.stpmax) {
                logger.info("*** Step 2-norm " + doubleValue + " greater than max " + this.stpmax + "  Scaling...");
                valueGradient.muli(Double.valueOf(this.stpmax / doubleValue));
            }
            try {
                this.step = this.lineMaximizer.optimize(valueGradient, i2, this.step);
                value = this.optimizable.getValue();
            } catch (InvalidStepException e) {
                logger.warn("Error during computation", e);
            }
            if (2.0d * Math.abs(value - value2) <= this.tolerance * (Math.abs(value) + Math.abs(value2) + 1.000000013351432E-10d)) {
                logger.info("Gradient Ascent: Value difference " + Math.abs(value - value2) + " below tolerance; saying converged.");
                this.converged = true;
                if (this.listener == null) {
                    return true;
                }
                this.listener.iterationDone(i2);
                return true;
            }
            value2 = value;
            valueGradient = this.optimizable.getValueGradient(i2);
            if (this.listener != null && 0 == 0) {
                this.listener.iterationDone(i2);
            }
            if (this.eval != null && this.eval.shouldStop(i2)) {
                return true;
            }
        }
        return false;
    }

    public void setMaxStepSize(double d) {
        this.maxStep = d;
    }
}
