package org.deeplearning4j.optimize.solvers;

import java.util.Collection;
import java.util.Iterator;
import java.util.LinkedList;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.optimize.api.StepFunction;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/deeplearning4j/optimize/solvers/LBFGS.class */
public class LBFGS extends BaseOptimizer {
    private static final long serialVersionUID = 9148732140255034888L;
    private int m;

    public LBFGS(NeuralNetConfiguration neuralNetConfiguration, StepFunction stepFunction, Collection<TrainingListener> collection, Model model) {
        super(neuralNetConfiguration, stepFunction, collection, model);
        this.m = 4;
    }

    @Override // org.deeplearning4j.optimize.solvers.BaseOptimizer, org.deeplearning4j.optimize.api.ConvexOptimizer
    public void setupSearchState(Pair<Gradient, Double> pair) {
        super.setupSearchState(pair);
        INDArray iNDArray = (INDArray) this.searchState.get(BaseOptimizer.PARAMS_KEY);
        this.searchState.put("s", new LinkedList());
        this.searchState.put("y", new LinkedList());
        this.searchState.put("rho", new LinkedList());
        this.searchState.put("oldparams", iNDArray.dup());
    }

    @Override // org.deeplearning4j.optimize.solvers.BaseOptimizer, org.deeplearning4j.optimize.api.ConvexOptimizer
    public void preProcessLine() {
        if (this.searchState.containsKey(BaseOptimizer.SEARCH_DIR)) {
            return;
        }
        this.searchState.put(BaseOptimizer.SEARCH_DIR, ((INDArray) this.searchState.get("g")).dup());
    }

    @Override // org.deeplearning4j.optimize.solvers.BaseOptimizer, org.deeplearning4j.optimize.api.ConvexOptimizer
    public void postStep(INDArray iNDArray) {
        INDArray sub;
        INDArray sub2;
        INDArray iNDArray2 = (INDArray) this.searchState.get("oldparams");
        INDArray params = this.model.params();
        INDArray iNDArray3 = (INDArray) this.searchState.get("g");
        LinkedList linkedList = (LinkedList) this.searchState.get("rho");
        LinkedList linkedList2 = (LinkedList) this.searchState.get("s");
        LinkedList linkedList3 = (LinkedList) this.searchState.get("y");
        double dot = Nd4j.getBlasWrapper().dot(iNDArray2, iNDArray3) + Nd4j.EPS_THRESHOLD;
        double dot2 = Nd4j.getBlasWrapper().dot(iNDArray3, iNDArray3) + Nd4j.EPS_THRESHOLD;
        if (linkedList2.size() >= this.m) {
            sub = (INDArray) linkedList2.removeLast();
            sub2 = (INDArray) linkedList3.removeLast();
            linkedList.removeLast();
            sub.assign(params).subi(iNDArray2);
            sub2.assign(iNDArray).subi(iNDArray3);
        } else {
            sub = params.sub(iNDArray2);
            sub2 = iNDArray.sub(iNDArray3);
        }
        linkedList.addFirst(Double.valueOf(1.0d / dot));
        linkedList2.addFirst(sub);
        linkedList3.addFirst(sub2);
        if (linkedList2.size() != linkedList3.size()) {
            throw new IllegalStateException("Gradient and parameter sizes are not equal");
        }
        int min = Math.min(this.m, linkedList2.size());
        double[] dArr = new double[min];
        Iterator it = linkedList2.iterator();
        Iterator it2 = linkedList3.iterator();
        Iterator it3 = linkedList.iterator();
        INDArray iNDArray4 = (INDArray) this.searchState.get(BaseOptimizer.SEARCH_DIR);
        iNDArray4.assign(iNDArray);
        for (int i = 0; i < min; i++) {
            INDArray iNDArray5 = (INDArray) it.next();
            INDArray iNDArray6 = (INDArray) it2.next();
            double doubleValue = ((Double) it3.next()).doubleValue();
            if (iNDArray5.length() != iNDArray4.length()) {
                throw new IllegalStateException("Gradients and parameters length not equal");
            }
            dArr[i] = doubleValue * Nd4j.getBlasWrapper().dot(iNDArray5, iNDArray4);
            Nd4j.getBlasWrapper().level1().axpy(iNDArray4.length(), -dArr[i], iNDArray6, iNDArray4);
        }
        iNDArray4.muli(Double.valueOf(dot / dot2));
        Iterator descendingIterator = linkedList2.descendingIterator();
        Iterator descendingIterator2 = linkedList3.descendingIterator();
        Iterator descendingIterator3 = linkedList.descendingIterator();
        for (int i2 = 0; i2 < min; i2++) {
            Nd4j.getBlasWrapper().level1().axpy(iNDArray.length(), dArr[i2] - (((Double) descendingIterator3.next()).doubleValue() * Nd4j.getBlasWrapper().dot((INDArray) descendingIterator2.next(), iNDArray4)), (INDArray) descendingIterator.next(), iNDArray4);
        }
        iNDArray2.assign(params);
        iNDArray3.assign(iNDArray);
    }
}
