package org.bigml.mimir.forest;

import com.fasterxml.jackson.databind.JsonNode;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.bigml.mimir.Predictor;
import org.bigml.mimir.math.Matrices;
import org.bigml.mimir.math.Vectors;
import org.bigml.mimir.utils.fields.FieldCollection;

/* loaded from: input_file:org/bigml/mimir/forest/ShapForest.class */
public class ShapForest implements Predictor {
    private static double DEPTH_FACTOR = 0.5772156649d;
    private boolean _isAnomalyModel;
    private int _nOutputs;
    private String _resourceId;
    private double[][] _startPath;
    private FieldCollection _fields;
    private ShapNode[] _trees;
    private static final long serialVersionUID = 1;

    public ShapForest(JsonNode jsonNode) {
        double findMeanDepth;
        JsonNode jsonNode2 = jsonNode;
        if (jsonNode.has("object") && jsonNode.get("object").has("model")) {
            jsonNode2 = jsonNode.get("object");
        } else if (!jsonNode2.has("model")) {
            throw new IllegalArgumentException("Bad JSON");
        }
        this._resourceId = Predictor.createResourceId(jsonNode2);
        this._fields = FieldCollection.makeShapForestCollection(jsonNode2);
        JsonNode jsonNode3 = jsonNode2.get("model").get("trees");
        if (jsonNode3.get(0).get("root").has("objective")) {
            this._isAnomalyModel = false;
            findMeanDepth = 1.0d;
        } else {
            this._isAnomalyModel = true;
            findMeanDepth = findMeanDepth(jsonNode2);
        }
        this._trees = new ShapNode[jsonNode3.size()];
        int i = -1;
        for (int i2 = 0; i2 < jsonNode3.size(); i2++) {
            this._trees[i2] = ShapNode.createNode(jsonNode3.get(i2).get("root"), null, this._fields, 1, findMeanDepth);
            int computeExpectation = computeExpectation(this._trees[i2], 0);
            if (computeExpectation > i) {
                i = computeExpectation;
            }
        }
        this._nOutputs = this._trees[0].expectation.length;
        this._startPath = Path.newPath(null, i);
    }

    @Override // org.bigml.mimir.Predictor
    public String getResourceId() {
        return this._resourceId;
    }

    public FieldCollection getFields() {
        return this._fields;
    }

    @Override // org.bigml.mimir.Predictor
    public double[] predict(double[] dArr) {
        double[] dArr2 = new double[this._nOutputs];
        for (ShapNode shapNode : this._trees) {
            Vectors.addInPlace(dArr2, nodePredict(shapNode, dArr));
        }
        Vectors.divideInPlace(dArr2, this._trees.length);
        if (this._isAnomalyModel) {
            dArr2[0] = Math.pow(2.0d, -dArr2[0]);
        }
        return dArr2;
    }

    @Override // org.bigml.mimir.Predictor
    public double[] predict(Map<String, Object> map) {
        return predict(this._fields.toDoubles(map));
    }

    @Override // org.bigml.mimir.Predictor
    public double[] predict(List<Object> list) {
        return predict(this._fields.toDoubles(list));
    }

    public List<List<Object>> explain(Map<String, Object> map) {
        double[] doubles = this._fields.toDoubles(map);
        double[][] dArr = new double[doubles.length + 1][this._nOutputs];
        for (ShapNode shapNode : this._trees) {
            shapForTree(shapNode, doubles, dArr);
        }
        Matrices.divideInPlace(dArr, this._trees.length);
        return formatShapValues(dArr);
    }

    private double[] nodePredict(ShapNode shapNode, double[] dArr) {
        int i = 0;
        ShapNode shapNode2 = shapNode;
        while (!shapNode2.isLeaf && i >= 0) {
            i = shapNode2.nextIndex(dArr);
            if (i == 0) {
                shapNode2 = shapNode2.left;
            } else if (i == 1) {
                shapNode2 = shapNode2.right;
            }
        }
        return shapNode2.objective;
    }

    private void shapForTree(ShapNode shapNode, double[] dArr, double[][] dArr2) {
        int length = dArr2.length - 1;
        for (int i = 0; i < dArr2[length].length; i++) {
            double[] dArr3 = dArr2[length];
            int i2 = i;
            dArr3[i2] = dArr3[i2] + shapNode.expectation[i];
        }
        Path.shapForNode(dArr, dArr2, shapNode, 0, this._startPath, 1.0d, 1.0d, -1);
    }

    private double shapleySum(List<ShapleyValue> list, double d) {
        double d2 = d;
        Iterator<ShapleyValue> it = list.iterator();
        while (it.hasNext()) {
            d2 += it.next().getValue();
        }
        return d2;
    }

    private List<Object> anomalize(List<ShapleyValue> list, double d) {
        ArrayList arrayList = new ArrayList();
        arrayList.add(Double.valueOf(Math.pow(2.0d, -shapleySum(list, d))));
        double d2 = d;
        for (ShapleyValue shapleyValue : list) {
            double value = d2 + shapleyValue.getValue();
            arrayList.add(new ShapleyValue(shapleyValue.getField(), Math.pow(2.0d, -value) - Math.pow(2.0d, -d2)));
            d2 = value;
        }
        return arrayList;
    }

    private List<List<Object>> formatShapValues(double[][] dArr) {
        List arrayList;
        ArrayList arrayList2 = new ArrayList();
        for (int i = 0; i < this._nOutputs; i++) {
            HashMap hashMap = new HashMap();
            for (int i2 = 0; i2 < dArr.length - 1; i2++) {
                String id = this._fields.getId(i2);
                hashMap.put(id, Double.valueOf(dArr[i2][i] + ((Double) hashMap.getOrDefault(id, Double.valueOf(0.0d))).doubleValue()));
            }
            ArrayList arrayList3 = new ArrayList();
            for (String str : hashMap.keySet()) {
                arrayList3.add(new ShapleyValue(str, ((Double) hashMap.get(str)).doubleValue()));
            }
            Collections.sort(arrayList3);
            double d = dArr[dArr.length - 1][i];
            if (this._isAnomalyModel) {
                arrayList = anomalize(arrayList3, d);
            } else {
                arrayList = new ArrayList();
                arrayList.add(Double.valueOf(shapleySum(arrayList3, d)));
                arrayList.addAll(arrayList3);
            }
            arrayList2.add(arrayList);
        }
        return arrayList2;
    }

    private static double findMeanDepth(JsonNode jsonNode) {
        int asInt = jsonNode.get("sample_size").asInt();
        double d = Double.POSITIVE_INFINITY;
        if (jsonNode.get("model").has("mean_depth")) {
            d = jsonNode.get("model").get("mean_depth").asDouble();
        }
        return Math.min(d, 2.0d * ((DEPTH_FACTOR + Math.log(asInt - 1)) - ((asInt - 1.0d) / asInt)));
    }

    private static int computeExpectation(ShapNode shapNode, int i) {
        if (shapNode.isLeaf) {
            shapNode.expectation = shapNode.objective;
            return 0;
        }
        if (shapNode.isMultipredicate) {
            int computeExpectation = computeExpectation(shapNode.left, i + 1);
            shapNode.expectation = shapNode.left.expectation;
            return computeExpectation + 1;
        }
        int computeExpectation2 = computeExpectation(shapNode.left, i + 1);
        int computeExpectation3 = computeExpectation(shapNode.right, i + 1);
        double d = shapNode.left.weight;
        double d2 = shapNode.right.weight;
        double[] dArr = shapNode.left.expectation;
        double[] dArr2 = shapNode.right.expectation;
        if (d == 0.0d && d2 == 0.0d) {
            d2 = 0.5d;
            d = 0.5d;
        }
        shapNode.expectation = Vectors.add(dArr, dArr2, d, d2);
        return Math.max(computeExpectation2, computeExpectation3) + 1;
    }
}
