package de.jungblut.math.minimize;

import com.google.common.base.Preconditions;
import de.jungblut.math.DoubleVector;
import java.util.Arrays;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

/* loaded from: input_file:de/jungblut/math/minimize/GradientDescent.class */
public final class GradientDescent extends AbstractMinimizer {
    private static final Logger LOG = LogManager.getLogger(GradientDescent.class);
    private static final int COST_HISTORY = 3;
    private final boolean breakOnDivergence;
    private final double breakDifference;
    private final double momentum;
    private final double alpha;
    private final boolean boldDriver;
    private final double boldIncreasePercentage;
    private final double boldDecreasePercentage;
    private final int annealingIteration;

    /* loaded from: input_file:de/jungblut/math/minimize/GradientDescent$GradientDescentBuilder.class */
    public static class GradientDescentBuilder {
        private final double alpha;
        private double breakDifference;
        private double momentum;
        private boolean breakOnDivergence;
        private boolean boldDriver;
        private double boldIncreasePercentage;
        private double boldDecreasePercentage;
        private int annealingIteration = -1;

        private GradientDescentBuilder(double d) {
            this.alpha = d;
        }

        public GradientDescent build() {
            return new GradientDescent(this);
        }

        public GradientDescentBuilder momentum(double d) {
            Preconditions.checkArgument(d >= 0.0d && d <= 1.0d, "Momentum must be between 0 and 1.");
            this.momentum = d;
            return this;
        }

        public GradientDescentBuilder boldDriver() {
            return boldDriver(0.5d, 0.05d);
        }

        public GradientDescentBuilder boldDriver(double d, double d2) {
            Preconditions.checkArgument(d >= 0.0d && d <= 1.0d, "increasedCostPercentage must be between 0 and 1.");
            Preconditions.checkArgument(d2 >= 0.0d && d2 <= 1.0d, "decreasedCostPercentage must be between 0 and 1.");
            this.boldDriver = true;
            this.boldIncreasePercentage = d;
            this.boldDecreasePercentage = d2;
            return this;
        }

        public GradientDescentBuilder breakOnDivergence() {
            this.breakOnDivergence = true;
            return this;
        }

        public GradientDescentBuilder breakOnDifference(double d) {
            this.breakDifference = d;
            return this;
        }

        public GradientDescentBuilder annealingAfter(int i) {
            Preconditions.checkArgument(i > 0, "Annealing can only kick in after the first iteration! Given: " + i);
            this.annealingIteration = i;
            return this;
        }

        public static GradientDescentBuilder create(double d) {
            return new GradientDescentBuilder(d);
        }
    }

    private GradientDescent(GradientDescentBuilder gradientDescentBuilder) {
        this.alpha = gradientDescentBuilder.alpha;
        this.breakDifference = gradientDescentBuilder.breakDifference;
        this.momentum = gradientDescentBuilder.momentum;
        this.breakOnDivergence = gradientDescentBuilder.breakOnDivergence;
        this.boldDriver = gradientDescentBuilder.boldDriver;
        this.boldIncreasePercentage = gradientDescentBuilder.boldIncreasePercentage;
        this.boldDecreasePercentage = gradientDescentBuilder.boldDecreasePercentage;
        this.annealingIteration = gradientDescentBuilder.annealingIteration;
    }

    public GradientDescent(double d, double d2) {
        this(GradientDescentBuilder.create(d).breakOnDifference(d2));
    }

    @Override // de.jungblut.math.minimize.Minimizer
    public final DoubleVector minimize(CostFunction costFunction, DoubleVector doubleVector, int i, boolean z) {
        double[] dArr = new double[COST_HISTORY];
        Arrays.fill(dArr, Double.MAX_VALUE);
        int length = dArr.length - 1;
        DoubleVector doubleVector2 = null;
        DoubleVector doubleVector3 = null;
        DoubleVector doubleVector4 = doubleVector;
        double d = this.alpha;
        for (int i2 = 0; i2 < i; i2++) {
            CostGradientTuple evaluateCost = costFunction.evaluateCost(doubleVector4);
            if (z) {
                LOG.info("Iteration " + i2 + " | Cost: " + evaluateCost.getCost());
            }
            shiftLeft(dArr);
            dArr[length] = evaluateCost.getCost();
            if (converged(dArr, this.breakDifference) || (this.breakOnDivergence && ascending(dArr))) {
                break;
            }
            DoubleVector gradient = evaluateCost.getGradient();
            if (this.boldDriver) {
                if (doubleVector3 != null) {
                    if (getCostDifference(dArr) < 0.0d) {
                        d += d * this.boldDecreasePercentage;
                    } else {
                        doubleVector4 = doubleVector2;
                        gradient = doubleVector3;
                        d -= d * this.boldIncreasePercentage;
                    }
                    if (z) {
                        LOG.info("Iteration " + i2 + " | Alpha: " + d + "\n");
                    }
                }
                doubleVector3 = gradient;
            }
            if (this.annealingIteration > 0) {
                d = this.alpha / (1.0d + (i2 / this.annealingIteration));
            }
            doubleVector2 = doubleVector4;
            doubleVector4 = doubleVector4.subtract(gradient.multiply(d));
            if (doubleVector2 != null && this.momentum != 0.0d) {
                doubleVector4 = doubleVector4.add(doubleVector2.subtract(doubleVector4).multiply(this.momentum));
            }
            onIterationFinished(i2, evaluateCost.getCost(), doubleVector4);
        }
        return doubleVector4;
    }

    public static DoubleVector minimizeFunction(CostFunction costFunction, DoubleVector doubleVector, double d, double d2, int i, boolean z) {
        return new GradientDescent(d, d2).minimize(costFunction, doubleVector, i, z);
    }

    static void shiftLeft(double[] dArr) {
        int length = dArr.length - 1;
        for (int i = 0; i < length; i++) {
            dArr[i] = dArr[i + 1];
        }
        dArr[length] = Double.MAX_VALUE;
    }

    static boolean converged(double[] dArr, double d) {
        return Math.abs(getCostDifference(dArr)) < d;
    }

    static boolean ascending(double[] dArr) {
        double d = dArr[0];
        boolean z = false;
        for (int i = 1; i < dArr.length; i++) {
            z = d < dArr[i];
            d = dArr[i];
        }
        return z;
    }

    private static double getCostDifference(double[] dArr) {
        return dArr[dArr.length - 1] - dArr[dArr.length - 2];
    }
}
