package org.nd4j.evaluation;

import java.io.IOException;
import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
import lombok.NonNull;
import org.nd4j.base.Preconditions;
import org.nd4j.evaluation.BaseEvaluation;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.evaluation.classification.EvaluationBinary;
import org.nd4j.evaluation.classification.EvaluationCalibration;
import org.nd4j.evaluation.classification.ROC;
import org.nd4j.evaluation.classification.ROCBinary;
import org.nd4j.evaluation.classification.ROCMultiClass;
import org.nd4j.evaluation.regression.RegressionEvaluation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastTo;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.AtomicBoolean;
import org.nd4j.linalg.primitives.AtomicDouble;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.primitives.Triple;
import org.nd4j.linalg.primitives.serde.JsonDeserializerAtomicBoolean;
import org.nd4j.linalg.primitives.serde.JsonDeserializerAtomicDouble;
import org.nd4j.linalg.primitives.serde.JsonSerializerAtomicBoolean;
import org.nd4j.linalg.primitives.serde.JsonSerializerAtomicDouble;
import org.nd4j.linalg.util.ArrayUtil;
import org.nd4j.shade.jackson.annotation.JsonAutoDetect;
import org.nd4j.shade.jackson.core.JsonProcessingException;
import org.nd4j.shade.jackson.databind.DeserializationFeature;
import org.nd4j.shade.jackson.databind.MapperFeature;
import org.nd4j.shade.jackson.databind.ObjectMapper;
import org.nd4j.shade.jackson.databind.SerializationFeature;
import org.nd4j.shade.jackson.databind.exc.InvalidTypeIdException;
import org.nd4j.shade.jackson.databind.module.SimpleModule;
import org.nd4j.shade.jackson.dataformat.yaml.YAMLFactory;

/* loaded from: input_file:org/nd4j/evaluation/BaseEvaluation.class */
public abstract class BaseEvaluation<T extends BaseEvaluation> implements IEvaluation<T> {
    private static ObjectMapper objectMapper = configureMapper(new ObjectMapper());
    private static ObjectMapper yamlMapper = configureMapper(new ObjectMapper(new YAMLFactory()));

    /* JADX WARN: Type inference failed for: r1v13, types: [org.nd4j.shade.jackson.databind.introspect.VisibilityChecker] */
    private static ObjectMapper configureMapper(ObjectMapper objectMapper2) {
        objectMapper2.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
        objectMapper2.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false);
        objectMapper2.configure(MapperFeature.SORT_PROPERTIES_ALPHABETICALLY, false);
        objectMapper2.enable(SerializationFeature.INDENT_OUTPUT);
        SimpleModule simpleModule = new SimpleModule();
        simpleModule.addSerializer(AtomicDouble.class, new JsonSerializerAtomicDouble());
        simpleModule.addSerializer(AtomicBoolean.class, new JsonSerializerAtomicBoolean());
        simpleModule.addDeserializer(AtomicDouble.class, new JsonDeserializerAtomicDouble());
        simpleModule.addDeserializer(AtomicBoolean.class, new JsonDeserializerAtomicBoolean());
        objectMapper2.registerModule(simpleModule);
        objectMapper2.setVisibilityChecker(objectMapper2.getSerializationConfig().getDefaultVisibilityChecker().withFieldVisibility(JsonAutoDetect.Visibility.ANY).withGetterVisibility(JsonAutoDetect.Visibility.NONE).withSetterVisibility(JsonAutoDetect.Visibility.NONE).withCreatorVisibility(JsonAutoDetect.Visibility.ANY));
        return objectMapper2;
    }

    public static <T extends IEvaluation> T fromYaml(String str, Class<T> cls) {
        try {
            return (T) ((IEvaluation) yamlMapper.readValue(str, cls));
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public static <T extends IEvaluation> T fromJson(String str, Class<T> cls) {
        try {
            return (T) ((IEvaluation) objectMapper.readValue(str, cls));
        } catch (InvalidTypeIdException e) {
            if (!e.getMessage().contains("Could not resolve type id")) {
                throw new RuntimeException(e);
            }
            try {
                return (T) attempFromLegacyFromJson(str, e);
            } catch (Throwable th) {
                throw new RuntimeException("Cannot deserialize from JSON - JSON is invalid?", th);
            }
        } catch (IOException e2) {
            throw new RuntimeException(e2);
        }
    }

    protected static <T extends IEvaluation> T attempFromLegacyFromJson(String str, InvalidTypeIdException invalidTypeIdException) throws InvalidTypeIdException {
        if (str.contains("org.deeplearning4j.eval.Evaluation")) {
            return (T) fromJson(str.replaceAll("org.deeplearning4j.eval.Evaluation", "org.nd4j.evaluation.classification.Evaluation"), Evaluation.class);
        }
        if (str.contains("org.deeplearning4j.eval.EvaluationBinary")) {
            return (T) fromJson(str.replaceAll("org.deeplearning4j.eval.EvaluationBinary", "org.nd4j.evaluation.classification.EvaluationBinary").replaceAll("org.deeplearning4j.eval.ROC", "org.nd4j.evaluation.classification.ROC").replaceAll("org.deeplearning4j.eval.curves.", "org.nd4j.evaluation.curves."), EvaluationBinary.class);
        }
        if (str.contains("org.deeplearning4j.eval.EvaluationCalibration")) {
            return (T) fromJson(str.replaceAll("org.deeplearning4j.eval.EvaluationCalibration", "org.nd4j.evaluation.classification.EvaluationCalibration").replaceAll("org.deeplearning4j.eval.curves.", "org.nd4j.evaluation.curves."), EvaluationCalibration.class);
        }
        if (str.contains("org.deeplearning4j.eval.ROCBinary")) {
            return (T) fromJson(str.replaceAll("org.deeplearning4j.eval.ROCBinary", "org.nd4j.evaluation.classification.ROCBinary").replaceAll("org.deeplearning4j.eval.ROC", "org.nd4j.evaluation.classification.ROC").replaceAll("org.deeplearning4j.eval.curves.", "org.nd4j.evaluation.curves."), ROCBinary.class);
        }
        if (str.contains("org.deeplearning4j.eval.ROCMultiClass")) {
            return (T) fromJson(str.replaceAll("org.deeplearning4j.eval.ROCMultiClass", "org.nd4j.evaluation.classification.ROCMultiClass").replaceAll("org.deeplearning4j.eval.ROC", "org.nd4j.evaluation.classification.ROC").replaceAll("org.deeplearning4j.eval.curves.", "org.nd4j.evaluation.curves."), ROCMultiClass.class);
        }
        if (str.contains("org.deeplearning4j.eval.ROC")) {
            return (T) fromJson(str.replaceAll("org.deeplearning4j.eval.ROC", "org.nd4j.evaluation.classification.ROC").replaceAll("org.deeplearning4j.eval.curves.", "org.nd4j.evaluation.curves."), ROC.class);
        }
        if (str.contains("org.deeplearning4j.eval.RegressionEvaluation")) {
            return (T) fromJson(str.replaceAll("org.deeplearning4j.eval.RegressionEvaluation", "org.nd4j.evaluation.regression.RegressionEvaluation"), RegressionEvaluation.class);
        }
        throw invalidTypeIdException;
    }

    public static Triple<INDArray, INDArray, INDArray> reshapeAndExtractNotMasked(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, int i) {
        if (iNDArray.rank() == 2) {
            Preconditions.checkState(i == 1, "Only axis=1 is supported 2d data - got axis=%s for labels array shape %ndShape", Integer.valueOf(i), iNDArray);
            if (iNDArray3 == null) {
                return new Triple<>(iNDArray, iNDArray2, null);
            }
            if (iNDArray3.rank() != 1 && !iNDArray3.isColumnVector()) {
                Preconditions.checkState(iNDArray.equalShapes(iNDArray3), "If a mask array is present for 2d data, it must either be a vector (column vector) or have shape equal to the labels (for per-output masking, when supported). Got labels shape %ndShape, mask shape %ndShape", iNDArray, iNDArray3);
                return new Triple<>(iNDArray, iNDArray2, iNDArray3);
            }
            int intValue = iNDArray3.neq(Double.valueOf(0.0d)).castTo(DataType.INT).sumNumber().intValue();
            if (intValue == 0) {
                return null;
            }
            if (intValue == iNDArray3.length()) {
                return new Triple<>(iNDArray, iNDArray2, null);
            }
            int[] intVector = iNDArray3.toIntVector();
            int[] iArr = new int[intValue];
            int i2 = 0;
            for (int i3 = 0; i3 < intVector.length; i3++) {
                if (intVector[i3] != 0) {
                    int i4 = i2;
                    i2++;
                    iArr[i4] = i3;
                }
            }
            return new Triple<>(Nd4j.pullRows(iNDArray, 1, iArr, 'c'), Nd4j.pullRows(iNDArray2, 1, iArr, 'c'), null);
        }
        if (iNDArray.rank() != 3 && iNDArray.rank() != 4 && iNDArray.rank() != 5) {
            throw new IllegalStateException("Unknown array type passed to evaluation: labels array rank " + iNDArray.rank() + " with shape " + iNDArray.shapeInfoToString() + ". Labels and predictions must always be rank 2 or higher, with leading dimension being minibatch dimension");
        }
        if (iNDArray3 == null) {
            return reshapeSameShapeTo2d(i, iNDArray, iNDArray2, iNDArray3);
        }
        if (iNDArray.rank() == 3) {
            if (iNDArray3.rank() != 2) {
                Preconditions.checkState(iNDArray.equalShapes(iNDArray3), "If a mask array is present for 3d data, it must either be 2d (shape [minibatch, sequenceLength]) or have shape equal to the labels (for per-output masking, when supported). Got labels shape %ndShape, mask shape %ndShape", iNDArray, iNDArray3);
                return reshapeSameShapeTo2d(i, iNDArray, iNDArray2, iNDArray3);
            }
            Pair<INDArray, INDArray> extractNonMaskedTimeSteps = EvaluationUtils.extractNonMaskedTimeSteps(iNDArray, iNDArray2, iNDArray3);
            if (extractNonMaskedTimeSteps == null) {
                return null;
            }
            return new Triple<>(extractNonMaskedTimeSteps.getFirst(), extractNonMaskedTimeSteps.getSecond(), null);
        }
        if (iNDArray.equalShapes(iNDArray3)) {
            return reshapeSameShapeTo2d(i, iNDArray, iNDArray2, iNDArray3);
        }
        if (iNDArray3.rank() != 1) {
            if (iNDArray3.rank() != iNDArray.rank() || !Shape.areShapesBroadcastable(iNDArray3.shape(), iNDArray.shape())) {
                throw new UnsupportedOperationException("Evaluation case not supported: labels shape " + Arrays.toString(iNDArray.shape()) + " with mask shape " + Arrays.toString(iNDArray3.shape()));
            }
            INDArray createUninitialized = Nd4j.createUninitialized(iNDArray3.dataType(), iNDArray.shape());
            Nd4j.exec(new BroadcastTo(iNDArray3, iNDArray.shape(), createUninitialized));
            return reshapeSameShapeTo2d(i, iNDArray, iNDArray2, createUninitialized);
        }
        Preconditions.checkState(iNDArray3.length() == iNDArray.size(0), "For rank 4 labels with shape %ndShape and 1d mask of shape %ndShape, the mask array length must equal labels dimension 0 size", iNDArray, iNDArray3);
        long[] nTimes = ArrayUtil.nTimes(iNDArray.rank(), 1L);
        nTimes[0] = iNDArray3.size(0);
        INDArray reshape = iNDArray3.reshape(nTimes);
        INDArray createUninitialized2 = Nd4j.createUninitialized(iNDArray3.dataType(), iNDArray.shape());
        Nd4j.exec(new BroadcastTo(reshape, iNDArray.shape(), createUninitialized2));
        return reshapeSameShapeTo2d(i, iNDArray, iNDArray2, createUninitialized2);
    }

    private static Triple<INDArray, INDArray, INDArray> reshapeSameShapeTo2d(int i, INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3) {
        int[] iArr = new int[iNDArray.rank()];
        int i2 = 0;
        for (int i3 = 0; i3 < iNDArray.rank(); i3++) {
            if (i3 != i) {
                int i4 = i2;
                i2++;
                iArr[i4] = i3;
            }
        }
        iArr[i2] = i;
        long j = 1;
        for (int i5 = 0; i5 < iArr.length - 1; i5++) {
            j *= iNDArray.size(iArr[i5]);
        }
        return new Triple<>(iNDArray.permute(iArr).dup('c').reshape('c', j, iNDArray.size(i)), iNDArray2.permute(iArr).dup('c').reshape('c', j, iNDArray.size(i)), iNDArray3 == null ? null : iNDArray3.permute(iArr).dup('c').reshape('c', j, iNDArray.size(i)));
    }

    @Override // org.nd4j.evaluation.IEvaluation
    public void eval(INDArray iNDArray, INDArray iNDArray2) {
        eval(iNDArray, iNDArray2, null, null);
    }

    @Override // org.nd4j.evaluation.IEvaluation
    public void eval(@NonNull INDArray iNDArray, @NonNull INDArray iNDArray2, List<? extends Serializable> list) {
        if (iNDArray == null) {
            throw new NullPointerException("labels is marked @NonNull but is null");
        }
        if (iNDArray2 == null) {
            throw new NullPointerException("predictions is marked @NonNull but is null");
        }
        eval(iNDArray, iNDArray2, null, list);
    }

    @Override // org.nd4j.evaluation.IEvaluation
    public void eval(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3) {
        eval(iNDArray, iNDArray2, iNDArray3, null);
    }

    @Override // org.nd4j.evaluation.IEvaluation
    public void evalTimeSeries(INDArray iNDArray, INDArray iNDArray2) {
        evalTimeSeries(iNDArray, iNDArray2, null);
    }

    @Override // org.nd4j.evaluation.IEvaluation
    public void evalTimeSeries(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3) {
        Pair<INDArray, INDArray> extractNonMaskedTimeSteps = EvaluationUtils.extractNonMaskedTimeSteps(iNDArray, iNDArray2, iNDArray3);
        if (extractNonMaskedTimeSteps == null) {
            return;
        }
        eval(extractNonMaskedTimeSteps.getFirst(), extractNonMaskedTimeSteps.getSecond());
    }

    @Override // org.nd4j.evaluation.IEvaluation
    public String toJson() {
        try {
            return objectMapper.writeValueAsString(this);
        } catch (JsonProcessingException e) {
            throw new RuntimeException(e);
        }
    }

    public String toString() {
        return stats();
    }

    @Override // org.nd4j.evaluation.IEvaluation
    public String toYaml() {
        try {
            return yamlMapper.writeValueAsString(this);
        } catch (JsonProcessingException e) {
            throw new RuntimeException(e);
        }
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        return (obj instanceof BaseEvaluation) && ((BaseEvaluation) obj).canEqual(this);
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof BaseEvaluation;
    }

    public int hashCode() {
        return 1;
    }

    public static ObjectMapper getObjectMapper() {
        return objectMapper;
    }

    public static ObjectMapper getYamlMapper() {
        return yamlMapper;
    }
}
