package org.jpmml.evaluator;

import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Maps;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.dmg.pmml.Array;
import org.dmg.pmml.Coefficient;
import org.dmg.pmml.Coefficients;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.FieldRef;
import org.dmg.pmml.KernelType;
import org.dmg.pmml.MiningFunctionType;
import org.dmg.pmml.PMML;
import org.dmg.pmml.RealSparseArray;
import org.dmg.pmml.SupportVector;
import org.dmg.pmml.SupportVectorMachine;
import org.dmg.pmml.SupportVectorMachineModel;
import org.dmg.pmml.SvmClassificationMethodType;
import org.dmg.pmml.SvmRepresentationType;
import org.dmg.pmml.VectorDictionary;
import org.dmg.pmml.VectorFields;
import org.dmg.pmml.VectorInstance;
import org.jpmml.evaluator.ClassificationMap;
import org.jpmml.manager.InvalidFeatureException;
import org.jpmml.manager.PMMLObjectUtil;
import org.jpmml.manager.UnsupportedFeatureException;

/* loaded from: input_file:WEB-INF/lib/pmml-evaluator-1.1.13.jar:org/jpmml/evaluator/SupportVectorMachineModelEvaluator.class */
public class SupportVectorMachineModelEvaluator extends ModelEvaluator<SupportVectorMachineModel> {
    private static final LoadingCache<SupportVectorMachineModel, Map<String, double[]>> vectorCache = CacheBuilder.newBuilder().weakKeys().build(new CacheLoader<SupportVectorMachineModel, Map<String, double[]>>() { // from class: org.jpmml.evaluator.SupportVectorMachineModelEvaluator.1
        @Override // com.google.common.cache.CacheLoader
        public Map<String, double[]> load(SupportVectorMachineModel supportVectorMachineModel) {
            return ImmutableMap.copyOf(SupportVectorMachineModelEvaluator.parseVectorDictionary(supportVectorMachineModel));
        }
    });

    public SupportVectorMachineModelEvaluator(PMML pmml) {
        this(pmml, (SupportVectorMachineModel) find(pmml.getModels(), SupportVectorMachineModel.class));
    }

    public SupportVectorMachineModelEvaluator(PMML pmml, SupportVectorMachineModel supportVectorMachineModel) {
        super(pmml, supportVectorMachineModel);
    }

    @Override // org.jpmml.manager.ModelManager, org.jpmml.manager.Consumer
    public String getSummary() {
        return "Support vector machine";
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.jpmml.evaluator.ModelEvaluator
    public Map<FieldName, ?> evaluate(ModelEvaluationContext modelEvaluationContext) {
        Map<FieldName, ? extends Number> evaluateClassification;
        SupportVectorMachineModel supportVectorMachineModel = (SupportVectorMachineModel) getModel();
        if (!supportVectorMachineModel.isScorable()) {
            throw new InvalidResultException(supportVectorMachineModel);
        }
        SvmRepresentationType svmRepresentation = supportVectorMachineModel.getSvmRepresentation();
        switch (svmRepresentation) {
            case SUPPORT_VECTORS:
                MiningFunctionType functionName = supportVectorMachineModel.getFunctionName();
                switch (functionName) {
                    case REGRESSION:
                        evaluateClassification = evaluateRegression(modelEvaluationContext);
                        break;
                    case CLASSIFICATION:
                        evaluateClassification = evaluateClassification(modelEvaluationContext);
                        break;
                    default:
                        throw new UnsupportedFeatureException(supportVectorMachineModel, functionName);
                }
                return OutputUtil.evaluate(evaluateClassification, modelEvaluationContext);
            default:
                throw new UnsupportedFeatureException(supportVectorMachineModel, svmRepresentation);
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    private Map<FieldName, ? extends Number> evaluateRegression(ModelEvaluationContext modelEvaluationContext) {
        SupportVectorMachineModel supportVectorMachineModel = (SupportVectorMachineModel) getModel();
        List<SupportVectorMachine> supportVectorMachines = supportVectorMachineModel.getSupportVectorMachines();
        if (supportVectorMachines.size() != 1) {
            throw new InvalidFeatureException(supportVectorMachineModel);
        }
        return TargetUtil.evaluateRegression(evaluateSupportVectorMachine(supportVectorMachines.get(0), createInput(modelEvaluationContext)), modelEvaluationContext);
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v82, types: [org.jpmml.evaluator.ClassificationMap] */
    private Map<FieldName, ? extends ClassificationMap<?>> evaluateClassification(ModelEvaluationContext modelEvaluationContext) {
        VoteClassificationMap voteClassificationMap;
        String str;
        SupportVectorMachineModel supportVectorMachineModel = (SupportVectorMachineModel) getModel();
        List<SupportVectorMachine> supportVectorMachines = supportVectorMachineModel.getSupportVectorMachines();
        if (supportVectorMachines.size() < 1) {
            throw new InvalidFeatureException(supportVectorMachineModel);
        }
        String alternateBinaryTargetCategory = supportVectorMachineModel.getAlternateBinaryTargetCategory();
        SvmClassificationMethodType classificationMethod = getClassificationMethod();
        switch (classificationMethod) {
            case ONE_AGAINST_ALL:
                voteClassificationMap = new ClassificationMap(ClassificationMap.Type.DISTANCE);
                break;
            case ONE_AGAINST_ONE:
                voteClassificationMap = new VoteClassificationMap();
                break;
            default:
                throw new UnsupportedFeatureException(supportVectorMachineModel, classificationMethod);
        }
        double[] createInput = createInput(modelEvaluationContext);
        for (SupportVectorMachine supportVectorMachine : supportVectorMachines) {
            String targetCategory = supportVectorMachine.getTargetCategory();
            String alternateTargetCategory = supportVectorMachine.getAlternateTargetCategory();
            Double evaluateSupportVectorMachine = evaluateSupportVectorMachine(supportVectorMachine, createInput);
            switch (classificationMethod) {
                case ONE_AGAINST_ALL:
                    if (targetCategory == null || alternateTargetCategory != null) {
                        throw new InvalidFeatureException(supportVectorMachine);
                    }
                    voteClassificationMap.put(targetCategory, evaluateSupportVectorMachine);
                    break;
                case ONE_AGAINST_ONE:
                    if (alternateBinaryTargetCategory == null) {
                        if (targetCategory != null && alternateTargetCategory != null) {
                            Double threshold = supportVectorMachine.getThreshold();
                            if (threshold == null) {
                                threshold = Double.valueOf(supportVectorMachineModel.getThreshold());
                            }
                            String str2 = evaluateSupportVectorMachine.compareTo(threshold) < 0 ? targetCategory : alternateTargetCategory;
                            Double d = voteClassificationMap.get(str2);
                            if (d == null) {
                                d = Double.valueOf(0.0d);
                            }
                            voteClassificationMap.put(str2, Double.valueOf(d.doubleValue() + 1.0d));
                            break;
                        } else {
                            throw new InvalidFeatureException(supportVectorMachine);
                        }
                    } else {
                        if (targetCategory == null || alternateTargetCategory != null) {
                            throw new InvalidFeatureException(supportVectorMachine);
                        }
                        long round = Math.round(evaluateSupportVectorMachine.doubleValue());
                        if (round == 1) {
                            str = targetCategory;
                        } else {
                            if (round != 0) {
                                throw new EvaluationException("Invalid numeric prediction " + evaluateSupportVectorMachine);
                            }
                            str = alternateBinaryTargetCategory;
                        }
                        Double d2 = voteClassificationMap.get(str);
                        if (d2 == null) {
                            d2 = Double.valueOf(0.0d);
                        }
                        voteClassificationMap.put(str, Double.valueOf(d2.doubleValue() + 1.0d));
                        break;
                    }
                    break;
            }
        }
        return TargetUtil.evaluateClassification((ClassificationMap<?>) voteClassificationMap, modelEvaluationContext);
    }

    /* JADX WARN: Multi-variable type inference failed */
    private Double evaluateSupportVectorMachine(SupportVectorMachine supportVectorMachine, double[] dArr) {
        double d = 0.0d;
        KernelType kernelType = ((SupportVectorMachineModel) getModel()).getKernelType();
        Coefficients coefficients = supportVectorMachine.getCoefficients();
        Iterator<Coefficient> it = coefficients.iterator();
        Iterator<SupportVector> it2 = supportVectorMachine.getSupportVectors().iterator();
        Map<String, double[]> vectorMap = getVectorMap();
        while (it.hasNext() && it2.hasNext()) {
            Coefficient next = it.next();
            SupportVector next2 = it2.next();
            double[] dArr2 = vectorMap.get(next2.getVectorId());
            if (dArr2 == null) {
                throw new InvalidFeatureException(next2);
            }
            d += next.getValue() * Double.valueOf(KernelTypeUtil.evaluate(kernelType, dArr, dArr2)).doubleValue();
        }
        if (it.hasNext() || it2.hasNext()) {
            throw new InvalidFeatureException(supportVectorMachine);
        }
        return Double.valueOf(d + coefficients.getAbsoluteValue());
    }

    /* JADX WARN: Multi-variable type inference failed */
    private SvmClassificationMethodType getClassificationMethod() {
        SupportVectorMachineModel supportVectorMachineModel = (SupportVectorMachineModel) getModel();
        SvmClassificationMethodType svmClassificationMethodType = (SvmClassificationMethodType) PMMLObjectUtil.getAttributeValue(supportVectorMachineModel, "classificationMethod");
        if (svmClassificationMethodType != null) {
            return svmClassificationMethodType;
        }
        List<SupportVectorMachine> supportVectorMachines = supportVectorMachineModel.getSupportVectorMachines();
        if (supportVectorMachineModel.getAlternateBinaryTargetCategory() != null) {
            if (supportVectorMachines.size() != 1) {
                throw new InvalidFeatureException(supportVectorMachineModel);
            }
            SupportVectorMachine supportVectorMachine = supportVectorMachines.get(0);
            if (supportVectorMachine.getTargetCategory() != null) {
                return SvmClassificationMethodType.ONE_AGAINST_ONE;
            }
            throw new InvalidFeatureException(supportVectorMachine);
        }
        Iterator<SupportVectorMachine> it = supportVectorMachines.iterator();
        if (!it.hasNext()) {
            throw new InvalidFeatureException(supportVectorMachineModel);
        }
        SupportVectorMachine next = it.next();
        String targetCategory = next.getTargetCategory();
        String alternateTargetCategory = next.getAlternateTargetCategory();
        if (targetCategory != null) {
            return alternateTargetCategory != null ? SvmClassificationMethodType.ONE_AGAINST_ONE : SvmClassificationMethodType.ONE_AGAINST_ALL;
        }
        throw new InvalidFeatureException(next);
    }

    /* JADX WARN: Multi-variable type inference failed */
    private double[] createInput(EvaluationContext evaluationContext) {
        VectorFields vectorFields = ((SupportVectorMachineModel) getModel()).getVectorDictionary().getVectorFields();
        List<FieldRef> fieldRefs = vectorFields.getFieldRefs();
        double[] dArr = new double[fieldRefs.size()];
        for (int i = 0; i < fieldRefs.size(); i++) {
            FieldRef fieldRef = fieldRefs.get(i);
            FieldValue evaluate = ExpressionUtil.evaluate(fieldRef, evaluationContext);
            if (evaluate == null) {
                throw new MissingFieldException(fieldRef.getField(), vectorFields);
            }
            dArr[i] = evaluate.asNumber().doubleValue();
        }
        Integer numberOfFields = vectorFields.getNumberOfFields();
        if (numberOfFields == null || numberOfFields.intValue() == dArr.length) {
            return dArr;
        }
        throw new InvalidFeatureException(vectorFields);
    }

    private Map<String, double[]> getVectorMap() {
        return (Map) getValue(vectorCache);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static Map<String, double[]> parseVectorDictionary(SupportVectorMachineModel supportVectorMachineModel) {
        double[] array;
        VectorDictionary vectorDictionary = supportVectorMachineModel.getVectorDictionary();
        VectorFields vectorFields = vectorDictionary.getVectorFields();
        LinkedHashMap newLinkedHashMap = Maps.newLinkedHashMap();
        for (VectorInstance vectorInstance : vectorDictionary.getVectorInstances()) {
            String id = vectorInstance.getId();
            if (id == null) {
                throw new InvalidFeatureException(vectorInstance);
            }
            Array array2 = vectorInstance.getArray();
            RealSparseArray rEALSparseArray = vectorInstance.getREALSparseArray();
            if (array2 != null && rEALSparseArray == null) {
                array = ArrayUtil.toArray(array2);
            } else {
                if (array2 != null || rEALSparseArray == null) {
                    throw new InvalidFeatureException(vectorInstance);
                }
                array = SparseArrayUtil.toArray(rEALSparseArray);
            }
            Integer numberOfFields = vectorFields.getNumberOfFields();
            if (numberOfFields != null && numberOfFields.intValue() != array.length) {
                throw new InvalidFeatureException(vectorInstance);
            }
            newLinkedHashMap.put(id, array);
        }
        Integer numberOfVectors = vectorDictionary.getNumberOfVectors();
        if (numberOfVectors == null || numberOfVectors.intValue() == newLinkedHashMap.size()) {
            return newLinkedHashMap;
        }
        throw new InvalidFeatureException(vectorDictionary);
    }
}
