package org.deeplearning4j.optimize.solvers;

import java.util.Collection;
import java.util.LinkedList;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.optimize.api.StepFunction;
import org.deeplearning4j.optimize.api.TerminationCondition;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;

/* loaded from: input_file:org/deeplearning4j/optimize/solvers/LBFGS.class */
public class LBFGS extends BaseOptimizer {
    private int m;

    public LBFGS(NeuralNetConfiguration neuralNetConfiguration, StepFunction stepFunction, Collection<IterationListener> collection, Model model) {
        super(neuralNetConfiguration, stepFunction, collection, model);
        this.m = 4;
    }

    public LBFGS(NeuralNetConfiguration neuralNetConfiguration, StepFunction stepFunction, Collection<IterationListener> collection, Collection<TerminationCondition> collection2, Model model) {
        super(neuralNetConfiguration, stepFunction, collection, collection2, model);
        this.m = 4;
    }

    @Override // org.deeplearning4j.optimize.solvers.BaseOptimizer
    protected boolean preFirstStepProcess(INDArray iNDArray) {
        this.searchState.put(BaseOptimizer.GRADIENT_KEY, iNDArray.mul(Double.valueOf(Nd4j.norm2(iNDArray).rdivi(Double.valueOf(1.0d)).getDouble(0))));
        return true;
    }

    @Override // org.deeplearning4j.optimize.solvers.BaseOptimizer, org.deeplearning4j.optimize.api.ConvexOptimizer
    public void setupSearchState(Pair<Gradient, Double> pair) {
        super.setupSearchState(pair);
        INDArray iNDArray = (INDArray) this.searchState.get(BaseOptimizer.GRADIENT_KEY);
        INDArray iNDArray2 = (INDArray) this.searchState.get(BaseOptimizer.PARAMS_KEY);
        this.searchState.put("s", new LinkedList());
        this.searchState.put("y", new LinkedList());
        this.searchState.put("rho", new LinkedList());
        this.searchState.put("alpha", Nd4j.create(this.m));
        this.searchState.put("oldparams", iNDArray2.dup());
        this.searchState.put("oldgradient", iNDArray.dup());
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.deeplearning4j.optimize.solvers.BaseOptimizer
    public void postFirstStep(INDArray iNDArray) {
        super.postFirstStep(iNDArray);
        if (this.step == 0.0d) {
            log.info("Unable to step in that direction...resetting");
            setupSearchState(this.model.gradientAndScore());
            this.step = 1.0d;
        }
    }

    @Override // org.deeplearning4j.optimize.solvers.BaseOptimizer, org.deeplearning4j.optimize.api.ConvexOptimizer
    public void preProcessLine(INDArray iNDArray) {
        INDArray iNDArray2 = (INDArray) this.searchState.get("oldparams");
        INDArray iNDArray3 = (INDArray) this.searchState.get(BaseOptimizer.PARAMS_KEY);
        iNDArray2.assign(iNDArray3.sub(iNDArray2));
        INDArray iNDArray4 = (INDArray) this.searchState.get("oldgradient");
        INDArray iNDArray5 = (INDArray) this.searchState.get(BaseOptimizer.GRADIENT_KEY);
        iNDArray4.subi(iNDArray5);
        double dot = Nd4j.getBlasWrapper().dot(iNDArray2, iNDArray4) + Nd4j.EPS_THRESHOLD;
        double d = dot / (Transforms.pow(iNDArray4, 2).sum(Integer.MAX_VALUE).getDouble(0) + Nd4j.EPS_THRESHOLD);
        LinkedList linkedList = (LinkedList) this.searchState.get("rho");
        linkedList.add(Double.valueOf(1.0d / dot));
        LinkedList linkedList2 = (LinkedList) this.searchState.get("s");
        linkedList2.add(iNDArray2);
        LinkedList linkedList3 = (LinkedList) this.searchState.get("y");
        linkedList3.add(iNDArray4);
        if (linkedList2.size() != linkedList3.size()) {
            throw new IllegalStateException("S and y mis matched sizes");
        }
        INDArray iNDArray6 = (INDArray) this.searchState.get("alpha");
        for (int size = linkedList2.size() - 1; size >= 0; size--) {
            if (((INDArray) linkedList2.get(size)).length() != iNDArray5.length()) {
                throw new IllegalStateException("Gradient and s length not equal");
            }
            if (size >= iNDArray6.length()) {
                break;
            }
            if (size > linkedList.size()) {
                throw new IllegalStateException("I > rho size");
            }
            iNDArray6.putScalar(size, ((Double) linkedList.get(size)).doubleValue() * Nd4j.getBlasWrapper().dot(iNDArray5, (INDArray) linkedList2.get(size)));
            if (iNDArray6.data().dataType() == DataBuffer.Type.DOUBLE) {
                Nd4j.getBlasWrapper().axpy((-1.0d) * iNDArray6.getDouble(size), iNDArray5, (INDArray) linkedList3.get(size));
            } else {
                Nd4j.getBlasWrapper().axpy((-1.0f) * iNDArray6.getFloat(size), iNDArray5, (INDArray) linkedList3.get(size));
            }
        }
        iNDArray5.muli(Double.valueOf(d));
        for (int i = 0; i < linkedList3.size() && i < iNDArray6.length(); i++) {
            double doubleValue = ((Double) linkedList.get(i)).doubleValue() * Nd4j.getBlasWrapper().dot((INDArray) linkedList3.get(i), iNDArray5);
            if (iNDArray6.data().dataType() == DataBuffer.Type.DOUBLE) {
                Nd4j.getBlasWrapper().axpy(iNDArray6.getDouble(i) * doubleValue, iNDArray5, (INDArray) linkedList2.get(i));
            } else {
                Nd4j.getBlasWrapper().axpy(iNDArray6.getFloat(i) * ((float) doubleValue), iNDArray5, (INDArray) linkedList2.get(i));
            }
        }
        iNDArray2.assign(iNDArray3);
        iNDArray4.assign(iNDArray5);
        iNDArray5.muli(-1);
    }

    @Override // org.deeplearning4j.optimize.solvers.BaseOptimizer, org.deeplearning4j.optimize.api.ConvexOptimizer
    public void postStep() {
    }
}
