package org.bigml.mimir.deepnet;

import com.fasterxml.jackson.databind.JsonNode;
import java.util.List;
import java.util.Map;
import org.bigml.mimir.deepnet.network.NetworkPredictor;
import org.bigml.mimir.deepnet.network.TensorflowWrappedPredictor;
import org.bigml.mimir.utils.Json;
import org.bigml.mimir.utils.ResourceLoader;
import org.bigml.mimir.utils.TestUtils;
import org.junit.Test;

/* loaded from: input_file:org/bigml/mimir/deepnet/TensorflowPredictorTest.class */
public class TensorflowPredictorTest {
    public void bundleTest(String str, String str2, String str3) throws Exception {
        List<Map<String, Object>> readData = TestUtils.readData(str);
        JsonNode parseStream = Json.parseStream(ResourceLoader.streamForFile(str2));
        NetworkPredictor createNetwork = NetworkPredictor.createNetwork(parseStream);
        TensorflowWrappedPredictor tensorflowWrappedPredictor = new TensorflowWrappedPredictor(parseStream, ResourceLoader.streamForFile(str3), 1);
        for (Map<String, Object> map : readData) {
            TestUtils.aboutEquals(NetworkPredictor.toProbabilities(createNetwork.predictions(map)), tensorflowWrappedPredictor.predict(map), 1.0E-5d);
        }
        tensorflowWrappedPredictor.close();
    }

    @Test
    public void irisTest() throws Exception {
        bundleTest("iris_data.json.gz", "iris_model.json.gz", "iris_model.smbundle.gz");
    }

    @Test
    public void titanicTest() throws Exception {
        bundleTest("titanic_data.json.gz", "titanic_model.json.gz", "titanic_model.smbundle.gz");
    }

    @Test
    public void titanicSearchTest() throws Exception {
        bundleTest("titanic_data.json.gz", "titanic_search_model.json.gz", "titanic_search_model.smbundle.gz");
    }
}
