package org.bigml.mimir.forest;

import com.fasterxml.jackson.databind.JsonNode;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.bigml.binding.LocalAnomaly;
import org.bigml.mimir.Predictor;
import org.bigml.mimir.utils.Json;
import org.bigml.mimir.utils.ResourceLoader;
import org.bigml.mimir.utils.TestUtils;
import org.bigml.mimir.utils.fields.Field;
import org.bigml.mimir.utils.fields.FieldCollection;
import org.bigml.mimir.utils.fields.NanableNumericField;
import org.json.simple.JSONObject;
import org.json.simple.parser.JSONParser;
import org.junit.Assert;
import org.junit.Test;

/* loaded from: input_file:org/bigml/mimir/forest/ShapForestTest.class */
public class ShapForestTest {
    public JSONObject toJsonObject(JsonNode jsonNode, FieldCollection fieldCollection) {
        JSONObject jSONObject = new JSONObject();
        Iterator<Field> it = fieldCollection.iterator();
        while (it.hasNext()) {
            Field next = it.next();
            String str = next._name;
            if (jsonNode.get(str).isNull()) {
                jSONObject.put(str, (Object) null);
            } else if (next instanceof NanableNumericField) {
                jSONObject.put(str, Double.valueOf(Double.parseDouble(jsonNode.get(str).asText())));
            } else {
                jSONObject.put(str, jsonNode.get(str).asText());
            }
        }
        return jSONObject;
    }

    public void anomalyTest(String str, String str2) throws Exception {
        ShapForest shapForest = (ShapForest) Predictor.predictorFromStream(ResourceLoader.streamForFile(str));
        JsonNode readObject = Json.readObject(str2);
        LocalAnomaly localAnomaly = new LocalAnomaly((JSONObject) new JSONParser().parse(Json.readObject(str).toString()));
        FieldCollection fields = shapForest.getFields();
        JSONObject[] jSONObjectArr = new JSONObject[readObject.size()];
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < readObject.size(); i++) {
            JsonNode jsonNode = readObject.get(i);
            if (jsonNode.has("point")) {
                jsonNode = readObject.get(i).get("point");
            }
            arrayList.add(TestUtils.makeInstance(jsonNode));
            jSONObjectArr[i] = toJsonObject(jsonNode, fields);
        }
        double currentTimeMillis = System.currentTimeMillis();
        for (int i2 = 0; i2 < readObject.size(); i2++) {
            localAnomaly.score(jSONObjectArr[i2]);
        }
        double currentTimeMillis2 = (System.currentTimeMillis() - currentTimeMillis) / readObject.size();
        double currentTimeMillis3 = System.currentTimeMillis();
        for (int i3 = 0; i3 < readObject.size(); i3++) {
            shapForest.predict((Map<String, Object>) arrayList.get(i3));
        }
        System.out.println(String.format("BigML Bindings: %.2f msec/prediction", Double.valueOf(currentTimeMillis2)) + "\n" + String.format("Mimir: %.2f msec/prediction", Double.valueOf((System.currentTimeMillis() - currentTimeMillis3) / readObject.size())));
        for (int i4 = 0; i4 < readObject.size(); i4++) {
            double d = shapForest.predict((Map<String, Object>) arrayList.get(i4))[0];
            double score = readObject.get(i4).has("prediction") ? Json.get1DArray(readObject.get(i4).get("prediction"))[0] : localAnomaly.score(jSONObjectArr[i4]);
            Assert.assertTrue((score - d), Math.abs(score - d) < 1.0E-8d);
        }
    }

    public ShapForest explanationTest(String str, String str2, int i) {
        JsonNode readObject = Json.readObject(str);
        JsonNode readObject2 = Json.readObject(str2);
        ShapForest shapForest = new ShapForest(readObject);
        double currentTimeMillis = System.currentTimeMillis();
        int min = Math.min(i, readObject2.size());
        for (int i2 = 0; i2 < min; i2++) {
            JsonNode jsonNode = readObject2.get(i2);
            Map<String, Object> makeInstance = jsonNode.has("point") ? TestUtils.makeInstance(jsonNode.get("point")) : TestUtils.makeInstance(jsonNode);
            List<Object> list = shapForest.explain(makeInstance).get(0);
            if (readObject2.get(i2).has("explanation")) {
                JsonNode jsonNode2 = readObject2.get(i2).get("explanation").get(0);
                Assert.assertTrue(Math.abs(jsonNode2.get(0).asDouble() - ((Double) list.get(0)).doubleValue()) < 1.0E-8d);
                for (int i3 = 1; i3 < jsonNode2.size(); i3++) {
                    String asText = jsonNode2.get(i3).get(0).asText();
                    String field = ((ShapleyValue) list.get(i3)).getField();
                    double asDouble = jsonNode2.get(i3).get(1).asDouble();
                    double value = ((ShapleyValue) list.get(i3)).getValue();
                    Assert.assertTrue(asText.equals(field));
                    Assert.assertTrue(Math.abs(asDouble - value) < 1.0E-8d);
                }
            }
            Assert.assertTrue(Math.abs(((Double) list.get(0)).doubleValue() - shapForest.predict(makeInstance)[0]) < 1.0E-8d);
        }
        System.out.println(String.format("Shapley: %.2f msec/explanation", Double.valueOf((System.currentTimeMillis() - currentTimeMillis) / min)));
        return shapForest;
    }

    @Test
    public void irisAnomalyTest() throws Exception {
        anomalyTest("iris_anomaly.json.gz", "iris_anomaly_test.json.gz");
    }

    @Test
    public void lendingClubTest() throws Exception {
        anomalyTest("lending_club_anomaly.json.gz", "lending_club_anomaly_test.json.gz");
    }

    @Test
    public void flightTest() throws Exception {
        anomalyTest("flightmodel.json.gz", "flightdata.json.gz");
    }

    @Test
    public void lendingClubExplanationTest() {
        Assert.assertTrue(explanationTest("lending_club_anomaly.json.gz", "lending_club_explanation_test.json.gz", 256).getFields().getName("000002").equals("loan_amnt"));
    }

    @Test
    public void flightExplanationTest() throws Exception {
        explanationTest("flightmodel.json.gz", "flightdata.json.gz", 32);
    }
}
