package org.nd4j.finitedifferences;

import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.exception.ND4JIllegalArgumentException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.function.Function;
import org.nd4j.linalg.ops.transforms.Transforms;

/* loaded from: input_file:org/nd4j/finitedifferences/TwoPointApproximation.class */
public class TwoPointApproximation {
    public static INDArray[] prepareBounds(INDArray iNDArray, INDArray iNDArray2) {
        return new INDArray[]{Nd4j.valueArrayOf(iNDArray2.shape(), iNDArray.getDouble(0)), Nd4j.valueArrayOf(iNDArray2.shape(), iNDArray.getDouble(1))};
    }

    public static INDArray[] adjustSchemeToBounds(INDArray iNDArray, INDArray iNDArray2, int i, INDArray iNDArray3, INDArray iNDArray4) {
        INDArray onesLike = Nd4j.onesLike(iNDArray2);
        if (Transforms.and(iNDArray3.eq(Double.valueOf(Double.NEGATIVE_INFINITY)), iNDArray4.eq(Double.valueOf(Double.POSITIVE_INFINITY))).sumNumber().doubleValue() > 0.0d) {
            return new INDArray[]{iNDArray2, onesLike};
        }
        INDArray mul = iNDArray2.mul(Integer.valueOf(i));
        INDArray dup = iNDArray2.dup();
        INDArray sub = iNDArray.sub(iNDArray3);
        INDArray sub2 = iNDArray4.sub(iNDArray);
        INDArray and = Transforms.and(Transforms.greaterThanOrEqual(sub, mul), Transforms.greaterThanOrEqual(sub2, mul));
        INDArray and2 = Transforms.and(Transforms.greaterThanOrEqual(iNDArray4, sub), Transforms.not(and));
        dup.put(and2, Transforms.min(iNDArray2.get(and2), sub2.get(and2).mul(Double.valueOf(0.5d)).divi(Integer.valueOf(i))));
        onesLike.put(and2, Nd4j.scalar(1.0d));
        INDArray and3 = Transforms.and(sub2.lt(iNDArray3), Transforms.not(and));
        dup.put(and3, Transforms.min(iNDArray2.get(and3), sub.get(and3).mul(Double.valueOf(0.5d)).divi(Integer.valueOf(i))));
        onesLike.put(and3, Nd4j.scalar(1.0d));
        INDArray divi = Transforms.min(sub2, sub).divi(Integer.valueOf(i));
        INDArray and4 = Transforms.and(Transforms.not(and), Transforms.lessThanOrEqual(Transforms.abs(dup), divi));
        dup.put(and4, divi.get(and4));
        onesLike.put(and4, Nd4j.scalar(0.0d));
        return new INDArray[]{dup, onesLike};
    }

    public static INDArray computeAbsoluteStep(INDArray iNDArray) {
        return computeAbsoluteStep(Transforms.pow(Nd4j.scalar(Nd4j.EPS_THRESHOLD), Double.valueOf(0.5d)), iNDArray);
    }

    public static double getEpsRelativeTo(INDArray iNDArray) {
        return iNDArray.data().dataType() == DataBuffer.Type.FLOAT ? 1.1920929E-7d : 2.220446049250313E-16d;
    }

    public static INDArray computeAbsoluteStep(INDArray iNDArray, INDArray iNDArray2) {
        if (iNDArray == null) {
            iNDArray = Transforms.pow(Nd4j.scalar(getEpsRelativeTo(iNDArray2)), Double.valueOf(0.5d));
        }
        return iNDArray2.gte(0).muli((Number) 2).subi((Number) 1).mul(iNDArray).muli(Transforms.max(Transforms.abs(iNDArray2), 1.0d));
    }

    public static INDArray approximateDerivative(Function<INDArray, INDArray> function, INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, INDArray iNDArray4) {
        if (iNDArray.rank() > 2) {
            throw new ND4JIllegalArgumentException("Argument must be a vector or scalar");
        }
        INDArray computeAbsoluteStep = computeAbsoluteStep(iNDArray2, iNDArray);
        INDArray[] prepareBounds = prepareBounds(iNDArray4, iNDArray);
        return denseDifference(function, iNDArray, iNDArray3, computeAbsoluteStep, adjustSchemeToBounds(iNDArray, computeAbsoluteStep, 1, prepareBounds[0], prepareBounds[1])[1]);
    }

    public static INDArray denseDifference(Function<INDArray, INDArray> function, INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, INDArray iNDArray4) {
        INDArray diag = Nd4j.diag(iNDArray3.reshape(1, iNDArray3.length()));
        INDArray create = Nd4j.create(iNDArray.length(), iNDArray2.length());
        for (int i = 0; i < iNDArray3.length(); i++) {
            INDArray add = iNDArray.add(diag.slice(i));
            create.putSlice(i, function.apply(add).sub(iNDArray2).div(add.slice(i).sub(iNDArray.slice(i))));
        }
        if (iNDArray2.length() == 1) {
            create = create.ravel();
        }
        return create;
    }
}
