package com.datarobot.mlops.common.metrics.predictionStats;

import com.datarobot.mlops.common.exceptions.DRCommonException;
import com.datarobot.mlops.common.metrics.PredictionsType;
import com.fasterxml.jackson.core.JsonGenerator;
import com.fasterxml.jackson.databind.SerializerProvider;
import com.fasterxml.jackson.databind.ser.std.StdSerializer;
import com.google.gson.JsonArray;
import com.google.gson.JsonDeserializationContext;
import com.google.gson.JsonDeserializer;
import com.google.gson.JsonElement;
import com.google.gson.JsonParseException;
import com.google.gson.JsonSerializationContext;
import com.google.gson.JsonSerializer;
import com.ibm.icu.text.PluralRules;
import java.io.IOException;
import java.lang.reflect.Type;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/datarobot/mlops/common/metrics/predictionStats/PredictionStatistics.class */
public class PredictionStatistics implements StatisticsValidator {
    private static final Logger logger = LoggerFactory.getLogger((Class<?>) PredictionStatistics.class);
    private Object predictions;
    private transient PredictionsType predictionsType;

    /* loaded from: input_file:com/datarobot/mlops/common/metrics/predictionStats/PredictionStatistics$PredictionStatisticsDeserializer.class */
    public static class PredictionStatisticsDeserializer implements JsonDeserializer {
        @Override // com.google.gson.JsonDeserializer
        public Object deserialize(JsonElement jsonElement, Type type, JsonDeserializationContext jsonDeserializationContext) throws JsonParseException {
            PredictionStatistics predictionStatistics = new PredictionStatistics();
            if (jsonElement.isJsonObject()) {
                predictionStatistics.setPredictionsType(PredictionsType.REGRESSION);
                predictionStatistics.setPredictions(jsonDeserializationContext.deserialize(jsonElement.getAsJsonObject(), NumericStats.class));
            } else {
                JsonArray asJsonArray = jsonElement.getAsJsonArray();
                if (asJsonArray.size() == 2) {
                    predictionStatistics.setPredictionsType(PredictionsType.BINARY);
                } else {
                    predictionStatistics.setPredictionsType(PredictionsType.MULTICLASS);
                }
                ArrayList arrayList = new ArrayList();
                Iterator<JsonElement> it2 = asJsonArray.iterator();
                while (it2.hasNext()) {
                    arrayList.add((NumericStats) jsonDeserializationContext.deserialize(it2.next().getAsJsonObject(), NumericStats.class));
                }
                predictionStatistics.setPredictions(arrayList);
            }
            return predictionStatistics;
        }
    }

    /* loaded from: input_file:com/datarobot/mlops/common/metrics/predictionStats/PredictionStatistics$PredictionStatisticsJacksonSerializer.class */
    public static class PredictionStatisticsJacksonSerializer extends StdSerializer<PredictionStatistics> {
        public PredictionStatisticsJacksonSerializer(Class<PredictionStatistics> cls) {
            super(cls);
        }

        public PredictionStatisticsJacksonSerializer() {
            this(null);
        }

        @Override // com.fasterxml.jackson.databind.ser.std.StdSerializer, com.fasterxml.jackson.databind.JsonSerializer
        public void serialize(PredictionStatistics predictionStatistics, JsonGenerator jsonGenerator, SerializerProvider serializerProvider) throws IOException {
            if (predictionStatistics.predictionsType == PredictionsType.REGRESSION) {
                jsonGenerator.writeObject(predictionStatistics.predictions);
                return;
            }
            if (!(predictionStatistics.predictions instanceof List)) {
                PredictionStatistics.logger.error("Predictions should be list of Numeric Stats, found: " + predictionStatistics.predictions.getClass());
                return;
            }
            List list = (List) predictionStatistics.predictions;
            jsonGenerator.writeStartArray();
            for (Object obj : list) {
                if (obj instanceof NumericStats) {
                    jsonGenerator.writeObject(obj);
                } else {
                    PredictionStatistics.logger.warn("Found non numeric prediction: " + obj.toString() + " Ignoring it");
                }
            }
            jsonGenerator.writeEndArray();
        }
    }

    /* loaded from: input_file:com/datarobot/mlops/common/metrics/predictionStats/PredictionStatistics$PredictionStatisticsSerializer.class */
    public static class PredictionStatisticsSerializer implements JsonSerializer {
        @Override // com.google.gson.JsonSerializer
        public JsonElement serialize(Object obj, Type type, JsonSerializationContext jsonSerializationContext) {
            if (!(obj instanceof PredictionStatistics)) {
                PredictionStatistics.logger.warn("Received instance of type " + obj.getClass() + " to serialize instead of PredictionStatistics, ignoring it");
                return null;
            }
            PredictionStatistics predictionStatistics = (PredictionStatistics) obj;
            switch (predictionStatistics.predictionsType) {
                case REGRESSION:
                    return jsonSerializationContext.serialize((NumericStats) predictionStatistics.predictions);
                default:
                    if (!(predictionStatistics.predictions instanceof List)) {
                        PredictionStatistics.logger.error("Predictions should be list of Numeric Stats, found: " + predictionStatistics.predictions.getClass());
                        return null;
                    }
                    List list = (List) predictionStatistics.predictions;
                    JsonArray jsonArray = new JsonArray();
                    for (Object obj2 : list) {
                        if (obj2 instanceof NumericStats) {
                            jsonArray.add(jsonSerializationContext.serialize(obj2));
                        } else {
                            PredictionStatistics.logger.warn("Found non numeric prediction: " + obj2.toString() + " Ignoring it");
                        }
                    }
                    return jsonArray;
            }
        }
    }

    public PredictionStatistics() {
    }

    public PredictionStatistics(List<NumericStats> list) {
        this.predictions = list;
        if (list.size() > 2) {
            this.predictionsType = PredictionsType.MULTICLASS;
        } else {
            this.predictionsType = PredictionsType.BINARY;
        }
    }

    public PredictionStatistics(NumericStats numericStats) {
        this.predictions = numericStats;
        this.predictionsType = PredictionsType.REGRESSION;
    }

    public void setPredictionsType(PredictionsType predictionsType) {
        this.predictionsType = predictionsType;
    }

    public void setPredictions(Object obj) {
        this.predictions = obj;
    }

    public Object getPredictions() {
        return this.predictions;
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null || getClass() != obj.getClass()) {
            return false;
        }
        PredictionStatistics predictionStatistics = (PredictionStatistics) obj;
        return Objects.equals(this.predictions, predictionStatistics.predictions) && this.predictionsType == predictionStatistics.predictionsType;
    }

    public int hashCode() {
        return Objects.hash(this.predictions, this.predictionsType);
    }

    @Override // com.datarobot.mlops.common.metrics.predictionStats.StatisticsValidator
    public String isValid() {
        switch (this.predictionsType) {
            case REGRESSION:
                if (!(this.predictions instanceof NumericStats)) {
                    return "Regression prediction is of type '" + this.predictions.getClass() + "' expected numeric stats";
                }
                String isValid = ((NumericStats) this.predictions).isValid();
                if (isValid != null) {
                    return "Regression Prediction: " + isValid;
                }
                return null;
            default:
                if (!(this.predictions instanceof List)) {
                    return "Classification prediction is of type '" + this.predictions.getClass() + "' expected list of numeric stats";
                }
                List list = (List) this.predictions;
                if (list != null && (list.size() < 2 || list.size() > 10)) {
                    return "Invalid class count in the input " + list.size() + ", API supports  Min: 2 Max: 10 classes";
                }
                for (int i = 0; i < list.size(); i++) {
                    Object obj = list.get(i);
                    if (!(obj instanceof NumericStats)) {
                        return "Prediction at index " + i + " is of type '" + obj.getClass() + "', expected numeric stats";
                    }
                    String isValid2 = ((NumericStats) obj).isValid();
                    if (isValid2 != null) {
                        return "Classification Prediction at index " + i + PluralRules.KEYWORD_RULE_SEPARATOR + isValid2;
                    }
                }
                return null;
        }
    }

    public Long getNumPredictions() throws DRCommonException {
        if (this.predictionsType == PredictionsType.REGRESSION) {
            return ((NumericStats) this.predictions).getCount();
        }
        if (this.predictionsType == PredictionsType.MULTICLASS || this.predictionsType == PredictionsType.BINARY) {
            return ((NumericStats) ((List) this.predictions).get(0)).getCount();
        }
        throw new DRCommonException("Prediction type " + this.predictionsType + " not supported");
    }

    public PredictionsType getPredictionsType() {
        return this.predictionsType;
    }
}
