package org.deeplearning4j.nn.simple.multiclass;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/deeplearning4j/nn/simple/multiclass/RankClassificationResult.class */
public class RankClassificationResult implements Serializable {
    private int[][] rankedIndices;
    private float[][] probabilities;
    private List<String> labels;
    private List<String> maxLabels;

    public RankClassificationResult(INDArray iNDArray) {
        this(iNDArray, null);
    }

    public RankClassificationResult(INDArray iNDArray, List<String> list) {
        if (iNDArray.rank() > 2) {
            throw new ND4JIllegalStateException("Only works with vectors and matrices right now");
        }
        INDArray iNDArray2 = Nd4j.sortWithIndices(iNDArray, -1, false)[0];
        if (list == null) {
            this.labels = new ArrayList(iNDArray.columns());
            for (int i = 0; i < iNDArray.columns(); i++) {
                this.labels.add(String.valueOf(i));
            }
        } else {
            this.labels = new ArrayList(list);
        }
        this.rankedIndices = new int[iNDArray2.rows()][iNDArray2.columns()];
        this.probabilities = new float[iNDArray.rows()][iNDArray.columns()];
        for (int i2 = 0; i2 < iNDArray2.rows(); i2++) {
            for (int i3 = 0; i3 < iNDArray2.columns(); i3++) {
                this.rankedIndices[i2][i3] = iNDArray2.getInt(i2, i3);
                this.probabilities[i2][i3] = iNDArray.getFloat(new int[]{i2, i3});
            }
        }
        maxOutcomes();
    }

    public String maxOutcomeForRow(int i) {
        return this.labels.get(this.rankedIndices[i][0]);
    }

    public List<String> maxOutcomes() {
        if (this.maxLabels != null) {
            return this.maxLabels;
        }
        this.maxLabels = new ArrayList(this.rankedIndices.length);
        for (int i = 0; i < this.rankedIndices.length; i++) {
            this.maxLabels.add(maxOutcomeForRow(i));
        }
        return this.maxLabels;
    }

    public int[][] getRankedIndices() {
        return this.rankedIndices;
    }

    public float[][] getProbabilities() {
        return this.probabilities;
    }

    public List<String> getLabels() {
        return this.labels;
    }

    public List<String> getMaxLabels() {
        return this.maxLabels;
    }

    public void setRankedIndices(int[][] iArr) {
        this.rankedIndices = iArr;
    }

    public void setProbabilities(float[][] fArr) {
        this.probabilities = fArr;
    }

    public void setLabels(List<String> list) {
        this.labels = list;
    }

    public void setMaxLabels(List<String> list) {
        this.maxLabels = list;
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof RankClassificationResult)) {
            return false;
        }
        RankClassificationResult rankClassificationResult = (RankClassificationResult) obj;
        if (!rankClassificationResult.canEqual(this) || !Arrays.deepEquals(getRankedIndices(), rankClassificationResult.getRankedIndices()) || !Arrays.deepEquals(getProbabilities(), rankClassificationResult.getProbabilities())) {
            return false;
        }
        List<String> labels = getLabels();
        List<String> labels2 = rankClassificationResult.getLabels();
        if (labels == null) {
            if (labels2 != null) {
                return false;
            }
        } else if (!labels.equals(labels2)) {
            return false;
        }
        List<String> maxLabels = getMaxLabels();
        List<String> maxLabels2 = rankClassificationResult.getMaxLabels();
        return maxLabels == null ? maxLabels2 == null : maxLabels.equals(maxLabels2);
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof RankClassificationResult;
    }

    public int hashCode() {
        int deepHashCode = (((1 * 59) + Arrays.deepHashCode(getRankedIndices())) * 59) + Arrays.deepHashCode(getProbabilities());
        List<String> labels = getLabels();
        int hashCode = (deepHashCode * 59) + (labels == null ? 43 : labels.hashCode());
        List<String> maxLabels = getMaxLabels();
        return (hashCode * 59) + (maxLabels == null ? 43 : maxLabels.hashCode());
    }

    public String toString() {
        return "RankClassificationResult(rankedIndices=" + Arrays.deepToString(getRankedIndices()) + ", probabilities=" + Arrays.deepToString(getProbabilities()) + ", labels=" + getLabels() + ", maxLabels=" + getMaxLabels() + ")";
    }
}
