package tri.util.math;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import kotlin.Metadata;
import kotlin.NoWhenBranchMatchedException;
import kotlin.collections.CollectionsKt;
import kotlin.collections.IntIterator;
import kotlin.jvm.internal.Intrinsics;
import kotlin.ranges.IntRange;
import kotlin.ranges.RangesKt;
import org.apache.commons.math3.fitting.leastsquares.LeastSquaresBuilder;
import org.apache.commons.math3.fitting.leastsquares.LeastSquaresOptimizer;
import org.apache.commons.math3.fitting.leastsquares.LevenbergMarquardtOptimizer;
import org.apache.commons.math3.fitting.leastsquares.MultivariateJacobianFunction;
import org.apache.commons.math3.fitting.leastsquares.ParameterValidator;
import org.apache.commons.math3.geometry.euclidean.twod.Vector2D;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.ArrayRealVector;
import org.apache.commons.math3.linear.RealVector;
import org.apache.commons.math3.util.Pair;
import org.jetbrains.annotations.NotNull;

/* compiled from: Sigmoid.kt */
@Metadata(mv = {1, 6, 0}, k = 1, xi = 48, d1 = {"��D\n\u0002\u0018\u0002\n\u0002\u0010��\n\u0002\b\u0002\n\u0002\u0010\u0006\n��\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\u0005\n\u0002\u0018\u0002\n��\n\u0002\u0010 \n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\u0002\bÆ\u0002\u0018��2\u00020\u0001B\u0007\b\u0002¢\u0006\u0002\u0010\u0002J \u0010\u0003\u001a\u00020\u00042\u0006\u0010\u0005\u001a\u00020\u00062\u0006\u0010\u0007\u001a\u00020\u00042\u0006\u0010\b\u001a\u00020\tH\u0002J \u0010\n\u001a\u00020\u00042\u0006\u0010\u0005\u001a\u00020\u00062\u0006\u0010\u0007\u001a\u00020\u00042\u0006\u0010\b\u001a\u00020\tH\u0002J(\u0010\u000b\u001a\u00020\u00042\u0006\u0010\u0005\u001a\u00020\u00062\u0006\u0010\u0007\u001a\u00020\u00042\u0006\u0010\b\u001a\u00020\t2\u0006\u0010\f\u001a\u00020\tH\u0002J(\u0010\r\u001a\u00020\u00042\u0006\u0010\u0005\u001a\u00020\u00062\u0006\u0010\u0007\u001a\u00020\u00042\u0006\u0010\b\u001a\u00020\t2\u0006\u0010\f\u001a\u00020\tH\u0002J,\u0010\u000e\u001a\u00020\u000f2\u0006\u0010\u0005\u001a\u00020\u00062\f\u0010\u0010\u001a\b\u0012\u0004\u0012\u00020\u00120\u00112\u0006\u0010\u0013\u001a\u00020\u000f2\u0006\u0010\u0014\u001a\u00020\u0015J,\u0010\u0016\u001a\u00020\u000f2\u0006\u0010\u0005\u001a\u00020\u00062\f\u0010\u0010\u001a\b\u0012\u0004\u0012\u00020\u00120\u00112\u0006\u0010\u0013\u001a\u00020\u000f2\u0006\u0010\u0014\u001a\u00020\u0015J\u001e\u0010\u0017\u001a\u00020\u00182\u0006\u0010\u0005\u001a\u00020\u00062\f\u0010\u0010\u001a\b\u0012\u0004\u0012\u00020\u00120\u0011H\u0002J\u001e\u0010\u0019\u001a\u00020\u00182\u0006\u0010\u0005\u001a\u00020\u00062\f\u0010\u0010\u001a\b\u0012\u0004\u0012\u00020\u00120\u0011H\u0002¨\u0006\u001a"}, d2 = {"Ltri/util/math/SigmoidCurveFitting;", "", "()V", "curve", "", "shape", "Ltri/util/math/Sigmoid;", "x", "params", "Lorg/apache/commons/math3/linear/RealVector;", "curveDerivative", "curveDerivativePartial", "delta", "curvePartial", "fitCumulative", "Ltri/util/math/SigmoidParameters;", "observedPoints", "", "Lorg/apache/commons/math3/geometry/euclidean/twod/Vector2D;", "initial", "parameterValidator", "Lorg/apache/commons/math3/fitting/leastsquares/ParameterValidator;", "fitIncidence", "solverFunCumulative", "Lorg/apache/commons/math3/fitting/leastsquares/MultivariateJacobianFunction;", "solverFunIncidence", "coda-time-covid"})
/* loaded from: input_file:tri/util/math/SigmoidCurveFitting.class */
public final class SigmoidCurveFitting {

    @NotNull
    public static final SigmoidCurveFitting INSTANCE = new SigmoidCurveFitting();

    /* compiled from: Sigmoid.kt */
    @Metadata(mv = {1, 6, 0}, k = 3, xi = 48)
    /* loaded from: input_file:tri/util/math/SigmoidCurveFitting$WhenMappings.class */
    public /* synthetic */ class WhenMappings {
        public static final /* synthetic */ int[] $EnumSwitchMapping$0;

        static {
            int[] iArr = new int[Sigmoid.values().length];
            iArr[Sigmoid.LINEAR.ordinal()] = 1;
            iArr[Sigmoid.QUADRATIC.ordinal()] = 2;
            iArr[Sigmoid.LOGISTIC.ordinal()] = 3;
            iArr[Sigmoid.GEN_LOGISTIC.ordinal()] = 4;
            iArr[Sigmoid.GAUSSIAN.ordinal()] = 5;
            iArr[Sigmoid.GOMPERTZ.ordinal()] = 6;
            $EnumSwitchMapping$0 = iArr;
        }
    }

    private SigmoidCurveFitting() {
    }

    @NotNull
    public final SigmoidParameters fitCumulative(@NotNull Sigmoid shape, @NotNull List<? extends Vector2D> observedPoints, @NotNull SigmoidParameters initial, @NotNull ParameterValidator parameterValidator) {
        Intrinsics.checkNotNullParameter(shape, "shape");
        Intrinsics.checkNotNullParameter(observedPoints, "observedPoints");
        Intrinsics.checkNotNullParameter(initial, "initial");
        Intrinsics.checkNotNullParameter(parameterValidator, "parameterValidator");
        List<? extends Vector2D> list = observedPoints;
        ArrayList arrayList = new ArrayList(CollectionsKt.collectionSizeOrDefault(list, 10));
        Iterator<T> it = list.iterator();
        while (it.hasNext()) {
            arrayList.add(Double.valueOf(((Vector2D) it.next()).getY()));
        }
        double[] doubleArray = CollectionsKt.toDoubleArray(arrayList);
        LeastSquaresBuilder leastSquaresBuilder = new LeastSquaresBuilder();
        Number[] numberArr = new Number[4];
        numberArr[0] = Double.valueOf(initial.getLoad());
        numberArr[1] = Double.valueOf(initial.getK());
        numberArr[2] = Double.valueOf(initial.getX0());
        Number v = initial.getV();
        if (v == null) {
            v = 0;
        }
        numberArr[3] = v;
        LeastSquaresOptimizer.Optimum optimize = new LevenbergMarquardtOptimizer().withCostRelativeTolerance(1.0E-9d).withParameterRelativeTolerance(1.0E-9d).optimize(leastSquaresBuilder.start(ApacheMathKt.vec(numberArr)).model(solverFunCumulative(shape, observedPoints)).target(doubleArray).maxEvaluations(100000).maxIterations(100000).parameterValidator(parameterValidator).build());
        System.out.println((Object) ("Cumulative fit parameters: " + optimize.getPoint()));
        double[] array = optimize.getPoint().toArray();
        return new SigmoidParameters(shape, array[0], array[1], array[2], Double.valueOf(array[3]));
    }

    @NotNull
    public final SigmoidParameters fitIncidence(@NotNull Sigmoid shape, @NotNull List<? extends Vector2D> observedPoints, @NotNull SigmoidParameters initial, @NotNull ParameterValidator parameterValidator) {
        Intrinsics.checkNotNullParameter(shape, "shape");
        Intrinsics.checkNotNullParameter(observedPoints, "observedPoints");
        Intrinsics.checkNotNullParameter(initial, "initial");
        Intrinsics.checkNotNullParameter(parameterValidator, "parameterValidator");
        IntRange until = RangesKt.until(1, observedPoints.size());
        ArrayList arrayList = new ArrayList(CollectionsKt.collectionSizeOrDefault(until, 10));
        Iterator<Integer> it = until.iterator();
        while (it.hasNext()) {
            int nextInt = ((IntIterator) it).nextInt();
            arrayList.add(new Vector2D(observedPoints.get(nextInt).getX(), observedPoints.get(nextInt).getY() - observedPoints.get(nextInt - 1).getY()));
        }
        ArrayList arrayList2 = arrayList;
        ArrayList arrayList3 = arrayList2;
        ArrayList arrayList4 = new ArrayList(CollectionsKt.collectionSizeOrDefault(arrayList3, 10));
        Iterator it2 = arrayList3.iterator();
        while (it2.hasNext()) {
            arrayList4.add(Double.valueOf(((Vector2D) it2.next()).getY()));
        }
        double[] doubleArray = CollectionsKt.toDoubleArray(arrayList4);
        LeastSquaresBuilder leastSquaresBuilder = new LeastSquaresBuilder();
        Number[] numberArr = new Number[4];
        numberArr[0] = Double.valueOf(initial.getLoad());
        numberArr[1] = Double.valueOf(initial.getK());
        numberArr[2] = Double.valueOf(initial.getX0());
        Number v = initial.getV();
        if (v == null) {
            v = 0;
        }
        numberArr[3] = v;
        LeastSquaresOptimizer.Optimum optimize = new LevenbergMarquardtOptimizer().withCostRelativeTolerance(1.0E-9d).withParameterRelativeTolerance(1.0E-9d).optimize(leastSquaresBuilder.start(ApacheMathKt.vec(numberArr)).model(solverFunIncidence(shape, arrayList2)).target(doubleArray).maxEvaluations(100000).maxIterations(100000).parameterValidator(parameterValidator).build());
        System.out.println((Object) ("Incidence fit parameters: " + optimize.getPoint()));
        double[] array = optimize.getPoint().toArray();
        return new SigmoidParameters(shape, array[0], array[1], array[2], Double.valueOf(array[3]));
    }

    private final MultivariateJacobianFunction solverFunCumulative(Sigmoid sigmoid, List<? extends Vector2D> list) {
        return (v2) -> {
            return m4853solverFunCumulative$lambda5(r0, r1, v2);
        };
    }

    private final MultivariateJacobianFunction solverFunIncidence(Sigmoid sigmoid, List<? extends Vector2D> list) {
        return (v2) -> {
            return m4854solverFunIncidence$lambda8(r0, r1, v2);
        };
    }

    private final double curve(Sigmoid sigmoid, double d, RealVector realVector) {
        switch (WhenMappings.$EnumSwitchMapping$0[sigmoid.ordinal()]) {
            case 1:
                return SigmoidKt.linear(d, ApacheMathKt.get(realVector, 0), ApacheMathKt.get(realVector, 1), ApacheMathKt.get(realVector, 2));
            case 2:
                return SigmoidKt.quadratic(d, ApacheMathKt.get(realVector, 0), ApacheMathKt.get(realVector, 1), ApacheMathKt.get(realVector, 2));
            case 3:
                return SigmoidKt.logistic(d, ApacheMathKt.get(realVector, 0), ApacheMathKt.get(realVector, 1), ApacheMathKt.get(realVector, 2));
            case 4:
                return SigmoidKt.generalLogistic(d, ApacheMathKt.get(realVector, 0), ApacheMathKt.get(realVector, 1), ApacheMathKt.get(realVector, 2), ApacheMathKt.get(realVector, 3));
            case 5:
                return SigmoidKt.gaussianErf(d, ApacheMathKt.get(realVector, 0), ApacheMathKt.get(realVector, 1), ApacheMathKt.get(realVector, 2));
            case 6:
                return SigmoidKt.gompertz(d, ApacheMathKt.get(realVector, 0), ApacheMathKt.get(realVector, 1), ApacheMathKt.get(realVector, 2));
            default:
                throw new NoWhenBranchMatchedException();
        }
    }

    private final double curveDerivative(Sigmoid sigmoid, double d, RealVector realVector) {
        return curve(sigmoid, d + 0.5d, realVector) - curve(sigmoid, d - 0.5d, realVector);
    }

    private final double curvePartial(Sigmoid sigmoid, double d, RealVector realVector, RealVector realVector2) {
        RealVector unitVector = realVector2.unitVector();
        Intrinsics.checkNotNullExpressionValue(unitVector, "delta.unitVector()");
        RealVector times = ApacheMathKt.times(5.0E-4d, unitVector);
        Intrinsics.checkNotNullExpressionValue(times, ".0005*delta.unitVector()");
        RealVector plus = ApacheMathKt.plus(realVector, times);
        Intrinsics.checkNotNullExpressionValue(plus, "params + .0005*delta.unitVector()");
        double curve = curve(sigmoid, d, plus);
        RealVector unitVector2 = realVector2.unitVector();
        Intrinsics.checkNotNullExpressionValue(unitVector2, "delta.unitVector()");
        RealVector times2 = ApacheMathKt.times(5.0E-4d, unitVector2);
        Intrinsics.checkNotNullExpressionValue(times2, ".0005*delta.unitVector()");
        RealVector minus = ApacheMathKt.minus(realVector, times2);
        Intrinsics.checkNotNullExpressionValue(minus, "params - .0005*delta.unitVector()");
        return (curve - curve(sigmoid, d, minus)) / 1000.0d;
    }

    private final double curveDerivativePartial(Sigmoid sigmoid, double d, RealVector realVector, RealVector realVector2) {
        RealVector unitVector = realVector2.unitVector();
        Intrinsics.checkNotNullExpressionValue(unitVector, "delta.unitVector()");
        RealVector times = ApacheMathKt.times(0.005d, unitVector);
        Intrinsics.checkNotNullExpressionValue(times, ".005*delta.unitVector()");
        RealVector plus = ApacheMathKt.plus(realVector, times);
        Intrinsics.checkNotNullExpressionValue(plus, "params + .005*delta.unitVector()");
        double curveDerivative = curveDerivative(sigmoid, d, plus);
        RealVector unitVector2 = realVector2.unitVector();
        Intrinsics.checkNotNullExpressionValue(unitVector2, "delta.unitVector()");
        RealVector times2 = ApacheMathKt.times(0.005d, unitVector2);
        Intrinsics.checkNotNullExpressionValue(times2, ".005*delta.unitVector()");
        RealVector minus = ApacheMathKt.minus(realVector, times2);
        Intrinsics.checkNotNullExpressionValue(minus, "params - .005*delta.unitVector()");
        return (curveDerivative - curveDerivative(sigmoid, d, minus)) / 100.0d;
    }

    /* renamed from: solverFunCumulative$lambda-5, reason: not valid java name */
    private static final Pair m4853solverFunCumulative$lambda5(List observedPoints, Sigmoid shape, RealVector params) {
        Intrinsics.checkNotNullParameter(observedPoints, "$observedPoints");
        Intrinsics.checkNotNullParameter(shape, "$shape");
        List<Vector2D> list = observedPoints;
        ArrayList arrayList = new ArrayList(CollectionsKt.collectionSizeOrDefault(list, 10));
        for (Vector2D vector2D : list) {
            SigmoidCurveFitting sigmoidCurveFitting = INSTANCE;
            double x = vector2D.getX();
            Intrinsics.checkNotNullExpressionValue(params, "params");
            arrayList.add(Double.valueOf(sigmoidCurveFitting.curve(shape, x, params)));
        }
        ArrayList arrayList2 = arrayList;
        Array2DRowRealMatrix array2DRowRealMatrix = new Array2DRowRealMatrix(observedPoints.size(), 4);
        int i = 0;
        for (Object obj : observedPoints) {
            int i2 = i;
            i++;
            if (i2 < 0) {
                CollectionsKt.throwIndexOverflow();
            }
            Vector2D vector2D2 = (Vector2D) obj;
            SigmoidCurveFitting sigmoidCurveFitting2 = INSTANCE;
            double x2 = vector2D2.getX();
            Intrinsics.checkNotNullExpressionValue(params, "params");
            array2DRowRealMatrix.setEntry(i2, 0, sigmoidCurveFitting2.curvePartial(shape, x2, params, ApacheMathKt.vec((Number) 1, (Number) 0, (Number) 0, (Number) 0)));
            array2DRowRealMatrix.setEntry(i2, 1, INSTANCE.curvePartial(shape, vector2D2.getX(), params, ApacheMathKt.vec((Number) 0, (Number) 1, (Number) 0, (Number) 0)));
            array2DRowRealMatrix.setEntry(i2, 2, INSTANCE.curvePartial(shape, vector2D2.getX(), params, ApacheMathKt.vec((Number) 0, (Number) 0, (Number) 1, (Number) 0)));
            array2DRowRealMatrix.setEntry(i2, 3, INSTANCE.curvePartial(shape, vector2D2.getX(), params, ApacheMathKt.vec((Number) 0, (Number) 0, (Number) 0, (Number) 1)));
        }
        return new Pair(new ArrayRealVector(CollectionsKt.toDoubleArray(arrayList2)), array2DRowRealMatrix);
    }

    /* renamed from: solverFunIncidence$lambda-8, reason: not valid java name */
    private static final Pair m4854solverFunIncidence$lambda8(List observedPoints, Sigmoid shape, RealVector params) {
        Intrinsics.checkNotNullParameter(observedPoints, "$observedPoints");
        Intrinsics.checkNotNullParameter(shape, "$shape");
        List<Vector2D> list = observedPoints;
        ArrayList arrayList = new ArrayList(CollectionsKt.collectionSizeOrDefault(list, 10));
        for (Vector2D vector2D : list) {
            SigmoidCurveFitting sigmoidCurveFitting = INSTANCE;
            double x = vector2D.getX();
            Intrinsics.checkNotNullExpressionValue(params, "params");
            arrayList.add(Double.valueOf(sigmoidCurveFitting.curveDerivative(shape, x, params)));
        }
        ArrayList arrayList2 = arrayList;
        Array2DRowRealMatrix array2DRowRealMatrix = new Array2DRowRealMatrix(observedPoints.size(), 4);
        int i = 0;
        for (Object obj : observedPoints) {
            int i2 = i;
            i++;
            if (i2 < 0) {
                CollectionsKt.throwIndexOverflow();
            }
            Vector2D vector2D2 = (Vector2D) obj;
            SigmoidCurveFitting sigmoidCurveFitting2 = INSTANCE;
            double x2 = vector2D2.getX();
            Intrinsics.checkNotNullExpressionValue(params, "params");
            array2DRowRealMatrix.setEntry(i2, 0, sigmoidCurveFitting2.curveDerivativePartial(shape, x2, params, ApacheMathKt.vec((Number) 1, (Number) 0, (Number) 0, (Number) 0)));
            array2DRowRealMatrix.setEntry(i2, 1, INSTANCE.curveDerivativePartial(shape, vector2D2.getX(), params, ApacheMathKt.vec((Number) 0, (Number) 1, (Number) 0, (Number) 0)));
            array2DRowRealMatrix.setEntry(i2, 2, INSTANCE.curveDerivativePartial(shape, vector2D2.getX(), params, ApacheMathKt.vec((Number) 0, (Number) 0, (Number) 1, (Number) 0)));
            array2DRowRealMatrix.setEntry(i2, 3, INSTANCE.curveDerivativePartial(shape, vector2D2.getX(), params, ApacheMathKt.vec((Number) 0, (Number) 0, (Number) 0, (Number) 1)));
        }
        return new Pair(new ArrayRealVector(CollectionsKt.toDoubleArray(arrayList2)), array2DRowRealMatrix);
    }
}
