package de.jungblut.math.minimize;

import de.jungblut.math.DoubleVector;
import de.jungblut.math.dense.DenseDoubleVector;
import gnu.trove.list.array.TDoubleArrayList;
import java.util.ArrayList;
import org.apache.commons.math3.util.FastMath;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

/* loaded from: input_file:de/jungblut/math/minimize/OWLQN.class */
public class OWLQN extends AbstractMinimizer {
    private static final Logger LOG = LogManager.getLogger(OWLQN.class);
    private DoubleVector x;
    private DoubleVector grad;
    private DoubleVector newX;
    private DoubleVector newGrad;
    private DoubleVector dir;
    private DoubleVector steepestDescDir;
    private double[] alphas;
    private TDoubleArrayList roList;
    private TDoubleArrayList costs;
    private ArrayList<DoubleVector> sList;
    private ArrayList<DoubleVector> yList;
    private double value;
    private int m = 10;
    private double l1weight = 0.0d;
    private double tol = 1.0E-4d;
    private boolean gradCheck = false;

    @Override // de.jungblut.math.minimize.Minimizer
    public DoubleVector minimize(CostFunction costFunction, DoubleVector doubleVector, int i, boolean z) {
        DenseDoubleVector denseDoubleVector = new DenseDoubleVector(doubleVector.getDimension());
        this.x = doubleVector;
        this.grad = denseDoubleVector;
        this.newX = doubleVector.deepCopy();
        this.newGrad = denseDoubleVector;
        this.dir = denseDoubleVector;
        this.steepestDescDir = this.newGrad;
        this.alphas = new double[this.m];
        this.roList = new TDoubleArrayList(this.m);
        this.costs = new TDoubleArrayList(this.m);
        this.sList = new ArrayList<>();
        this.yList = new ArrayList<>();
        this.value = evaluateL1(costFunction);
        this.grad = this.newGrad;
        for (int i2 = 0; i2 < i; i2++) {
            updateDir(costFunction, z);
            boolean backTrackingLineSearch = backTrackingLineSearch(i2, costFunction);
            shift();
            this.costs.add(this.value);
            if (this.costs.size() > 5) {
                double d = this.costs.get(0);
                while (this.costs.size() > 5) {
                    this.costs.removeAt(0);
                }
                if (((d - this.value) / this.costs.size()) / Math.abs(this.value) < this.tol) {
                    break;
                }
            }
            if (!backTrackingLineSearch) {
                break;
            }
            if (z) {
                LOG.info("Iteration " + i2 + " | Cost: " + this.value);
            }
        }
        this.x = null;
        this.grad = null;
        this.newGrad = null;
        this.dir = null;
        this.steepestDescDir = null;
        this.alphas = null;
        this.roList = null;
        this.costs = null;
        this.sList = null;
        this.yList = null;
        return this.newX;
    }

    private void updateDir(CostFunction costFunction, boolean z) {
        makeSteepestDescDir();
        mapDirectionByInverseHessian();
        fixDirectionSigns();
        if (this.gradCheck) {
            testDirectionDerivation(costFunction);
        }
    }

    private void testDirectionDerivation(CostFunction costFunction) {
        double sqrt = FastMath.sqrt(this.dir.dot(this.dir));
        if (sqrt != 0.0d) {
            double d = 1.05E-8d / sqrt;
            getNextPoint(d);
            double evaluateL1 = (evaluateL1(costFunction) - this.value) / d;
            double directionDerivation = directionDerivation();
            LOG.info("GradCheck: expected= " + evaluateL1 + " vs. " + directionDerivation + "! AbsDiff= " + Math.abs(evaluateL1 - directionDerivation));
        }
    }

    private void fixDirectionSigns() {
        if (this.l1weight > 0.0d) {
            for (int i = 0; i < this.dir.getDimension(); i++) {
                if (this.dir.get(i) * this.steepestDescDir.get(i) <= 0.0d) {
                    this.dir.set(i, 0.0d);
                }
            }
        }
    }

    private void mapDirectionByInverseHessian() {
        int size = this.sList.size();
        if (size != 0) {
            for (int i = size - 1; i >= 0; i--) {
                this.alphas[i] = (-this.sList.get(i).dot(this.dir)) / this.roList.get(i);
                addMult(this.dir, this.yList.get(i), this.alphas[i]);
            }
            DoubleVector doubleVector = this.yList.get(size - 1);
            scale(this.dir, this.roList.get(size - 1) / doubleVector.dot(doubleVector));
            for (int i2 = 0; i2 < size; i2++) {
                addMult(this.dir, this.sList.get(i2), (-this.alphas[i2]) - (this.yList.get(i2).dot(this.dir) / this.roList.get(i2)));
            }
        }
    }

    private void makeSteepestDescDir() {
        if (this.l1weight == 0.0d) {
            scaleInto(this.dir, this.grad, -1.0d);
        } else {
            for (int i = 0; i < this.dir.getDimension(); i++) {
                if (this.x.get(i) < 0.0d) {
                    this.dir.set(i, (-this.grad.get(i)) + this.l1weight);
                } else if (this.x.get(i) > 0.0d) {
                    this.dir.set(i, (-this.grad.get(i)) - this.l1weight);
                } else if (this.grad.get(i) < (-this.l1weight)) {
                    this.dir.set(i, (-this.grad.get(i)) - this.l1weight);
                } else if (this.grad.get(i) > this.l1weight) {
                    this.dir.set(i, (-this.grad.get(i)) + this.l1weight);
                } else {
                    this.dir.set(i, 0.0d);
                }
            }
        }
        this.steepestDescDir = this.dir;
    }

    private boolean backTrackingLineSearch(int i, CostFunction costFunction) {
        double directionDerivation = directionDerivation();
        if (directionDerivation > 0.0d) {
            throw new RuntimeException("L-BFGS chose a non-descent direction: check your gradient!");
        }
        if (directionDerivation == 0.0d || Double.isNaN(directionDerivation)) {
            LOG.info("L-BFGS apparently found the minimum. No direction to descent anymore.");
            return false;
        }
        double d = 1.0d;
        double d2 = 0.5d;
        if (i == 0) {
            d = 1.0d / FastMath.sqrt(this.dir.dot(this.dir));
            d2 = 0.1d;
        }
        double d3 = this.value;
        while (true) {
            getNextPoint(d);
            this.value = evaluateL1(costFunction);
            if (Double.isNaN(this.value) || this.value <= d3 + (1.0E-4d * directionDerivation * d)) {
                return true;
            }
            d *= d2;
        }
    }

    private void getNextPoint(double d) {
        addMultInto(this.newX, this.x, this.dir, d);
        if (this.l1weight > 0.0d) {
            for (int i = 0; i < this.x.getDimension(); i++) {
                if (this.x.get(i) * this.newX.get(i) < 0.0d) {
                    this.newX.set(i, 0.0d);
                }
            }
        }
    }

    private void addMultInto(DoubleVector doubleVector, DoubleVector doubleVector2, DoubleVector doubleVector3, double d) {
        for (int i = 0; i < doubleVector.getDimension(); i++) {
            doubleVector.set(i, doubleVector2.get(i) + (doubleVector3.get(i) * d));
        }
    }

    private void addMult(DoubleVector doubleVector, DoubleVector doubleVector2, double d) {
        for (int i = 0; i < doubleVector.getDimension(); i++) {
            doubleVector.set(i, doubleVector.get(i) + (doubleVector2.get(i) * d));
        }
    }

    private void scale(DoubleVector doubleVector, double d) {
        for (int i = 0; i < doubleVector.getDimension(); i++) {
            doubleVector.set(i, doubleVector.get(i) * d);
        }
    }

    void scaleInto(DoubleVector doubleVector, DoubleVector doubleVector2, double d) {
        for (int i = 0; i < doubleVector.getDimension(); i++) {
            doubleVector.set(i, doubleVector2.get(i) * d);
        }
    }

    private double directionDerivation() {
        if (this.l1weight == 0.0d) {
            return this.dir.dot(this.grad);
        }
        double d = 0.0d;
        for (int i = 0; i < this.dir.getDimension(); i++) {
            if (this.dir.get(i) != 0.0d) {
                if (this.x.get(i) < 0.0d) {
                    d += this.dir.get(i) * (this.grad.get(i) - this.l1weight);
                } else if (this.x.get(i) > 0.0d) {
                    d += this.dir.get(i) * (this.grad.get(i) + this.l1weight);
                } else if (this.dir.get(i) < 0.0d) {
                    d += this.dir.get(i) * (this.grad.get(i) - this.l1weight);
                } else if (this.dir.get(i) > 0.0d) {
                    d += this.dir.get(i) * (this.grad.get(i) + this.l1weight);
                }
            }
        }
        return d;
    }

    private double evaluateL1(CostFunction costFunction) {
        CostGradientTuple evaluateCost = costFunction.evaluateCost(this.newX);
        this.newGrad = evaluateCost.getGradient();
        double cost = evaluateCost.getCost();
        if (this.l1weight > 0.0d) {
            for (int i = 0; i < this.newGrad.getDimension(); i++) {
                cost += Math.abs(this.newX.get(i)) * this.l1weight;
            }
        }
        return cost;
    }

    private void shift() {
        DenseDoubleVector denseDoubleVector = null;
        DenseDoubleVector denseDoubleVector2 = null;
        if (this.sList.size() < this.m) {
            denseDoubleVector = new DenseDoubleVector(this.x.getDimension());
            denseDoubleVector2 = new DenseDoubleVector(this.x.getDimension());
        }
        if (denseDoubleVector == null) {
            denseDoubleVector = (DoubleVector) this.sList.get(0);
            this.sList.remove(0);
            denseDoubleVector2 = (DoubleVector) this.yList.get(0);
            this.yList.remove(0);
            this.roList.removeAt(0);
        }
        addMultInto(denseDoubleVector, this.newX, this.x, -1.0d);
        addMultInto(denseDoubleVector2, this.newGrad, this.grad, -1.0d);
        double dot = denseDoubleVector.dot(denseDoubleVector2);
        this.sList.add(denseDoubleVector);
        this.yList.add(denseDoubleVector2);
        this.roList.add(dot);
        DoubleVector deepCopy = this.newX.deepCopy();
        this.newX = this.x.deepCopy();
        this.x = deepCopy;
        DoubleVector deepCopy2 = this.newGrad.deepCopy();
        this.newGrad = this.grad.deepCopy();
        this.grad = deepCopy2;
    }

    public OWLQN doGradChecks() {
        this.gradCheck = true;
        return this;
    }

    public OWLQN setM(int i) {
        this.m = i;
        return this;
    }

    public OWLQN setL1Weight(double d) {
        this.l1weight = d;
        return this;
    }

    public OWLQN setTolerance(double d) {
        this.tol = d;
        return this;
    }

    public static DoubleVector minimizeFunction(CostFunction costFunction, DoubleVector doubleVector, int i, boolean z) {
        return new OWLQN().minimize(costFunction, doubleVector, i, z);
    }
}
