package de.jungblut.math.minimize;

import de.jungblut.math.DoubleVector;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

/* loaded from: input_file:de/jungblut/math/minimize/Fmincg.class */
public final class Fmincg extends AbstractMinimizer {
    private static final Logger LOG = LogManager.getLogger(Fmincg.class);
    public static double EXT = 3.0d;
    private static final double RHO = 0.01d;
    private static final double SIG = 0.5d;
    private static final double INT = 0.1d;
    private static final int MAX = 20;
    private static final int RATIO = 100;

    public static DoubleVector minimizeFunction(CostFunction costFunction, DoubleVector doubleVector, int i, boolean z) {
        return new Fmincg().minimize(costFunction, doubleVector, i, z);
    }

    @Override // de.jungblut.math.minimize.Minimizer
    public final DoubleVector minimize(CostFunction costFunction, DoubleVector doubleVector, int i, boolean z) {
        double sqrt;
        DoubleVector doubleVector2 = doubleVector;
        boolean z2 = false;
        CostGradientTuple evaluateCost = costFunction.evaluateCost(doubleVector2);
        double cost = evaluateCost.getCost();
        DoubleVector gradient = evaluateCost.getGradient();
        int i2 = 0 + (i < 0 ? 1 : 0);
        DoubleVector multiply = gradient.multiply(-1.0d);
        double dot = multiply.multiply(-1.0d).dot(multiply);
        double d = 1 / (1.0d - dot);
        while (i2 < Math.abs(i)) {
            int i3 = i2 + (i > 0 ? 1 : 0);
            DoubleVector deepCopy = doubleVector2.deepCopy();
            double d2 = cost;
            gradient.deepCopy();
            doubleVector2 = doubleVector2.add(multiply.multiply(d));
            CostGradientTuple evaluateCost2 = costFunction.evaluateCost(doubleVector2);
            double cost2 = evaluateCost2.getCost();
            DoubleVector gradient2 = evaluateCost2.getGradient();
            i2 = i3 + (i < 0 ? 1 : 0);
            double dot2 = gradient2.dot(multiply);
            double d3 = cost;
            double d4 = dot;
            double d5 = -d;
            int min = i > 0 ? MAX : Math.min(MAX, (-i) - i2);
            boolean z3 = false;
            double d6 = -1.0d;
            while (true) {
                if (((cost2 > cost + ((d * RHO) * dot)) || (dot2 > (-0.5d) * dot)) && min > 0) {
                    d6 = d;
                    if (cost2 > cost) {
                        sqrt = d5 - ((((SIG * d4) * d5) * d5) / (((d4 * d5) + cost2) - d3));
                    } else {
                        double d7 = ((6.0d * (cost2 - d3)) / d5) + (3.0d * (dot2 + d4));
                        double d8 = (3.0d * (d3 - cost2)) - (d5 * (d4 + (2.0d * dot2)));
                        sqrt = (Math.sqrt((d8 * d8) - (((d7 * dot2) * d5) * d5)) - d8) / d7;
                    }
                    if (Double.isNaN(sqrt) || Double.isInfinite(sqrt)) {
                        sqrt = d5 / 2.0d;
                    }
                    double max = Math.max(Math.min(sqrt, INT * d5), 0.9d * d5);
                    d += max;
                    doubleVector2 = doubleVector2.add(multiply.multiply(max));
                    CostGradientTuple evaluateCost3 = costFunction.evaluateCost(doubleVector2);
                    cost2 = evaluateCost3.getCost();
                    gradient2 = evaluateCost3.getGradient();
                    min--;
                    i2 += i < 0 ? 1 : 0;
                    dot2 = gradient2.dot(multiply);
                    d5 -= max;
                } else {
                    if (cost2 > cost + (d * RHO * dot) || dot2 > (-0.5d) * dot) {
                        break;
                    }
                    if (dot2 > SIG * dot) {
                        z3 = true;
                        break;
                    }
                    if (min == 0) {
                        break;
                    }
                    double d9 = (3.0d * (d3 - cost2)) - (d5 * (d4 + (2.0d * dot2)));
                    double sqrt2 = (((-dot2) * d5) * d5) / (d9 + Math.sqrt((d9 * d9) - ((((((6.0d * (cost2 - d3)) / d5) + (3.0d * (dot2 + d4))) * dot2) * d5) * d5)));
                    if (Double.isNaN(sqrt2) || Double.isInfinite(sqrt2) || sqrt2 < 0.0d) {
                        sqrt2 = d6 < -0.5d ? d * (EXT - 1.0d) : (d6 - d) / 2.0d;
                    } else if (d6 > -0.5d && sqrt2 + d > d6) {
                        sqrt2 = (d6 - d) / 2.0d;
                    } else if (d6 < -0.5d && sqrt2 + d > d * EXT) {
                        sqrt2 = d * (EXT - 1.0d);
                    } else if (sqrt2 < (-d5) * INT) {
                        sqrt2 = (-d5) * INT;
                    } else if (d6 > -0.5d && sqrt2 < (d6 - d) * 0.9d) {
                        sqrt2 = (d6 - d) * 0.9d;
                    }
                    d3 = cost2;
                    d4 = dot2;
                    d5 = -sqrt2;
                    d += sqrt2;
                    doubleVector2 = doubleVector2.add(multiply.multiply(sqrt2));
                    CostGradientTuple evaluateCost4 = costFunction.evaluateCost(doubleVector2);
                    cost2 = evaluateCost4.getCost();
                    gradient2 = evaluateCost4.getGradient();
                    min--;
                    i2 += i < 0 ? 1 : 0;
                    dot2 = gradient2.dot(multiply);
                }
            }
            if (!z3) {
                doubleVector2 = deepCopy;
                cost = d2;
                if (z2 || i2 > Math.abs(i)) {
                    break;
                }
                gradient = gradient2;
                multiply = gradient.multiply(-1.0d);
                dot = multiply.multiply(-1.0d).dot(multiply);
                d = 1.0d / (1.0d - dot);
                z2 = true;
            } else {
                cost = cost2;
                if (z) {
                    LOG.info("Iteration " + i2 + " | Cost: " + cost);
                }
                onIterationFinished(i2, cost, doubleVector2);
                multiply = multiply.multiply((gradient2.dot(gradient2) - gradient.dot(gradient2)) / gradient.dot(gradient)).subtract(gradient2);
                gradient = gradient2;
                double dot3 = gradient.dot(multiply);
                if (dot3 > 0.0d) {
                    multiply = gradient.multiply(-1.0d);
                    dot3 = multiply.multiply(-1.0d).dot(multiply);
                }
                d *= Math.min(100.0d, dot / (dot3 - 2.2251E-308d));
                dot = dot3;
                z2 = false;
            }
        }
        return doubleVector2;
    }
}
