package com.github.chen0040.glm.search.methods.cgs;

import com.github.chen0040.glm.search.CostEvaluationMethod;
import com.github.chen0040.glm.search.GradientEvaluationMethod;
import com.github.chen0040.glm.search.LineSearch;
import com.github.chen0040.glm.search.LineSearchResult;
import com.github.chen0040.glm.search.LocalSearch;
import com.github.chen0040.glm.search.TerminationEvaluationMethod;
import com.github.chen0040.glm.search.solutions.NumericSolution;
import com.github.chen0040.glm.search.solutions.NumericSolutionFactory;
import com.github.chen0040.glm.search.solutions.NumericSolutionUpdateResult;

/* loaded from: input_file:com/github/chen0040/glm/search/methods/cgs/NonlinearCGSearch.class */
public class NonlinearCGSearch extends LocalSearch {
    private BetaFormula betaFormula = BetaFormula.FletcherReeves;

    @Override // com.github.chen0040.glm.search.LocalSearch
    public void copy(LocalSearch localSearch) {
        super.copy(localSearch);
        this.betaFormula = ((NonlinearCGSearch) localSearch).betaFormula;
    }

    @Override // com.github.chen0040.glm.search.LocalSearch
    public LocalSearch makeCopy() {
        NonlinearCGSearch nonlinearCGSearch = new NonlinearCGSearch();
        nonlinearCGSearch.copy(this);
        return nonlinearCGSearch;
    }

    public void setBetaFormula(BetaFormula betaFormula) {
        this.betaFormula = betaFormula;
    }

    @Override // com.github.chen0040.glm.search.LocalSearch
    public NumericSolution minimize(double[] dArr, CostEvaluationMethod costEvaluationMethod, GradientEvaluationMethod gradientEvaluationMethod, TerminationEvaluationMethod terminationEvaluationMethod, Object obj) {
        NumericSolution numericSolution = new NumericSolution();
        int length = dArr.length;
        double[] dArr2 = (double[]) dArr.clone();
        double apply = costEvaluationMethod.apply(dArr2, getLowerBounds(), getUpperBounds(), obj);
        double[] dArr3 = new double[length];
        gradientEvaluationMethod.apply(dArr2, dArr3, getLowerBounds(), getUpperBounds(), obj);
        double[] dArr4 = new double[length];
        double[] dArr5 = new double[length];
        for (int i = 0; i < length; i++) {
            dArr4[i] = -dArr3[i];
        }
        LineSearchResult search = LineSearch.search(dArr2, apply, dArr4, costEvaluationMethod, gradientEvaluationMethod, getLowerBounds(), getUpperBounds(), obj);
        search.alpha();
        double[] x = search.x();
        search.fx();
        double[] dArr6 = new double[length];
        for (int i2 = 0; i2 < length; i2++) {
            dArr6[i2] = dArr4[i2];
        }
        NumericSolutionUpdateResult numericSolutionUpdateResult = null;
        for (int i3 = 0; !terminationEvaluationMethod.shouldTerminate(numericSolutionUpdateResult, i3); i3++) {
            for (int i4 = 0; i4 < length; i4++) {
                dArr5[i4] = dArr4[i4];
                dArr2[i4] = x[i4];
            }
            gradientEvaluationMethod.apply(dArr2, dArr3, getLowerBounds(), getUpperBounds(), obj);
            for (int i5 = 0; i5 < length; i5++) {
                dArr4[i5] = -dArr3[i5];
            }
            double computeBeta = computeBeta(dArr4, dArr5, dArr6);
            for (int i6 = 0; i6 < length; i6++) {
                dArr6[i6] = dArr4[i6] + (computeBeta * dArr6[i6]);
            }
            LineSearchResult search2 = LineSearch.search(dArr2, apply, dArr6, costEvaluationMethod, gradientEvaluationMethod, getLowerBounds(), getUpperBounds(), obj);
            x = search2.x();
            double fx = search2.fx();
            search2.alpha();
            numericSolutionUpdateResult = numericSolution.tryUpdateSolution(x, fx);
            if (numericSolutionUpdateResult.improved()) {
                notifySolutionUpdated(numericSolution, numericSolutionUpdateResult, i3);
            }
            step(new NumericSolution(x, fx), numericSolutionUpdateResult, i3);
        }
        return numericSolution;
    }

    public double[] randomize(double[] dArr) {
        return NumericSolutionFactory.mutate(dArr, 5.0d);
    }

    public double computeBeta(double[] dArr, double[] dArr2, double[] dArr3) {
        double d = 0.0d;
        int length = dArr.length;
        if (this.betaFormula == BetaFormula.FletcherReeves) {
            double d2 = 0.0d;
            for (double d3 : dArr) {
                d2 += Math.pow(d3, 2.0d);
            }
            double d4 = 0.0d;
            for (int i = 0; i < length; i++) {
                d4 += Math.pow(dArr2[i], 2.0d);
            }
            if (d4 != 0.0d) {
                d = d2 / d4;
            }
        } else if (this.betaFormula == BetaFormula.HestenesStiefel) {
            double d5 = 0.0d;
            for (int i2 = 0; i2 < length; i2++) {
                d5 += dArr[i2] * (dArr[i2] - dArr2[i2]);
            }
            double d6 = 0.0d;
            for (int i3 = 0; i3 < length; i3++) {
                d6 += Math.pow(dArr2[i3], 2.0d);
            }
            if (d6 != 0.0d) {
                d = d5 / d6;
            }
        } else if (this.betaFormula == BetaFormula.PolakRebiere) {
            double d7 = 0.0d;
            for (int i4 = 0; i4 < length; i4++) {
                d7 += dArr[i4] * (dArr[i4] - dArr2[i4]);
            }
            double d8 = 0.0d;
            for (int i5 = 0; i5 < length; i5++) {
                d8 += dArr3[i5] * (dArr[i5] - dArr2[i5]);
            }
            if (d8 != 0.0d) {
                d = d7 / d8;
            }
        }
        return d;
    }
}
