package org.jpmml.evaluator.support_vector_machine;

import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.common.collect.ImmutableMap;
import com.google.common.primitives.Doubles;
import com.google.common.primitives.Floats;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.dmg.pmml.Array;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.FieldRef;
import org.dmg.pmml.HasValue;
import org.dmg.pmml.MathContext;
import org.dmg.pmml.PMML;
import org.dmg.pmml.PMMLObject;
import org.dmg.pmml.RealSparseArray;
import org.dmg.pmml.regression.CategoricalPredictor;
import org.dmg.pmml.support_vector_machine.Coefficient;
import org.dmg.pmml.support_vector_machine.Coefficients;
import org.dmg.pmml.support_vector_machine.Kernel;
import org.dmg.pmml.support_vector_machine.PMMLAttributes;
import org.dmg.pmml.support_vector_machine.PMMLElements;
import org.dmg.pmml.support_vector_machine.SupportVector;
import org.dmg.pmml.support_vector_machine.SupportVectorMachine;
import org.dmg.pmml.support_vector_machine.SupportVectorMachineModel;
import org.dmg.pmml.support_vector_machine.VectorDictionary;
import org.dmg.pmml.support_vector_machine.VectorFields;
import org.dmg.pmml.support_vector_machine.VectorInstance;
import org.jpmml.evaluator.ArrayUtil;
import org.jpmml.evaluator.CacheUtil;
import org.jpmml.evaluator.Classification;
import org.jpmml.evaluator.EvaluationContext;
import org.jpmml.evaluator.EvaluationException;
import org.jpmml.evaluator.ExpressionUtil;
import org.jpmml.evaluator.FieldValue;
import org.jpmml.evaluator.FieldValueUtil;
import org.jpmml.evaluator.InvalidAttributeException;
import org.jpmml.evaluator.InvalidElementException;
import org.jpmml.evaluator.InvalidElementListException;
import org.jpmml.evaluator.MisplacedAttributeException;
import org.jpmml.evaluator.MisplacedElementException;
import org.jpmml.evaluator.MissingAttributeException;
import org.jpmml.evaluator.MissingElementException;
import org.jpmml.evaluator.MissingValueException;
import org.jpmml.evaluator.ModelEvaluator;
import org.jpmml.evaluator.Numbers;
import org.jpmml.evaluator.PMMLException;
import org.jpmml.evaluator.PMMLUtil;
import org.jpmml.evaluator.SparseArrayUtil;
import org.jpmml.evaluator.TargetUtil;
import org.jpmml.evaluator.UnsupportedAttributeException;
import org.jpmml.evaluator.Value;
import org.jpmml.evaluator.ValueFactory;
import org.jpmml.evaluator.ValueMap;
import org.jpmml.model.ReflectionUtil;
import org.jpmml.model.XPathUtil;

/* loaded from: input_file:WEB-INF/lib/pmml-evaluator-1.5.1.jar:org/jpmml/evaluator/support_vector_machine/SupportVectorMachineModelEvaluator.class */
public class SupportVectorMachineModelEvaluator extends ModelEvaluator<SupportVectorMachineModel> {
    private transient Map<String, Object> vectorMap;
    private static final LoadingCache<SupportVectorMachineModel, Map<String, Object>> vectorCache = CacheUtil.buildLoadingCache(new CacheLoader<SupportVectorMachineModel, Map<String, Object>>() { // from class: org.jpmml.evaluator.support_vector_machine.SupportVectorMachineModelEvaluator.2
        @Override // com.google.common.cache.CacheLoader
        public Map<String, Object> load(SupportVectorMachineModel supportVectorMachineModel) {
            return ImmutableMap.copyOf(SupportVectorMachineModelEvaluator.parseVectorDictionary(supportVectorMachineModel));
        }
    });

    public SupportVectorMachineModelEvaluator(PMML pmml) {
        this(pmml, (SupportVectorMachineModel) PMMLUtil.findModel(pmml, SupportVectorMachineModel.class));
    }

    public SupportVectorMachineModelEvaluator(PMML pmml, SupportVectorMachineModel supportVectorMachineModel) {
        super(pmml, supportVectorMachineModel);
        this.vectorMap = null;
        boolean isMaxWins = supportVectorMachineModel.isMaxWins();
        if (isMaxWins) {
            throw new UnsupportedAttributeException(supportVectorMachineModel, PMMLAttributes.SUPPORTVECTORMACHINEMODEL_MAXWINS, Boolean.valueOf(isMaxWins));
        }
        SupportVectorMachineModel.Representation representation = supportVectorMachineModel.getRepresentation();
        switch (representation) {
            case SUPPORT_VECTORS:
                VectorDictionary vectorDictionary = supportVectorMachineModel.getVectorDictionary();
                if (vectorDictionary == null) {
                    throw new MissingElementException(supportVectorMachineModel, PMMLElements.SUPPORTVECTORMACHINEMODEL_VECTORDICTIONARY);
                }
                if (vectorDictionary.getVectorFields() == null) {
                    throw new MissingElementException(vectorDictionary, PMMLElements.VECTORDICTIONARY_VECTORFIELDS);
                }
                if (!supportVectorMachineModel.hasSupportVectorMachines()) {
                    throw new MissingElementException(supportVectorMachineModel, PMMLElements.SUPPORTVECTORMACHINEMODEL_SUPPORTVECTORMACHINES);
                }
                return;
            default:
                throw new UnsupportedAttributeException(supportVectorMachineModel, representation);
        }
    }

    @Override // org.jpmml.evaluator.Evaluator
    public String getSummary() {
        return "Support vector machine";
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.jpmml.evaluator.ModelEvaluator
    protected <V extends Number> Map<FieldName, ?> evaluateRegression(ValueFactory<V> valueFactory, EvaluationContext evaluationContext) {
        List<SupportVectorMachine> supportVectorMachines = ((SupportVectorMachineModel) getModel()).getSupportVectorMachines();
        if (supportVectorMachines.size() != 1) {
            throw new InvalidElementListException(supportVectorMachines);
        }
        return TargetUtil.evaluateRegression(getTargetField(), evaluateSupportVectorMachine(valueFactory, supportVectorMachines.get(0), createInput(evaluationContext)));
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.jpmml.evaluator.ModelEvaluator
    protected <V extends Number> Map<FieldName, ? extends Classification<?, V>> evaluateClassification(final ValueFactory<V> valueFactory, EvaluationContext evaluationContext) {
        ValueMap valueMap;
        Classification voteProbabilityDistribution;
        Object obj;
        SupportVectorMachineModel supportVectorMachineModel = (SupportVectorMachineModel) getModel();
        List<SupportVectorMachine> supportVectorMachines = supportVectorMachineModel.getSupportVectorMachines();
        Object alternateBinaryTargetCategory = supportVectorMachineModel.getAlternateBinaryTargetCategory();
        SupportVectorMachineModel.ClassificationMethod classificationMethod = getClassificationMethod();
        switch (classificationMethod) {
            case ONE_AGAINST_ALL:
                valueMap = new ValueMap(2 * supportVectorMachines.size());
                break;
            case ONE_AGAINST_ONE:
                valueMap = new VoteMap<Object, V>(2 * supportVectorMachines.size()) { // from class: org.jpmml.evaluator.support_vector_machine.SupportVectorMachineModelEvaluator.1
                    @Override // org.jpmml.evaluator.ValueMap
                    public ValueFactory<V> getValueFactory() {
                        return valueFactory;
                    }
                };
                break;
            default:
                throw new UnsupportedAttributeException(supportVectorMachineModel, classificationMethod);
        }
        Object createInput = createInput(evaluationContext);
        for (SupportVectorMachine supportVectorMachine : supportVectorMachines) {
            Object targetCategory = supportVectorMachine.getTargetCategory();
            if (targetCategory == null) {
                throw new MissingAttributeException(supportVectorMachine, PMMLAttributes.SUPPORTVECTORMACHINE_TARGETCATEGORY);
            }
            Object alternateTargetCategory = supportVectorMachine.getAlternateTargetCategory();
            Value<V> evaluateSupportVectorMachine = evaluateSupportVectorMachine(valueFactory, supportVectorMachine, createInput);
            switch (classificationMethod) {
                case ONE_AGAINST_ALL:
                    if (alternateTargetCategory != null) {
                        throw new MisplacedAttributeException(supportVectorMachine, PMMLAttributes.SUPPORTVECTORMACHINE_ALTERNATETARGETCATEGORY, alternateTargetCategory);
                    }
                    valueMap.put(targetCategory, evaluateSupportVectorMachine);
                    break;
                case ONE_AGAINST_ONE:
                    if (alternateBinaryTargetCategory != null) {
                        if (alternateTargetCategory != null) {
                            throw new MisplacedAttributeException(supportVectorMachine, PMMLAttributes.SUPPORTVECTORMACHINE_ALTERNATETARGETCATEGORY, alternateTargetCategory);
                        }
                        evaluateSupportVectorMachine.round2();
                        if (evaluateSupportVectorMachine.isZero()) {
                            obj = alternateBinaryTargetCategory;
                        } else {
                            if (!evaluateSupportVectorMachine.isOne()) {
                                throw new EvaluationException("Expected " + PMMLException.formatValue(Numbers.DOUBLE_ZERO) + " or " + PMMLException.formatValue(Numbers.DOUBLE_ONE) + ", got " + PMMLException.formatValue(evaluateSupportVectorMachine.getValue()));
                            }
                            obj = targetCategory;
                        }
                    } else {
                        if (alternateTargetCategory == null) {
                            throw new MissingAttributeException(supportVectorMachine, PMMLAttributes.SUPPORTVECTORMACHINE_ALTERNATETARGETCATEGORY);
                        }
                        Number threshold = supportVectorMachine.getThreshold();
                        if (threshold == null) {
                            threshold = supportVectorMachineModel.getThreshold();
                        }
                        obj = evaluateSupportVectorMachine.compareTo(threshold) < 0 ? targetCategory : alternateTargetCategory;
                    }
                    ((VoteMap) valueMap).increment(obj);
                    break;
            }
        }
        switch (classificationMethod) {
            case ONE_AGAINST_ALL:
                voteProbabilityDistribution = new DistanceDistribution(valueMap);
                break;
            case ONE_AGAINST_ONE:
                voteProbabilityDistribution = new VoteProbabilityDistribution(valueMap);
                break;
            default:
                throw new UnsupportedAttributeException(supportVectorMachineModel, classificationMethod);
        }
        return TargetUtil.evaluateClassification(getTargetField(), voteProbabilityDistribution);
    }

    /* JADX WARN: Multi-variable type inference failed */
    private <V extends Number> Value<V> evaluateSupportVectorMachine(ValueFactory<V> valueFactory, SupportVectorMachine supportVectorMachine, Object obj) {
        SupportVectorMachineModel supportVectorMachineModel = (SupportVectorMachineModel) getModel();
        Value<V> newValue = valueFactory.newValue();
        Kernel kernel = supportVectorMachineModel.getKernel();
        if (kernel == null) {
            throw new MissingElementException(MissingElementException.formatMessage(XPathUtil.formatElement(supportVectorMachineModel.getClass()) + "/<Kernel>"), supportVectorMachine);
        }
        Coefficients coefficients = supportVectorMachine.getCoefficients();
        Iterator<Coefficient> it = coefficients.iterator();
        Iterator<SupportVector> it2 = supportVectorMachine.getSupportVectors().iterator();
        Map<String, Object> vectorMap = getVectorMap();
        while (it.hasNext() && it2.hasNext()) {
            Coefficient next = it.next();
            SupportVector next2 = it2.next();
            String vectorId = next2.getVectorId();
            if (vectorId == null) {
                throw new MissingAttributeException(next2, PMMLAttributes.SUPPORTVECTOR_VECTORID);
            }
            Object obj2 = vectorMap.get(vectorId);
            if (obj2 == null) {
                throw new InvalidAttributeException(next2, PMMLAttributes.SUPPORTVECTOR_VECTORID, vectorId);
            }
            newValue.add2(next.getValue(), KernelUtil.evaluate(kernel, valueFactory, obj, obj2).getValue());
        }
        if (it.hasNext() || it2.hasNext()) {
            throw new InvalidElementException(supportVectorMachine);
        }
        newValue.add2(coefficients.getAbsoluteValue());
        return newValue;
    }

    /* JADX WARN: Multi-variable type inference failed */
    private SupportVectorMachineModel.ClassificationMethod getClassificationMethod() {
        SupportVectorMachineModel supportVectorMachineModel = (SupportVectorMachineModel) getModel();
        SupportVectorMachineModel.ClassificationMethod classificationMethod = (SupportVectorMachineModel.ClassificationMethod) ReflectionUtil.getFieldValue(PMMLAttributes.SUPPORTVECTORMACHINEMODEL_CLASSIFICATIONMETHOD, supportVectorMachineModel);
        if (classificationMethod != null) {
            return classificationMethod;
        }
        List<SupportVectorMachine> supportVectorMachines = supportVectorMachineModel.getSupportVectorMachines();
        if (supportVectorMachineModel.getAlternateBinaryTargetCategory() != null) {
            if (supportVectorMachines.size() != 1) {
                throw new InvalidElementException(supportVectorMachineModel);
            }
            SupportVectorMachine supportVectorMachine = supportVectorMachines.get(0);
            if (supportVectorMachine.getTargetCategory() != null) {
                return SupportVectorMachineModel.ClassificationMethod.ONE_AGAINST_ONE;
            }
            throw new InvalidElementException(supportVectorMachine);
        }
        Iterator<SupportVectorMachine> it = supportVectorMachines.iterator();
        if (!it.hasNext()) {
            throw new InvalidElementException(supportVectorMachineModel);
        }
        SupportVectorMachine next = it.next();
        Object targetCategory = next.getTargetCategory();
        Object alternateTargetCategory = next.getAlternateTargetCategory();
        if (targetCategory != null) {
            return alternateTargetCategory != null ? SupportVectorMachineModel.ClassificationMethod.ONE_AGAINST_ONE : SupportVectorMachineModel.ClassificationMethod.ONE_AGAINST_ALL;
        }
        throw new InvalidElementException(next);
    }

    /* JADX WARN: Multi-variable type inference failed */
    private Object createInput(EvaluationContext evaluationContext) {
        SupportVectorMachineModel supportVectorMachineModel = (SupportVectorMachineModel) getModel();
        VectorFields vectorFields = supportVectorMachineModel.getVectorDictionary().getVectorFields();
        List<PMMLObject> content = vectorFields.getContent();
        ArrayList arrayList = new ArrayList(content.size());
        for (int i = 0; i < content.size(); i++) {
            PMMLObject pMMLObject = content.get(i);
            if (pMMLObject instanceof FieldRef) {
                FieldRef fieldRef = (FieldRef) pMMLObject;
                FieldName field = fieldRef.getField();
                FieldValue evaluate = ExpressionUtil.evaluate(fieldRef, evaluationContext);
                if (FieldValueUtil.isMissing(evaluate)) {
                    throw new MissingValueException(field, vectorFields);
                }
                arrayList.add(evaluate.asNumber());
            } else {
                if (!(pMMLObject instanceof CategoricalPredictor)) {
                    throw new MisplacedElementException(pMMLObject);
                }
                CategoricalPredictor categoricalPredictor = (CategoricalPredictor) pMMLObject;
                FieldName name = categoricalPredictor.getName();
                if (name == null) {
                    throw new MissingAttributeException(categoricalPredictor, org.dmg.pmml.regression.PMMLAttributes.CATEGORICALPREDICTOR_FIELD);
                }
                FieldValue evaluate2 = evaluationContext.evaluate(name);
                if (FieldValueUtil.isMissing(evaluate2)) {
                    throw new MissingValueException(name, categoricalPredictor);
                }
                Number coefficient = categoricalPredictor.getCoefficient();
                if (coefficient != null && coefficient.doubleValue() != 1.0d) {
                    throw new InvalidAttributeException(categoricalPredictor, org.dmg.pmml.regression.PMMLAttributes.CATEGORICALPREDICTOR_COEFFICIENT, coefficient);
                }
                arrayList.add(evaluate2.equals((HasValue<?>) categoricalPredictor) ? Numbers.DOUBLE_ONE : Numbers.DOUBLE_ZERO);
            }
        }
        return toArray(supportVectorMachineModel, arrayList);
    }

    private Map<String, Object> getVectorMap() {
        if (this.vectorMap == null) {
            this.vectorMap = (Map) getValue(vectorCache);
        }
        return this.vectorMap;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static Map<String, Object> parseVectorDictionary(SupportVectorMachineModel supportVectorMachineModel) {
        List<? extends Number> asNumberList;
        VectorDictionary vectorDictionary = supportVectorMachineModel.getVectorDictionary();
        List<PMMLObject> content = vectorDictionary.getVectorFields().getContent();
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        for (VectorInstance vectorInstance : vectorDictionary.getVectorInstances()) {
            String id = vectorInstance.getId();
            if (id == null) {
                throw new MissingAttributeException(vectorInstance, PMMLAttributes.VECTORINSTANCE_ID);
            }
            Array array = vectorInstance.getArray();
            RealSparseArray realSparseArray = vectorInstance.getRealSparseArray();
            if (array != null && realSparseArray == null) {
                asNumberList = ArrayUtil.asNumberList(array);
            } else {
                if (array != null || realSparseArray == null) {
                    throw new InvalidElementException(vectorInstance);
                }
                asNumberList = SparseArrayUtil.asNumberList(realSparseArray);
            }
            if (content.size() != asNumberList.size()) {
                throw new InvalidElementException(vectorInstance);
            }
            linkedHashMap.put(id, toArray(supportVectorMachineModel, asNumberList));
        }
        return linkedHashMap;
    }

    private static Object toArray(SupportVectorMachineModel supportVectorMachineModel, List<? extends Number> list) {
        MathContext mathContext = supportVectorMachineModel.getMathContext();
        switch (mathContext) {
            case FLOAT:
                return Floats.toArray(list);
            case DOUBLE:
                return Doubles.toArray(list);
            default:
                throw new UnsupportedAttributeException(supportVectorMachineModel, mathContext);
        }
    }
}
