package org.jpmml.evaluator;

import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.dmg.pmml.CategoricalPredictor;
import org.dmg.pmml.DataField;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.FieldRef;
import org.dmg.pmml.MiningFunctionType;
import org.dmg.pmml.NumericPredictor;
import org.dmg.pmml.OpType;
import org.dmg.pmml.PMML;
import org.dmg.pmml.PredictorTerm;
import org.dmg.pmml.RegressionModel;
import org.dmg.pmml.RegressionNormalizationMethodType;
import org.dmg.pmml.RegressionTable;
import org.jpmml.manager.InvalidFeatureException;
import org.jpmml.manager.UnsupportedFeatureException;

/* loaded from: input_file:WEB-INF/lib/pmml-evaluator-1.1.13.jar:org/jpmml/evaluator/RegressionModelEvaluator.class */
public class RegressionModelEvaluator extends ModelEvaluator<RegressionModel> {
    public RegressionModelEvaluator(PMML pmml) {
        this(pmml, (RegressionModel) find(pmml.getModels(), RegressionModel.class));
    }

    public RegressionModelEvaluator(PMML pmml, RegressionModel regressionModel) {
        super(pmml, regressionModel);
    }

    @Override // org.jpmml.manager.ModelManager, org.jpmml.manager.Consumer
    public String getSummary() {
        return "Regression";
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.jpmml.evaluator.ModelEvaluator
    public Map<FieldName, ?> evaluate(ModelEvaluationContext modelEvaluationContext) {
        Map<FieldName, ? extends Number> evaluateClassification;
        RegressionModel regressionModel = (RegressionModel) getModel();
        if (!regressionModel.isScorable()) {
            throw new InvalidResultException(regressionModel);
        }
        MiningFunctionType functionName = regressionModel.getFunctionName();
        switch (functionName) {
            case REGRESSION:
                evaluateClassification = evaluateRegression(modelEvaluationContext);
                break;
            case CLASSIFICATION:
                evaluateClassification = evaluateClassification(modelEvaluationContext);
                break;
            default:
                throw new UnsupportedFeatureException(regressionModel, functionName);
        }
        return OutputUtil.evaluate(evaluateClassification, modelEvaluationContext);
    }

    /* JADX WARN: Multi-variable type inference failed */
    private Map<FieldName, ? extends Number> evaluateRegression(ModelEvaluationContext modelEvaluationContext) {
        RegressionModel regressionModel = (RegressionModel) getModel();
        FieldName targetFieldName = regressionModel.getTargetFieldName();
        if (targetFieldName == null) {
            targetFieldName = getTargetField();
        }
        List<RegressionTable> regressionTables = regressionModel.getRegressionTables();
        if (regressionTables.size() != 1) {
            throw new InvalidFeatureException(regressionModel);
        }
        Double evaluateRegressionTable = evaluateRegressionTable(regressionTables.get(0), modelEvaluationContext);
        if (evaluateRegressionTable == null) {
            return TargetUtil.evaluateRegressionDefault(modelEvaluationContext);
        }
        return TargetUtil.evaluateRegression((Map<FieldName, ? extends Number>) Collections.singletonMap(targetFieldName, normalizeRegressionResult(evaluateRegressionTable)), modelEvaluationContext);
    }

    private Map<FieldName, ? extends ClassificationMap<?>> evaluateClassification(ModelEvaluationContext modelEvaluationContext) {
        RegressionModel regressionModel = (RegressionModel) getModel();
        FieldName targetFieldName = regressionModel.getTargetFieldName();
        if (targetFieldName == null) {
            targetFieldName = getTargetField();
        }
        DataField dataField = getDataField(targetFieldName);
        OpType optype = dataField.getOptype();
        switch (optype) {
            case CONTINUOUS:
                throw new InvalidFeatureException(dataField);
            case CATEGORICAL:
            case ORDINAL:
                List<RegressionTable> regressionTables = regressionModel.getRegressionTables();
                if (regressionTables.size() < 1) {
                    throw new InvalidFeatureException(regressionModel);
                }
                List<String> targetCategories = ArgumentUtil.getTargetCategories(dataField);
                if (targetCategories.size() > 0 && targetCategories.size() != regressionTables.size()) {
                    throw new InvalidFeatureException(dataField);
                }
                ProbabilityClassificationMap probabilityClassificationMap = new ProbabilityClassificationMap();
                for (RegressionTable regressionTable : regressionTables) {
                    String targetCategory = regressionTable.getTargetCategory();
                    if (targetCategory == null) {
                        throw new InvalidFeatureException(regressionTable);
                    }
                    Double evaluateRegressionTable = evaluateRegressionTable(regressionTable, modelEvaluationContext);
                    if (evaluateRegressionTable == null) {
                        return TargetUtil.evaluateClassificationDefault(modelEvaluationContext);
                    }
                    probabilityClassificationMap.put(targetCategory, evaluateRegressionTable);
                }
                switch (optype) {
                    case CATEGORICAL:
                        computeCategoricalProbabilities(probabilityClassificationMap);
                        break;
                    case ORDINAL:
                        computeOrdinalProbabilities(probabilityClassificationMap, targetCategories);
                        break;
                    default:
                        throw new UnsupportedFeatureException(dataField, optype);
                }
                return TargetUtil.evaluateClassification((Map<FieldName, ? extends ClassificationMap<?>>) Collections.singletonMap(targetFieldName, probabilityClassificationMap), modelEvaluationContext);
            default:
                throw new UnsupportedFeatureException(dataField, optype);
        }
    }

    private Double evaluateRegressionTable(RegressionTable regressionTable, EvaluationContext evaluationContext) {
        double intercept = 0.0d + regressionTable.getIntercept();
        for (NumericPredictor numericPredictor : regressionTable.getNumericPredictors()) {
            FieldName name = numericPredictor.getName();
            FieldValue evaluate = ExpressionUtil.evaluate(name, evaluationContext);
            if (evaluate == null) {
                evaluationContext.addWarning("Missing argument \"" + name.getValue() + "\"");
                return null;
            }
            intercept += numericPredictor.getCoefficient() * Math.pow(evaluate.asNumber().doubleValue(), numericPredictor.getExponent());
        }
        for (CategoricalPredictor categoricalPredictor : regressionTable.getCategoricalPredictors()) {
            FieldName name2 = categoricalPredictor.getName();
            FieldValue evaluate2 = ExpressionUtil.evaluate(name2, evaluationContext);
            if (evaluate2 == null) {
                evaluationContext.addWarning("Missing argument \"" + name2.getValue() + "\"");
            } else {
                intercept += categoricalPredictor.getCoefficient() * (evaluate2.equalsString(categoricalPredictor.getValue()) ? 1.0d : 0.0d);
            }
        }
        for (PredictorTerm predictorTerm : regressionTable.getPredictorTerms()) {
            double coefficient = predictorTerm.getCoefficient();
            List<FieldRef> fieldRefs = predictorTerm.getFieldRefs();
            if (fieldRefs.size() < 1) {
                throw new InvalidFeatureException(predictorTerm);
            }
            Iterator<FieldRef> it = fieldRefs.iterator();
            while (it.hasNext()) {
                FieldValue evaluateFieldRef = ExpressionUtil.evaluateFieldRef(it.next(), evaluationContext);
                if (evaluateFieldRef == null) {
                    return null;
                }
                coefficient *= evaluateFieldRef.asNumber().doubleValue();
            }
            intercept += coefficient;
        }
        return Double.valueOf(intercept);
    }

    /* JADX WARN: Multi-variable type inference failed */
    private Double normalizeRegressionResult(Double d) {
        RegressionModel regressionModel = (RegressionModel) getModel();
        RegressionNormalizationMethodType normalizationMethod = regressionModel.getNormalizationMethod();
        switch (normalizationMethod) {
            case NONE:
                return d;
            case SOFTMAX:
            case LOGIT:
                return Double.valueOf(1.0d / (1.0d + Math.exp(-d.doubleValue())));
            case EXP:
                return Double.valueOf(Math.exp(d.doubleValue()));
            default:
                throw new UnsupportedFeatureException(regressionModel, normalizationMethod);
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    private void computeCategoricalProbabilities(ClassificationMap<String> classificationMap) {
        switch (((RegressionModel) getModel()).getNormalizationMethod()) {
            case NONE:
                return;
            case SOFTMAX:
                ClassificationMap.normalizeSoftMax(classificationMap);
                return;
            case LOGIT:
            case EXP:
            default:
                for (Map.Entry<String, Double> entry : classificationMap.entrySet()) {
                    entry.setValue(normalizeClassificationResult(entry.getValue()));
                }
                classificationMap.normalizeValues();
                return;
            case SIMPLEMAX:
                ClassificationMap.normalize(classificationMap);
                return;
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    private void computeOrdinalProbabilities(ClassificationMap<String> classificationMap, List<String> list) {
        RegressionModel regressionModel = (RegressionModel) getModel();
        RegressionNormalizationMethodType normalizationMethod = regressionModel.getNormalizationMethod();
        switch (normalizationMethod) {
            case NONE:
                return;
            case SOFTMAX:
            case SIMPLEMAX:
                throw new UnsupportedFeatureException(regressionModel, normalizationMethod);
            case LOGIT:
            case EXP:
            default:
                for (Map.Entry<String, Double> entry : classificationMap.entrySet()) {
                    entry.setValue(normalizeClassificationResult(entry.getValue()));
                }
                calculateCategoryProbabilities(classificationMap, list);
                return;
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    private Double normalizeClassificationResult(Double d) {
        RegressionModel regressionModel = (RegressionModel) getModel();
        RegressionNormalizationMethodType normalizationMethod = regressionModel.getNormalizationMethod();
        switch (normalizationMethod) {
            case LOGIT:
                return Double.valueOf(1.0d / (1.0d + Math.exp(-d.doubleValue())));
            case EXP:
            case SIMPLEMAX:
            default:
                throw new UnsupportedFeatureException(regressionModel, normalizationMethod);
            case PROBIT:
                return Double.valueOf(NormalDistributionUtil.cumulativeProbability(d.doubleValue()));
            case CLOGLOG:
                return Double.valueOf(1.0d - Math.exp(-Math.exp(d.doubleValue())));
            case LOGLOG:
                return Double.valueOf(Math.exp(-Math.exp(-d.doubleValue())));
            case CAUCHIT:
                return Double.valueOf(0.5d + (0.3183098861837907d * Math.atan(d.doubleValue())));
        }
    }

    public static void calculateCategoryProbabilities(Map<String, Double> map, List<String> list) {
        double d = 0.0d;
        for (int i = 0; i < list.size() - 1; i++) {
            String str = list.get(i);
            Double d2 = map.get(str);
            if (d2 == null || d2.doubleValue() > 1.0d) {
                throw new EvaluationException();
            }
            Double valueOf = Double.valueOf(d2.doubleValue() - d);
            if (valueOf.doubleValue() < 0.0d) {
                throw new EvaluationException();
            }
            map.put(str, valueOf);
            d = d2.doubleValue();
        }
        if (list.size() > 1) {
            map.put(list.get(list.size() - 1), Double.valueOf(1.0d - d));
        }
    }
}
