package org.jpmml.evaluator.general_regression;

import com.google.common.base.Function;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.BiMap;
import com.google.common.collect.HashBiMap;
import com.google.common.collect.ImmutableBiMap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import com.google.common.collect.ListMultimap;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.dmg.pmml.DataType;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.HasValue;
import org.dmg.pmml.Matrix;
import org.dmg.pmml.PMML;
import org.dmg.pmml.general_regression.BaseCumHazardTables;
import org.dmg.pmml.general_regression.BaselineCell;
import org.dmg.pmml.general_regression.BaselineStratum;
import org.dmg.pmml.general_regression.Categories;
import org.dmg.pmml.general_regression.Category;
import org.dmg.pmml.general_regression.GeneralRegressionModel;
import org.dmg.pmml.general_regression.PCell;
import org.dmg.pmml.general_regression.PMMLAttributes;
import org.dmg.pmml.general_regression.PMMLElements;
import org.dmg.pmml.general_regression.PPCell;
import org.dmg.pmml.general_regression.Parameter;
import org.dmg.pmml.general_regression.ParameterCell;
import org.dmg.pmml.general_regression.ParameterList;
import org.dmg.pmml.general_regression.Predictor;
import org.dmg.pmml.general_regression.PredictorList;
import org.jpmml.evaluator.CacheUtil;
import org.jpmml.evaluator.Classification;
import org.jpmml.evaluator.EvaluationContext;
import org.jpmml.evaluator.FieldValue;
import org.jpmml.evaluator.FieldValueUtil;
import org.jpmml.evaluator.InvalidAttributeException;
import org.jpmml.evaluator.InvalidElementException;
import org.jpmml.evaluator.MapHolder;
import org.jpmml.evaluator.MatrixUtil;
import org.jpmml.evaluator.MisplacedAttributeException;
import org.jpmml.evaluator.MissingAttributeException;
import org.jpmml.evaluator.MissingElementException;
import org.jpmml.evaluator.MissingValueException;
import org.jpmml.evaluator.ModelEvaluator;
import org.jpmml.evaluator.NumberUtil;
import org.jpmml.evaluator.Numbers;
import org.jpmml.evaluator.PMMLUtil;
import org.jpmml.evaluator.TargetField;
import org.jpmml.evaluator.TargetUtil;
import org.jpmml.evaluator.TypeInfo;
import org.jpmml.evaluator.TypeUtil;
import org.jpmml.evaluator.UnsupportedAttributeException;
import org.jpmml.evaluator.UnsupportedElementException;
import org.jpmml.evaluator.Value;
import org.jpmml.evaluator.ValueFactory;
import org.jpmml.evaluator.ValueMap;
import org.jpmml.evaluator.ValueUtil;

/* loaded from: input_file:WEB-INF/lib/pmml-evaluator-1.5.1.jar:org/jpmml/evaluator/general_regression/GeneralRegressionModelEvaluator.class */
public class GeneralRegressionModelEvaluator extends ModelEvaluator<GeneralRegressionModel> {
    private transient BiMap<String, Parameter> parameterRegistry;
    private transient Map<Object, Map<String, Row>> ppMatrixMap;
    private transient Map<Object, List<PCell>> paramMatrixMap;
    private transient List<Object> targetCategories;
    private static final LoadingCache<GeneralRegressionModel, BiMap<String, Parameter>> parameterCache = CacheUtil.buildLoadingCache(new CacheLoader<GeneralRegressionModel, BiMap<String, Parameter>>() { // from class: org.jpmml.evaluator.general_regression.GeneralRegressionModelEvaluator.3
        @Override // com.google.common.cache.CacheLoader
        public BiMap<String, Parameter> load(GeneralRegressionModel generalRegressionModel) {
            return ImmutableBiMap.copyOf((Map) GeneralRegressionModelEvaluator.parseParameterRegistry(generalRegressionModel.getParameterList()));
        }
    });
    private static final LoadingCache<GeneralRegressionModel, BiMap<FieldName, Predictor>> factorCache = CacheUtil.buildLoadingCache(new CacheLoader<GeneralRegressionModel, BiMap<FieldName, Predictor>>() { // from class: org.jpmml.evaluator.general_regression.GeneralRegressionModelEvaluator.4
        @Override // com.google.common.cache.CacheLoader
        public BiMap<FieldName, Predictor> load(GeneralRegressionModel generalRegressionModel) {
            return ImmutableBiMap.copyOf((Map) GeneralRegressionModelEvaluator.parsePredictorRegistry(generalRegressionModel.getFactorList()));
        }
    });
    private static final LoadingCache<GeneralRegressionModel, BiMap<FieldName, Predictor>> covariateCache = CacheUtil.buildLoadingCache(new CacheLoader<GeneralRegressionModel, BiMap<FieldName, Predictor>>() { // from class: org.jpmml.evaluator.general_regression.GeneralRegressionModelEvaluator.5
        @Override // com.google.common.cache.CacheLoader
        public BiMap<FieldName, Predictor> load(GeneralRegressionModel generalRegressionModel) {
            return ImmutableBiMap.copyOf((Map) GeneralRegressionModelEvaluator.parsePredictorRegistry(generalRegressionModel.getCovariateList()));
        }
    });
    private static final LoadingCache<GeneralRegressionModel, Map<Object, Map<String, Row>>> ppMatrixCache = CacheUtil.buildLoadingCache(new CacheLoader<GeneralRegressionModel, Map<Object, Map<String, Row>>>() { // from class: org.jpmml.evaluator.general_regression.GeneralRegressionModelEvaluator.6
        @Override // com.google.common.cache.CacheLoader
        public Map<Object, Map<String, Row>> load(GeneralRegressionModel generalRegressionModel) {
            return Collections.unmodifiableMap(GeneralRegressionModelEvaluator.parsePPMatrix(generalRegressionModel));
        }
    });
    private static final LoadingCache<GeneralRegressionModel, Map<Object, List<PCell>>> paramMatrixCache = CacheUtil.buildLoadingCache(new CacheLoader<GeneralRegressionModel, Map<Object, List<PCell>>>() { // from class: org.jpmml.evaluator.general_regression.GeneralRegressionModelEvaluator.7
        @Override // com.google.common.cache.CacheLoader
        public Map<Object, List<PCell>> load(GeneralRegressionModel generalRegressionModel) {
            return Collections.unmodifiableMap(GeneralRegressionModelEvaluator.parseParamMatrix(generalRegressionModel));
        }
    });

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:WEB-INF/lib/pmml-evaluator-1.5.1.jar:org/jpmml/evaluator/general_regression/GeneralRegressionModelEvaluator$Row.class */
    public static class Row {
        private List<FactorHandler> factorHandlers;
        private List<CovariateHandler> covariateHandlers;

        /* JADX INFO: Access modifiers changed from: private */
        /* loaded from: input_file:WEB-INF/lib/pmml-evaluator-1.5.1.jar:org/jpmml/evaluator/general_regression/GeneralRegressionModelEvaluator$Row$ContrastMatrixHandler.class */
        public class ContrastMatrixHandler extends FactorHandler {
            private Matrix matrix;
            private List<Object> categories;
            private List<FieldValue> parsedCategories;

            private ContrastMatrixHandler(PPCell pPCell, Matrix matrix, List<Object> list) {
                super(pPCell);
                this.matrix = null;
                this.categories = null;
                this.parsedCategories = null;
                setMatrix(matrix);
                setCategories(list);
            }

            @Override // org.jpmml.evaluator.general_regression.GeneralRegressionModelEvaluator.Row.FactorHandler, org.jpmml.evaluator.general_regression.GeneralRegressionModelEvaluator.Row.PredictorHandler
            public <V extends Number> Value<V> updateProduct(Value<V> value, FieldValue fieldValue) {
                Matrix matrix = getMatrix();
                int index = getIndex(fieldValue);
                int index2 = getIndex(getCategory());
                if (index < 0 || index2 < 0) {
                    throw new InvalidElementException(getPPCell());
                }
                Number elementAt = MatrixUtil.getElementAt(matrix, index + 1, index2 + 1);
                if (elementAt == null) {
                    throw new InvalidElementException(matrix);
                }
                return value.multiply2(elementAt);
            }

            public int getIndex(FieldValue fieldValue) {
                if (this.parsedCategories == null) {
                    this.parsedCategories = ImmutableList.copyOf((Collection) parseCategories(fieldValue));
                }
                return this.parsedCategories.indexOf(fieldValue);
            }

            public int getIndex(Object obj) {
                return getCategories().indexOf(obj);
            }

            private List<FieldValue> parseCategories(TypeInfo typeInfo) {
                return Lists.transform(getCategories(), obj -> {
                    return FieldValueUtil.create(typeInfo, obj);
                });
            }

            public Matrix getMatrix() {
                return this.matrix;
            }

            private void setMatrix(Matrix matrix) {
                this.matrix = matrix;
            }

            public List<Object> getCategories() {
                return this.categories;
            }

            private void setCategories(List<Object> list) {
                this.categories = list;
            }
        }

        /* JADX INFO: Access modifiers changed from: private */
        /* loaded from: input_file:WEB-INF/lib/pmml-evaluator-1.5.1.jar:org/jpmml/evaluator/general_regression/GeneralRegressionModelEvaluator$Row$CovariateHandler.class */
        public class CovariateHandler extends PredictorHandler {
            private Number exponent;

            private CovariateHandler(PPCell pPCell) {
                super(pPCell);
                this.exponent = null;
                Object value = pPCell.getValue();
                if (value == null) {
                    throw new MissingAttributeException(pPCell, PMMLAttributes.PPCELL_VALUE);
                }
                Number number = (Number) TypeUtil.parseOrCast(DataType.DOUBLE, value);
                setExponent(number.doubleValue() == 1.0d ? null : number);
            }

            @Override // org.jpmml.evaluator.general_regression.GeneralRegressionModelEvaluator.Row.PredictorHandler
            public <V extends Number> Value<V> updateProduct(Value<V> value, FieldValue fieldValue) {
                Number exponent = getExponent();
                return exponent != null ? value.multiply2(fieldValue.asNumber(), exponent) : value.multiply2(fieldValue.asNumber());
            }

            public Number getExponent() {
                return this.exponent;
            }

            private void setExponent(Number number) {
                this.exponent = number;
            }
        }

        /* JADX INFO: Access modifiers changed from: private */
        /* loaded from: input_file:WEB-INF/lib/pmml-evaluator-1.5.1.jar:org/jpmml/evaluator/general_regression/GeneralRegressionModelEvaluator$Row$FactorHandler.class */
        public class FactorHandler extends PredictorHandler {
            private Object category;

            private FactorHandler(PPCell pPCell) {
                super(pPCell);
                this.category = null;
                Object value = pPCell.getValue();
                if (value == null) {
                    throw new MissingAttributeException(pPCell, PMMLAttributes.PPCELL_VALUE);
                }
                setCategory(value);
            }

            @Override // org.jpmml.evaluator.general_regression.GeneralRegressionModelEvaluator.Row.PredictorHandler
            public <V extends Number> Value<V> updateProduct(Value<V> value, FieldValue fieldValue) {
                return fieldValue.equals((HasValue<?>) getPPCell()) ? value : value.multiply2(Numbers.DOUBLE_ZERO);
            }

            public Object getCategory() {
                return this.category;
            }

            private void setCategory(Object obj) {
                this.category = obj;
            }
        }

        /* JADX INFO: Access modifiers changed from: private */
        /* loaded from: input_file:WEB-INF/lib/pmml-evaluator-1.5.1.jar:org/jpmml/evaluator/general_regression/GeneralRegressionModelEvaluator$Row$PredictorHandler.class */
        public abstract class PredictorHandler {
            private PPCell ppCell;

            private PredictorHandler(PPCell pPCell) {
                this.ppCell = null;
                setPPCell(pPCell);
                if (pPCell.getField() == null) {
                    throw new MissingAttributeException(pPCell, PMMLAttributes.PPCELL_FIELD);
                }
            }

            public abstract <V extends Number> Value<V> updateProduct(Value<V> value, FieldValue fieldValue);

            public FieldName getField() {
                return getPPCell().getField();
            }

            public PPCell getPPCell() {
                return this.ppCell;
            }

            private void setPPCell(PPCell pPCell) {
                this.ppCell = pPCell;
            }
        }

        private Row() {
            this.factorHandlers = new ArrayList();
            this.covariateHandlers = new ArrayList();
        }

        public <V extends Number> Value<V> evaluate(ValueFactory<V> valueFactory, EvaluationContext evaluationContext) {
            Value<V> newValue = valueFactory.newValue(Numbers.DOUBLE_ONE);
            List<FactorHandler> factorHandlers = getFactorHandlers();
            int size = factorHandlers.size();
            for (int i = 0; i < size; i++) {
                FactorHandler factorHandler = factorHandlers.get(i);
                FieldValue evaluate = evaluationContext.evaluate(factorHandler.getField());
                if (FieldValueUtil.isMissing(evaluate)) {
                    return null;
                }
                factorHandler.updateProduct(newValue, evaluate);
            }
            if (newValue.isZero()) {
                return newValue;
            }
            List<CovariateHandler> covariateHandlers = getCovariateHandlers();
            int size2 = covariateHandlers.size();
            for (int i2 = 0; i2 < size2; i2++) {
                CovariateHandler covariateHandler = covariateHandlers.get(i2);
                FieldValue evaluate2 = evaluationContext.evaluate(covariateHandler.getField());
                if (FieldValueUtil.isMissing(evaluate2)) {
                    return null;
                }
                covariateHandler.updateProduct(newValue, evaluate2);
            }
            return newValue;
        }

        public void addFactor(PPCell pPCell, Predictor predictor) {
            List<FactorHandler> factorHandlers = getFactorHandlers();
            Matrix matrix = predictor.getMatrix();
            if (matrix == null) {
                factorHandlers.add(new FactorHandler(pPCell));
                return;
            }
            Categories categories = predictor.getCategories();
            if (categories == null) {
                throw new UnsupportedElementException(predictor);
            }
            factorHandlers.add(new ContrastMatrixHandler(pPCell, matrix, Lists.transform(categories.getCategories(), new Function<Category, Object>() { // from class: org.jpmml.evaluator.general_regression.GeneralRegressionModelEvaluator.Row.1
                @Override // com.google.common.base.Function, java.util.function.Function
                public Object apply(Category category) {
                    Object value = category.getValue();
                    if (value == null) {
                        throw new MissingAttributeException(category, PMMLAttributes.CATEGORY_VALUE);
                    }
                    return value;
                }
            })));
        }

        /* JADX INFO: Access modifiers changed from: private */
        public void addCovariate(PPCell pPCell) {
            getCovariateHandlers().add(new CovariateHandler(pPCell));
        }

        public List<FactorHandler> getFactorHandlers() {
            return this.factorHandlers;
        }

        public List<CovariateHandler> getCovariateHandlers() {
            return this.covariateHandlers;
        }
    }

    public GeneralRegressionModelEvaluator(PMML pmml) {
        this(pmml, (GeneralRegressionModel) PMMLUtil.findModel(pmml, GeneralRegressionModel.class));
    }

    public GeneralRegressionModelEvaluator(PMML pmml, GeneralRegressionModel generalRegressionModel) {
        super(pmml, generalRegressionModel);
        this.parameterRegistry = null;
        this.ppMatrixMap = null;
        this.paramMatrixMap = null;
        this.targetCategories = null;
        if (generalRegressionModel.getModelType() == null) {
            throw new MissingAttributeException(generalRegressionModel, PMMLAttributes.GENERALREGRESSIONMODEL_MODELTYPE);
        }
        if (generalRegressionModel.getParameterList() == null) {
            throw new MissingElementException(generalRegressionModel, PMMLElements.GENERALREGRESSIONMODEL_PARAMETERLIST);
        }
        if (generalRegressionModel.getPPMatrix() == null) {
            throw new MissingElementException(generalRegressionModel, PMMLElements.GENERALREGRESSIONMODEL_PPMATRIX);
        }
        if (generalRegressionModel.getParamMatrix() == null) {
            throw new MissingElementException(generalRegressionModel, PMMLElements.GENERALREGRESSIONMODEL_PARAMMATRIX);
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.jpmml.evaluator.Evaluator
    public String getSummary() {
        switch (((GeneralRegressionModel) getModel()).getModelType()) {
            case COX_REGRESSION:
                return "Cox regression";
            default:
                return "General regression";
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.jpmml.evaluator.ModelEvaluator
    protected <V extends Number> Map<FieldName, ?> evaluateRegression(ValueFactory<V> valueFactory, EvaluationContext evaluationContext) {
        switch (((GeneralRegressionModel) getModel()).getModelType()) {
            case COX_REGRESSION:
                return evaluateCoxRegression(valueFactory, evaluationContext);
            default:
                return evaluateGeneralRegression(valueFactory, evaluationContext);
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    private <V extends Number> Map<FieldName, ?> evaluateCoxRegression(ValueFactory<V> valueFactory, EvaluationContext evaluationContext) {
        List<BaselineCell> baselineCells;
        Number maxTime;
        GeneralRegressionModel generalRegressionModel = (GeneralRegressionModel) getModel();
        TargetField targetField = getTargetField();
        generalRegressionModel.getStartTimeVariable();
        FieldName endTimeVariable = generalRegressionModel.getEndTimeVariable();
        if (endTimeVariable == null) {
            throw new MissingAttributeException(generalRegressionModel, PMMLAttributes.GENERALREGRESSIONMODEL_ENDTIMEVARIABLE);
        }
        BaseCumHazardTables baseCumHazardTables = generalRegressionModel.getBaseCumHazardTables();
        if (baseCumHazardTables == null) {
            throw new MissingElementException(generalRegressionModel, PMMLElements.GENERALREGRESSIONMODEL_BASECUMHAZARDTABLES);
        }
        FieldName baselineStrataVariable = generalRegressionModel.getBaselineStrataVariable();
        if (baselineStrataVariable != null) {
            BaselineStratum baselineStratum = getBaselineStratum(baseCumHazardTables, getVariable(baselineStrataVariable, evaluationContext));
            if (baselineStratum == null) {
                return TargetUtil.evaluateRegressionDefault(valueFactory, targetField);
            }
            baselineCells = baselineStratum.getBaselineCells();
            if (baselineCells.isEmpty()) {
                throw new MissingElementException(baselineStratum, PMMLElements.BASELINESTRATUM_BASELINECELLS);
            }
            maxTime = baselineStratum.getMaxTime();
            if (maxTime == null) {
                throw new MissingAttributeException(baselineStratum, PMMLAttributes.BASELINESTRATUM_MAXTIME);
            }
        } else {
            baselineCells = baseCumHazardTables.getBaselineCells();
            if (baselineCells.isEmpty()) {
                throw new MissingElementException(baseCumHazardTables, PMMLElements.BASECUMHAZARDTABLES_BASELINECELLS);
            }
            maxTime = baseCumHazardTables.getMaxTime();
            if (maxTime == null) {
                throw new MissingAttributeException(baseCumHazardTables, PMMLAttributes.BASECUMHAZARDTABLES_MAXTIME);
            }
        }
        Comparator<BaselineCell> comparator = new Comparator<BaselineCell>() { // from class: org.jpmml.evaluator.general_regression.GeneralRegressionModelEvaluator.1
            @Override // java.util.Comparator
            public int compare(BaselineCell baselineCell, BaselineCell baselineCell2) {
                return NumberUtil.compare(getTime(baselineCell), getTime(baselineCell2));
            }

            private Number getTime(BaselineCell baselineCell) {
                Number time = baselineCell.getTime();
                if (time == null) {
                    throw new MissingAttributeException(baselineCell, PMMLAttributes.BASELINECELL_TIME);
                }
                return time;
            }
        };
        Number time = baselineCells.stream().min(comparator).get().getTime();
        FieldValue variable = getVariable(endTimeVariable, evaluationContext);
        if (variable.compareToValue(maxTime) > 0) {
            return TargetUtil.evaluateRegressionDefault(valueFactory, targetField);
        }
        if (variable.compareToValue(time) < 0) {
            return TargetUtil.evaluateRegression(targetField, valueFactory.newValue(Numbers.DOUBLE_ZERO));
        }
        Number asNumber = variable.asNumber();
        BaselineCell baselineCell = baselineCells.stream().filter(baselineCell2 -> {
            return NumberUtil.compare(baselineCell2.getTime(), asNumber) <= 0;
        }).max(comparator).get();
        Number cumHazard = baselineCell.getCumHazard();
        if (cumHazard == null) {
            throw new MissingAttributeException(baselineCell, PMMLAttributes.BASELINECELL_CUMHAZARD);
        }
        Value<V> computeDotProduct = computeDotProduct(valueFactory, evaluationContext);
        Value<? extends Number> computeReferencePoint = computeReferencePoint(valueFactory);
        return (computeDotProduct == null || computeReferencePoint == null) ? TargetUtil.evaluateRegressionDefault(valueFactory, targetField) : TargetUtil.evaluateRegression(targetField, computeDotProduct.subtract(computeReferencePoint).exp2().multiply2(cumHazard));
    }

    /* JADX WARN: Multi-variable type inference failed */
    private <V extends Number> Map<FieldName, ?> evaluateGeneralRegression(ValueFactory<V> valueFactory, EvaluationContext evaluationContext) {
        GeneralRegressionModel generalRegressionModel = (GeneralRegressionModel) getModel();
        TargetField targetField = getTargetField();
        Value<V> computeDotProduct = computeDotProduct(valueFactory, evaluationContext);
        if (computeDotProduct == null) {
            return TargetUtil.evaluateRegressionDefault(valueFactory, targetField);
        }
        GeneralRegressionModel.ModelType modelType = generalRegressionModel.getModelType();
        switch (modelType) {
            case COX_REGRESSION:
            case MULTINOMIAL_LOGISTIC:
            case ORDINAL_MULTINOMIAL:
                throw new InvalidAttributeException(generalRegressionModel, modelType);
            case REGRESSION:
            case GENERAL_LINEAR:
                break;
            case GENERALIZED_LINEAR:
                computeDotProduct = computeLink(computeDotProduct, evaluationContext);
                break;
            default:
                throw new UnsupportedAttributeException(generalRegressionModel, modelType);
        }
        return TargetUtil.evaluateRegression(targetField, computeDotProduct);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.jpmml.evaluator.ModelEvaluator
    protected <V extends Number> Map<FieldName, ? extends Classification<?, V>> evaluateClassification(ValueFactory<V> valueFactory, EvaluationContext evaluationContext) {
        Value<V> newValue;
        Map<String, Row> map;
        Iterable<PCell> concat;
        GeneralRegressionModel generalRegressionModel = (GeneralRegressionModel) getModel();
        TargetField targetField = getTargetField();
        List<Object> targetCategories = getTargetCategories();
        GeneralRegressionModel.ModelType modelType = generalRegressionModel.getModelType();
        Map<Object, Map<String, Row>> pPMatrixMap = getPPMatrixMap();
        Map<Object, List<PCell>> paramMatrixMap = getParamMatrixMap();
        ValueMap<Object, V> valueMap = new ValueMap<>(2 * targetCategories.size());
        Value<? extends Number> value = null;
        Value<? extends Number> value2 = null;
        for (int i = 0; i < targetCategories.size(); i++) {
            Object obj = targetCategories.get(i);
            if (i < targetCategories.size() - 1) {
                if (pPMatrixMap.isEmpty()) {
                    map = Collections.emptyMap();
                } else {
                    map = pPMatrixMap.get(obj);
                    if (map == null) {
                        map = pPMatrixMap.get(null);
                    }
                    if (map == null) {
                        throw new InvalidElementException(generalRegressionModel.getPPMatrix());
                    }
                }
                switch (modelType) {
                    case COX_REGRESSION:
                    case REGRESSION:
                    case GENERAL_LINEAR:
                        throw new InvalidAttributeException(generalRegressionModel, modelType);
                    case GENERALIZED_LINEAR:
                    case MULTINOMIAL_LOGISTIC:
                        concat = paramMatrixMap.get(obj);
                        if (concat == null && targetCategories.size() == 2) {
                            concat = paramMatrixMap.get(null);
                        }
                        if (concat == null) {
                            throw new InvalidElementException(generalRegressionModel.getParamMatrix());
                        }
                        break;
                    case ORDINAL_MULTINOMIAL:
                        List<PCell> list = paramMatrixMap.get(obj);
                        if (list == null || list.size() != 1) {
                            throw new InvalidElementException(generalRegressionModel.getParamMatrix());
                        }
                        List<PCell> list2 = paramMatrixMap.get(null);
                        if (list2 == null) {
                            throw new InvalidElementException(generalRegressionModel.getParamMatrix());
                        }
                        concat = Iterables.concat(list, list2);
                        break;
                        break;
                    default:
                        throw new UnsupportedAttributeException(generalRegressionModel, modelType);
                }
                newValue = computeDotProduct(valueFactory, concat, map, evaluationContext);
                if (newValue == null) {
                    return TargetUtil.evaluateClassificationDefault(valueFactory, targetField);
                }
                switch (modelType) {
                    case COX_REGRESSION:
                    case REGRESSION:
                    case GENERAL_LINEAR:
                        throw new InvalidAttributeException(generalRegressionModel, modelType);
                    case GENERALIZED_LINEAR:
                        newValue = computeLink(newValue, evaluationContext);
                        break;
                    case MULTINOMIAL_LOGISTIC:
                        newValue.exp2();
                        break;
                    case ORDINAL_MULTINOMIAL:
                        newValue = computeCumulativeLink(newValue, evaluationContext);
                        break;
                    default:
                        throw new UnsupportedAttributeException(generalRegressionModel, modelType);
                }
            } else {
                switch (modelType) {
                    case COX_REGRESSION:
                    case REGRESSION:
                    case GENERAL_LINEAR:
                        throw new InvalidAttributeException(generalRegressionModel, modelType);
                    case GENERALIZED_LINEAR:
                        newValue = valueFactory.newValue(Numbers.DOUBLE_ONE);
                        if (value != null) {
                            newValue.subtract(value);
                            break;
                        }
                        break;
                    case MULTINOMIAL_LOGISTIC:
                        newValue = valueFactory.newValue(Numbers.DOUBLE_ZERO).exp2();
                        break;
                    case ORDINAL_MULTINOMIAL:
                        newValue = valueFactory.newValue(Numbers.DOUBLE_ONE);
                        break;
                    default:
                        throw new UnsupportedAttributeException(generalRegressionModel, modelType);
                }
            }
            switch (modelType) {
                case COX_REGRESSION:
                case REGRESSION:
                case GENERAL_LINEAR:
                    throw new InvalidAttributeException(generalRegressionModel, modelType);
                case GENERALIZED_LINEAR:
                    value = newValue;
                    break;
                case MULTINOMIAL_LOGISTIC:
                    break;
                case ORDINAL_MULTINOMIAL:
                    Value<? extends Number> copy2 = newValue.copy2();
                    if (value2 != null) {
                        newValue.subtract(value2);
                    }
                    value2 = copy2;
                    break;
                default:
                    throw new UnsupportedAttributeException(generalRegressionModel, modelType);
            }
            valueMap.put(obj, newValue);
        }
        switch (modelType) {
            case COX_REGRESSION:
            case REGRESSION:
            case GENERAL_LINEAR:
                throw new InvalidAttributeException(generalRegressionModel, modelType);
            case GENERALIZED_LINEAR:
            case ORDINAL_MULTINOMIAL:
                break;
            case MULTINOMIAL_LOGISTIC:
                ValueUtil.normalizeSimpleMax(valueMap);
                break;
            default:
                throw new UnsupportedAttributeException(generalRegressionModel, modelType);
        }
        return TargetUtil.evaluateClassification(targetField, createClassification(valueMap));
    }

    /* JADX WARN: Multi-variable type inference failed */
    private <V extends Number> Value<V> computeDotProduct(ValueFactory<V> valueFactory, EvaluationContext evaluationContext) {
        Map<String, Row> map;
        GeneralRegressionModel generalRegressionModel = (GeneralRegressionModel) getModel();
        Map<Object, Map<String, Row>> pPMatrixMap = getPPMatrixMap();
        if (pPMatrixMap.isEmpty()) {
            map = Collections.emptyMap();
        } else {
            map = pPMatrixMap.get(null);
            if (map == null) {
                throw new InvalidElementException(generalRegressionModel.getPPMatrix());
            }
        }
        Map<Object, List<PCell>> paramMatrixMap = getParamMatrixMap();
        List<PCell> list = paramMatrixMap.get(null);
        if (paramMatrixMap.size() != 1 || list == null) {
            throw new InvalidElementException(generalRegressionModel.getParamMatrix());
        }
        return computeDotProduct(valueFactory, list, map, evaluationContext);
    }

    private <V extends Number> Value<V> computeDotProduct(ValueFactory<V> valueFactory, Iterable<PCell> iterable, Map<String, Row> map, EvaluationContext evaluationContext) {
        Value<V> value = null;
        for (PCell pCell : iterable) {
            String parameterName = pCell.getParameterName();
            if (parameterName == null) {
                throw new MissingAttributeException(pCell, PMMLAttributes.PCELL_PARAMETERNAME);
            }
            Number beta = pCell.getBeta();
            if (beta == null) {
                throw new MissingAttributeException(pCell, PMMLAttributes.PCELL_BETA);
            }
            if (value == null) {
                value = valueFactory.newValue();
            }
            Row row = map.get(parameterName);
            if (row != null) {
                Value<V> evaluate = row.evaluate(valueFactory, evaluationContext);
                if (evaluate == null) {
                    return null;
                }
                value.add2(beta, evaluate.getValue());
            } else {
                value.add2(beta);
            }
        }
        return value;
    }

    /* JADX WARN: Multi-variable type inference failed */
    private <V extends Number> Value<V> computeReferencePoint(ValueFactory<V> valueFactory) {
        GeneralRegressionModel generalRegressionModel = (GeneralRegressionModel) getModel();
        BiMap<String, Parameter> parameterRegistry = getParameterRegistry();
        Map<Object, List<PCell>> paramMatrixMap = getParamMatrixMap();
        List<PCell> list = paramMatrixMap.get(null);
        if (paramMatrixMap.size() != 1 || list == null) {
            throw new InvalidElementException(generalRegressionModel.getParamMatrix());
        }
        Value<V> value = null;
        for (PCell pCell : list) {
            String parameterName = pCell.getParameterName();
            if (parameterName == null) {
                throw new MissingAttributeException(pCell, PMMLAttributes.PCELL_PARAMETERNAME);
            }
            Number beta = pCell.getBeta();
            if (beta == null) {
                throw new MissingAttributeException(pCell, PMMLAttributes.PCELL_BETA);
            }
            if (value == null) {
                value = valueFactory.newValue();
            }
            Parameter parameter = parameterRegistry.get(parameterName);
            if (parameter == null) {
                return null;
            }
            value.add2(beta, parameter.getReferencePoint());
        }
        return value;
    }

    /* JADX WARN: Multi-variable type inference failed */
    private <V extends Number> Value<V> computeLink(Value<V> value, EvaluationContext evaluationContext) {
        GeneralRegressionModel generalRegressionModel = (GeneralRegressionModel) getModel();
        GeneralRegressionModel.LinkFunction linkFunction = generalRegressionModel.getLinkFunction();
        if (linkFunction == null) {
            throw new MissingAttributeException(generalRegressionModel, PMMLAttributes.GENERALREGRESSIONMODEL_LINKFUNCTION);
        }
        Number distParameter = generalRegressionModel.getDistParameter();
        Number linkParameter = generalRegressionModel.getLinkParameter();
        switch (linkFunction) {
            case CLOGLOG:
            case IDENTITY:
            case LOG:
            case LOGC:
            case LOGIT:
            case LOGLOG:
            case PROBIT:
                if (distParameter != null) {
                    throw new MisplacedAttributeException(generalRegressionModel, PMMLAttributes.GENERALREGRESSIONMODEL_DISTPARAMETER, distParameter);
                }
                if (linkParameter != null) {
                    throw new MisplacedAttributeException(generalRegressionModel, PMMLAttributes.GENERALREGRESSIONMODEL_LINKPARAMETER, linkParameter);
                }
                break;
            case NEGBIN:
                if (distParameter == null) {
                    throw new MissingAttributeException(generalRegressionModel, PMMLAttributes.GENERALREGRESSIONMODEL_DISTPARAMETER);
                }
                if (linkParameter != null) {
                    throw new MisplacedAttributeException(generalRegressionModel, PMMLAttributes.GENERALREGRESSIONMODEL_LINKPARAMETER, linkParameter);
                }
                break;
            case ODDSPOWER:
            case POWER:
                if (distParameter != null) {
                    throw new MisplacedAttributeException(generalRegressionModel, PMMLAttributes.GENERALREGRESSIONMODEL_DISTPARAMETER, distParameter);
                }
                if (linkParameter == null) {
                    throw new MissingAttributeException(generalRegressionModel, PMMLAttributes.GENERALREGRESSIONMODEL_LINKPARAMETER);
                }
                break;
            default:
                throw new UnsupportedAttributeException(generalRegressionModel, linkFunction);
        }
        Number offset = getOffset(generalRegressionModel, evaluationContext);
        if (offset != null) {
            value.add2(offset);
        }
        switch (linkFunction) {
            case CLOGLOG:
            case IDENTITY:
            case LOG:
            case LOGC:
            case LOGIT:
            case LOGLOG:
            case PROBIT:
            case NEGBIN:
            case ODDSPOWER:
            case POWER:
                GeneralRegressionModelUtil.computeLink(linkFunction, distParameter, linkParameter, value);
                Integer trials = getTrials(generalRegressionModel, evaluationContext);
                if (trials != null) {
                    value.multiply2(trials);
                }
                return value;
            default:
                throw new UnsupportedAttributeException(generalRegressionModel, linkFunction);
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    private <V extends Number> Value<V> computeCumulativeLink(Value<V> value, EvaluationContext evaluationContext) {
        GeneralRegressionModel generalRegressionModel = (GeneralRegressionModel) getModel();
        GeneralRegressionModel.CumulativeLinkFunction cumulativeLinkFunction = generalRegressionModel.getCumulativeLinkFunction();
        if (cumulativeLinkFunction == null) {
            throw new MissingAttributeException(generalRegressionModel, PMMLAttributes.GENERALREGRESSIONMODEL_CUMULATIVELINKFUNCTION);
        }
        Number offset = getOffset(generalRegressionModel, evaluationContext);
        if (offset != null) {
            value.add2(offset);
        }
        switch (cumulativeLinkFunction) {
            case LOGIT:
            case PROBIT:
            case CLOGLOG:
            case LOGLOG:
            case CAUCHIT:
                GeneralRegressionModelUtil.computeCumulativeLink(cumulativeLinkFunction, value);
                return value;
            default:
                throw new UnsupportedAttributeException(generalRegressionModel, cumulativeLinkFunction);
        }
    }

    public BiMap<String, Parameter> getParameterRegistry() {
        if (this.parameterRegistry == null) {
            this.parameterRegistry = (BiMap) getValue(parameterCache);
        }
        return this.parameterRegistry;
    }

    private Map<Object, Map<String, Row>> getPPMatrixMap() {
        if (this.ppMatrixMap == null) {
            this.ppMatrixMap = (Map) getValue(ppMatrixCache);
        }
        return this.ppMatrixMap;
    }

    private Map<Object, List<PCell>> getParamMatrixMap() {
        if (this.paramMatrixMap == null) {
            this.paramMatrixMap = (Map) getValue(paramMatrixCache);
        }
        return this.paramMatrixMap;
    }

    private List<Object> getTargetCategories() {
        if (this.targetCategories == null) {
            this.targetCategories = ImmutableList.copyOf((Collection) parseTargetCategories());
        }
        return this.targetCategories;
    }

    /* JADX WARN: Multi-variable type inference failed */
    private List<Object> parseTargetCategories() {
        GeneralRegressionModel generalRegressionModel = (GeneralRegressionModel) getModel();
        TargetField targetField = getTargetField();
        switch (targetField.getOpType()) {
            case CATEGORICAL:
            case ORDINAL:
                List<Object> categories = targetField.getCategories();
                if (categories == null || categories.size() < 2) {
                    throw new InvalidElementException(generalRegressionModel);
                }
                Object targetReferenceCategory = generalRegressionModel.getTargetReferenceCategory();
                GeneralRegressionModel.ModelType modelType = generalRegressionModel.getModelType();
                switch (modelType) {
                    case COX_REGRESSION:
                    case REGRESSION:
                    case GENERAL_LINEAR:
                        throw new InvalidAttributeException(generalRegressionModel, modelType);
                    case GENERALIZED_LINEAR:
                    case MULTINOMIAL_LOGISTIC:
                        if (targetReferenceCategory == null) {
                            Map<Object, List<PCell>> paramMatrixMap = getParamMatrixMap();
                            Set set = (Set) categories.stream().filter(obj -> {
                                return !paramMatrixMap.containsKey(obj);
                            }).collect(Collectors.toSet());
                            if (set.size() == 1) {
                                targetReferenceCategory = Iterables.getOnlyElement(set);
                                break;
                            } else {
                                throw new InvalidElementException(generalRegressionModel.getParamMatrix());
                            }
                        }
                        break;
                    case ORDINAL_MULTINOMIAL:
                        break;
                    default:
                        throw new UnsupportedAttributeException(generalRegressionModel, modelType);
                }
                if (targetReferenceCategory != null) {
                    categories = new ArrayList(categories);
                    if (categories.remove(targetReferenceCategory)) {
                        categories.add(targetReferenceCategory);
                    }
                }
                return categories;
            default:
                throw new InvalidElementException(generalRegressionModel);
        }
    }

    private static Number getOffset(GeneralRegressionModel generalRegressionModel, EvaluationContext evaluationContext) {
        FieldName offsetVariable = generalRegressionModel.getOffsetVariable();
        return offsetVariable != null ? getVariable(offsetVariable, evaluationContext).asNumber() : generalRegressionModel.getOffsetValue();
    }

    private static Integer getTrials(GeneralRegressionModel generalRegressionModel, EvaluationContext evaluationContext) {
        FieldName trialsVariable = generalRegressionModel.getTrialsVariable();
        return trialsVariable != null ? getVariable(trialsVariable, evaluationContext).asInteger() : generalRegressionModel.getTrialsValue();
    }

    private static FieldValue getVariable(FieldName fieldName, EvaluationContext evaluationContext) {
        FieldValue evaluate = evaluationContext.evaluate(fieldName);
        if (FieldValueUtil.isMissing(evaluate)) {
            throw new MissingValueException(fieldName);
        }
        return evaluate;
    }

    /* JADX WARN: Multi-variable type inference failed */
    private static BaselineStratum getBaselineStratum(BaseCumHazardTables baseCumHazardTables, FieldValue fieldValue) {
        if (baseCumHazardTables instanceof MapHolder) {
            return (BaselineStratum) ((MapHolder) baseCumHazardTables).get(fieldValue.getDataType(), fieldValue.getValue());
        }
        for (BaselineStratum baselineStratum : baseCumHazardTables.getBaselineStrata()) {
            Object value = baselineStratum.getValue();
            if (value == null) {
                throw new MissingAttributeException(baselineStratum, PMMLAttributes.BASELINESTRATUM_VALUE);
            }
            if (fieldValue.equalsValue(value)) {
                return baselineStratum;
            }
        }
        return null;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static BiMap<String, Parameter> parseParameterRegistry(ParameterList parameterList) {
        HashBiMap create = HashBiMap.create();
        if (!parameterList.hasParameters()) {
            return create;
        }
        for (Parameter parameter : parameterList.getParameters()) {
            create.put(parameter.getName(), parameter);
        }
        return create;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static BiMap<FieldName, Predictor> parsePredictorRegistry(PredictorList predictorList) {
        HashBiMap create = HashBiMap.create();
        if (predictorList == null || !predictorList.hasPredictors()) {
            return create;
        }
        for (Predictor predictor : predictorList.getPredictors()) {
            create.put(predictor.getField(), predictor);
        }
        return create;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static Map<Object, Map<String, Row>> parsePPMatrix(final GeneralRegressionModel generalRegressionModel) {
        java.util.function.Function function = new Function<List<PPCell>, Row>() { // from class: org.jpmml.evaluator.general_regression.GeneralRegressionModelEvaluator.2
            private BiMap<FieldName, Predictor> factors;
            private BiMap<FieldName, Predictor> covariates;

            {
                this.factors = (BiMap) CacheUtil.getValue(GeneralRegressionModel.this, GeneralRegressionModelEvaluator.factorCache);
                this.covariates = (BiMap) CacheUtil.getValue(GeneralRegressionModel.this, GeneralRegressionModelEvaluator.covariateCache);
            }

            @Override // com.google.common.base.Function, java.util.function.Function
            public Row apply(List<PPCell> list) {
                Row row = new Row();
                for (PPCell pPCell : list) {
                    FieldName field = pPCell.getField();
                    if (field == null) {
                        throw new MissingAttributeException(pPCell, PMMLAttributes.PPCELL_FIELD);
                    }
                    Predictor predictor = this.factors.get(field);
                    if (predictor != null) {
                        row.addFactor(pPCell, predictor);
                    } else {
                        if (this.covariates.get(field) == null) {
                            throw new InvalidAttributeException(pPCell, PMMLAttributes.PPCELL_FIELD, field);
                        }
                        row.addCovariate(pPCell);
                    }
                }
                return row;
            }
        };
        ListMultimap groupByTargetCategory = groupByTargetCategory(generalRegressionModel.getPPMatrix().getPPCells());
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        for (Map.Entry entry : asMap(groupByTargetCategory).entrySet()) {
            LinkedHashMap linkedHashMap2 = new LinkedHashMap();
            for (Map.Entry entry2 : asMap(groupByParameterName((List) entry.getValue())).entrySet()) {
                linkedHashMap2.put(entry2.getKey(), function.apply(entry2.getValue()));
            }
            linkedHashMap.put(entry.getKey(), linkedHashMap2);
        }
        return linkedHashMap;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static Map<Object, List<PCell>> parseParamMatrix(GeneralRegressionModel generalRegressionModel) {
        return asMap(groupByTargetCategory(generalRegressionModel.getParamMatrix().getPCells()));
    }

    private static <K, C extends ParameterCell> Map<K, List<C>> asMap(ListMultimap<K, C> listMultimap) {
        return listMultimap.asMap();
    }

    private static <C extends ParameterCell> ListMultimap<String, C> groupByParameterName(List<C> list) {
        return groupCells(list, (v0) -> {
            return v0.getParameterName();
        });
    }

    private static <C extends ParameterCell> ListMultimap<Object, C> groupByTargetCategory(List<C> list) {
        return groupCells(list, (v0) -> {
            return v0.getTargetCategory();
        });
    }

    private static <K, C extends ParameterCell> ListMultimap<K, C> groupCells(List<C> list, Function<C, K> function) {
        ArrayListMultimap create = ArrayListMultimap.create();
        for (C c : list) {
            create.put(function.apply(c), c);
        }
        return create;
    }
}
