package org.bigml.mimir.deepnet;

import com.fasterxml.jackson.databind.JsonNode;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.bigml.mimir.deepnet.network.NetworkPredictor;
import org.bigml.mimir.utils.Json;
import org.junit.Assert;
import org.junit.Test;

/* loaded from: input_file:org/bigml/mimir/deepnet/RegressionTest.class */
public class RegressionTest {
    public static Integer positivize(Integer num, int i) {
        return (num == null || num.intValue() >= 0) ? num : Integer.valueOf(i + num.intValue());
    }

    public static List<Object> removeNonInput(List<Object> list, Integer num, Integer num2) {
        ArrayList arrayList = new ArrayList(list.size());
        Integer positivize = positivize(num, list.size());
        Integer positivize2 = positivize(num2, list.size());
        for (int i = 0; i < list.size(); i++) {
            if ((positivize == null || i != positivize.intValue()) && (positivize2 == null || i != positivize2.intValue())) {
                arrayList.add(list.get(i));
            }
        }
        return arrayList;
    }

    public boolean aboutEquals(double[] dArr, double[] dArr2, double d) {
        if (dArr.length != dArr2.length) {
            return false;
        }
        for (int i = 0; i < dArr.length; i++) {
            if (Math.abs(dArr[i] - dArr2[i]) > d) {
                return false;
            }
        }
        return true;
    }

    public void regressionTest(String str) {
        JsonNode readObject = Json.readObject(str);
        for (int i = 0; i < readObject.size(); i++) {
            JsonNode jsonNode = readObject.get(i).get("model");
            JsonNode jsonNode2 = readObject.get(i).get("validation");
            NetworkPredictor createNetwork = NetworkPredictor.createNetwork(jsonNode);
            JsonNode jsonNode3 = jsonNode.get("class_index");
            JsonNode jsonNode4 = jsonNode.get("weight_index");
            Integer valueOf = jsonNode3.isNull() ? null : Integer.valueOf(jsonNode3.asInt());
            Integer valueOf2 = jsonNode4.isNull() ? null : Integer.valueOf(jsonNode4.asInt());
            for (int i2 = 0; i2 < jsonNode2.size(); i2++) {
                JsonNode jsonNode5 = jsonNode2.get(i2);
                List<Object> objectList = Json.getObjectList(jsonNode5.get("input"));
                double[] dArr = Json.get1DArray(jsonNode5.get("output"));
                double[] predict = createNetwork.predict(removeNonInput(objectList, valueOf, valueOf2));
                Assert.assertTrue(Arrays.toString(dArr) + " != " + Arrays.toString(predict), aboutEquals(dArr, predict, 1.0E-5d));
            }
        }
    }

    @Test
    public void regressionTest() {
        regressionTest("deepnet_regression.json.gz");
    }

    @Test
    public void legacyRegressionTest() {
        regressionTest("legacy_deepnet_regression.json.gz");
    }

    @Test
    public void learnedNetworkImageRegressionTest() {
        regressionTest("simple_image_regression.json.gz");
    }

    @Test
    public void searchRegressionTest() {
        regressionTest("search_regression.json.gz");
    }
}
