package org.jpmml.evaluator.visitors;

import com.google.common.collect.Iterables;
import java.util.ArrayDeque;
import java.util.Collections;
import java.util.Deque;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.Map;
import org.dmg.pmml.DataType;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.MiningField;
import org.dmg.pmml.MiningSchema;
import org.dmg.pmml.Model;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.PMMLAttributes;
import org.dmg.pmml.PMMLFunctions;
import org.dmg.pmml.PMMLObject;
import org.dmg.pmml.ScoreDistribution;
import org.dmg.pmml.Target;
import org.dmg.pmml.TargetValue;
import org.dmg.pmml.VisitorAction;
import org.dmg.pmml.general_regression.GeneralRegressionModel;
import org.dmg.pmml.general_regression.PCell;
import org.dmg.pmml.general_regression.PCovCell;
import org.dmg.pmml.general_regression.PPCell;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.naive_bayes.TargetValueCount;
import org.dmg.pmml.naive_bayes.TargetValueStat;
import org.dmg.pmml.regression.RegressionTable;
import org.dmg.pmml.rule_set.RuleSet;
import org.dmg.pmml.rule_set.SimpleRule;
import org.dmg.pmml.support_vector_machine.SupportVectorMachine;
import org.dmg.pmml.support_vector_machine.SupportVectorMachineModel;
import org.dmg.pmml.tree.Node;
import org.dmg.pmml.tree.TreeModel;
import org.jpmml.evaluator.Evaluator;
import org.jpmml.evaluator.MissingAttributeException;
import org.jpmml.evaluator.MissingElementException;
import org.jpmml.evaluator.TypeUtil;
import org.jpmml.model.XPathUtil;

/* loaded from: input_file:WEB-INF/lib/pmml-evaluator-1.5.16.jar:org/jpmml/evaluator/visitors/TargetCategoryParser.class */
public class TargetCategoryParser extends AbstractParser {
    private Deque<Map<FieldName, DataType>> targetDataTypes = new ArrayDeque();
    private DataType dataType = null;

    @Override // org.jpmml.evaluator.visitors.AbstractParser, org.jpmml.model.visitors.FieldResolver, org.jpmml.model.visitors.Resettable
    public void reset() {
        super.reset();
        this.targetDataTypes.clear();
        this.dataType = null;
    }

    @Override // org.dmg.pmml.VisitContext
    public void pushParent(PMMLObject pMMLObject) {
        super.pushParent(pMMLObject);
        if (pMMLObject instanceof MiningModel) {
            processMiningModel((MiningModel) pMMLObject);
        } else if (pMMLObject instanceof Model) {
            processModel((Model) pMMLObject);
        }
    }

    @Override // org.jpmml.evaluator.visitors.AbstractParser, org.jpmml.model.visitors.FieldResolver, org.dmg.pmml.VisitContext
    public PMMLObject popParent() {
        PMMLObject popParent = super.popParent();
        if (popParent instanceof Model) {
            this.targetDataTypes.pop();
            this.dataType = null;
        }
        return popParent;
    }

    @Override // org.jpmml.model.visitors.AbstractVisitor, org.dmg.pmml.Visitor
    public VisitorAction visit(GeneralRegressionModel generalRegressionModel) {
        generalRegressionModel.setTargetReferenceCategory(parseTargetValue(generalRegressionModel.getTargetReferenceCategory()));
        return super.visit(generalRegressionModel);
    }

    @Override // org.jpmml.model.visitors.AbstractVisitor, org.dmg.pmml.Visitor
    public VisitorAction visit(Node node) {
        if (getParent() instanceof TreeModel) {
            switch (((TreeModel) r0).getMiningFunction()) {
                case CLASSIFICATION:
                    break;
                default:
                    return VisitorAction.SKIP;
            }
        }
        node.setScore(parseTargetValue(node.getScore()));
        return super.visit(node);
    }

    @Override // org.jpmml.model.visitors.AbstractVisitor, org.dmg.pmml.Visitor
    public VisitorAction visit(OutputField outputField) {
        switch (outputField.getResultFeature()) {
            case PROBABILITY:
            case CONFIDENCE:
            case AFFINITY:
                outputField.setValue(parseTargetValue(outputField.getTargetField(), outputField.getValue()));
                break;
        }
        return super.visit(outputField);
    }

    @Override // org.jpmml.model.visitors.AbstractVisitor, org.dmg.pmml.Visitor
    public VisitorAction visit(PCell pCell) {
        pCell.setTargetCategory(parseTargetValue(pCell.getTargetCategory()));
        return super.visit(pCell);
    }

    @Override // org.jpmml.model.visitors.AbstractVisitor, org.dmg.pmml.Visitor
    public VisitorAction visit(PCovCell pCovCell) {
        pCovCell.setTargetCategory(parseTargetValue(pCovCell.getTargetCategory()));
        return super.visit(pCovCell);
    }

    @Override // org.jpmml.model.visitors.AbstractVisitor, org.dmg.pmml.Visitor
    public VisitorAction visit(PPCell pPCell) {
        pPCell.setTargetCategory(parseTargetValue(pPCell.getTargetCategory()));
        return super.visit(pPCell);
    }

    @Override // org.jpmml.model.visitors.AbstractVisitor, org.dmg.pmml.Visitor
    public VisitorAction visit(RegressionTable regressionTable) {
        regressionTable.setTargetCategory(parseTargetValue(regressionTable.getTargetCategory()));
        return super.visit(regressionTable);
    }

    @Override // org.jpmml.model.visitors.AbstractVisitor, org.dmg.pmml.Visitor
    public VisitorAction visit(RuleSet ruleSet) {
        ruleSet.setDefaultScore(parseTargetValue(ruleSet.getDefaultScore()));
        return super.visit(ruleSet);
    }

    @Override // org.jpmml.model.visitors.AbstractVisitor, org.dmg.pmml.Visitor
    public VisitorAction visit(ScoreDistribution scoreDistribution) {
        Object value = scoreDistribution.getValue();
        if (value == null) {
            throw new MissingAttributeException(scoreDistribution, PMMLAttributes.SCOREDISTRIBUTION_VALUE);
        }
        scoreDistribution.setValue(parseTargetValue(value));
        return super.visit(scoreDistribution);
    }

    @Override // org.jpmml.model.visitors.AbstractVisitor, org.dmg.pmml.Visitor
    public VisitorAction visit(SimpleRule simpleRule) {
        if (simpleRule.getScore() == null) {
            throw new MissingAttributeException(simpleRule, org.dmg.pmml.rule_set.PMMLAttributes.SIMPLERULE_SCORE);
        }
        simpleRule.setScore(parseTargetValue(simpleRule.getScore()));
        return super.visit(simpleRule);
    }

    @Override // org.jpmml.model.visitors.AbstractVisitor, org.dmg.pmml.Visitor
    public VisitorAction visit(SupportVectorMachine supportVectorMachine) {
        supportVectorMachine.setTargetCategory(parseTargetValue(supportVectorMachine.getTargetCategory()));
        supportVectorMachine.setAlternateTargetCategory(parseTargetValue(supportVectorMachine.getAlternateTargetCategory()));
        return super.visit(supportVectorMachine);
    }

    @Override // org.jpmml.model.visitors.AbstractVisitor, org.dmg.pmml.Visitor
    public VisitorAction visit(SupportVectorMachineModel supportVectorMachineModel) {
        supportVectorMachineModel.setAlternateBinaryTargetCategory(parseTargetValue(supportVectorMachineModel.getAlternateBinaryTargetCategory()));
        return super.visit(supportVectorMachineModel);
    }

    @Override // org.jpmml.model.visitors.AbstractVisitor, org.dmg.pmml.Visitor
    public VisitorAction visit(TargetValue targetValue) {
        targetValue.setValue(parseTargetValue(((Target) getParent()).getField(), targetValue.getValue()));
        return super.visit(targetValue);
    }

    @Override // org.jpmml.model.visitors.AbstractVisitor, org.dmg.pmml.Visitor
    public VisitorAction visit(TargetValueCount targetValueCount) {
        Object value = targetValueCount.getValue();
        if (value == null) {
            throw new MissingAttributeException(targetValueCount, org.dmg.pmml.naive_bayes.PMMLAttributes.TARGETVALUECOUNT_VALUE);
        }
        targetValueCount.setValue(parseTargetValue(value));
        return super.visit(targetValueCount);
    }

    @Override // org.jpmml.model.visitors.AbstractVisitor, org.dmg.pmml.Visitor
    public VisitorAction visit(TargetValueStat targetValueStat) {
        Object value = targetValueStat.getValue();
        if (value == null) {
            throw new MissingAttributeException(targetValueStat, org.dmg.pmml.naive_bayes.PMMLAttributes.TARGETVALUESTAT_VALUE);
        }
        targetValueStat.setValue(parseTargetValue(value));
        return super.visit(targetValueStat);
    }

    private void processMiningModel(MiningModel miningModel) {
        if (miningModel.getSegmentation() != null) {
            switch (r0.getMultipleModelMethod()) {
                case SELECT_FIRST:
                case SELECT_ALL:
                case MODEL_CHAIN:
                    this.targetDataTypes.push(Collections.singletonMap(Evaluator.DEFAULT_TARGET_NAME, null));
                    this.dataType = null;
                    return;
            }
        }
        processModel(miningModel);
    }

    private void processModel(Model model) {
        MiningSchema miningSchema = model.getMiningSchema();
        if (miningSchema == null) {
            throw new MissingElementException(MissingElementException.formatMessage(XPathUtil.formatElement(model.getClass()) + PMMLFunctions.DIVIDE + XPathUtil.formatElement(MiningSchema.class)), model);
        }
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        if (miningSchema.hasMiningFields()) {
            for (MiningField miningField : miningSchema.getMiningFields()) {
                FieldName name = miningField.getName();
                if (name == null) {
                    throw new MissingAttributeException(miningField, PMMLAttributes.MININGFIELD_NAME);
                }
                switch (miningField.getUsageType()) {
                    case PREDICTED:
                    case TARGET:
                        linkedHashMap.put(name, resolveTargetDataType(name));
                        break;
                }
            }
        }
        this.targetDataTypes.push(linkedHashMap);
        this.dataType = getDataType();
    }

    private DataType getDataType() {
        for (Map<FieldName, DataType> map : this.targetDataTypes) {
            if (!map.isEmpty()) {
                HashSet hashSet = new HashSet(map.values());
                if (hashSet.size() == 1) {
                    return (DataType) Iterables.getOnlyElement(hashSet);
                }
                return null;
            }
        }
        return null;
    }

    private Object parseTargetValue(Object obj) {
        if (obj != null && this.dataType != null) {
            return TypeUtil.parseOrCast(this.dataType, obj);
        }
        return obj;
    }

    private Object parseTargetValue(FieldName fieldName, Object obj) {
        DataType dataType;
        if (fieldName == null) {
            return parseTargetValue(obj);
        }
        if (obj != null && (dataType = this.targetDataTypes.peekFirst().get(fieldName)) != null) {
            return TypeUtil.parseOrCast(dataType, obj);
        }
        return obj;
    }
}
