package org.bigml.mimir.image;

import com.fasterxml.jackson.databind.JsonNode;
import java.io.File;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.bigml.mimir.cache.CNNCache;
import org.bigml.mimir.deepnet.network.Network;
import org.bigml.mimir.deepnet.network.yolo.BoundingBoxes;
import org.bigml.mimir.deepnet.network.yolo.TensorflowBoundingBoxPredictor;
import org.bigml.mimir.image.featurize.TensorflowFeaturizer;
import org.bigml.mimir.utils.Json;
import org.bigml.mimir.utils.ResourceLoader;
import org.bigml.mimir.utils.TestUtils;
import org.bigml.mimir.utils.fields.FieldCollection;
import org.bigml.mimir.utils.fields.ImageField;
import org.junit.Assert;
import org.junit.Test;

/* loaded from: input_file:org/bigml/mimir/image/ImageTest.class */
public class ImageTest {
    public static JsonNode readoutLayers(String str) {
        return Json.parseObject(ResourceLoader.stringForFile(str + "_pretrained_" + CNNCache.metadataForNetwork(str).getVersion() + ".json.gz"));
    }

    public static Network createPretrained(String str) {
        ImageField imageField = new ImageField("000000", "image", new TensorflowFeaturizer(str, 2));
        ArrayList arrayList = new ArrayList();
        arrayList.add(imageField);
        return Network.createNetwork(readoutLayers(str), new FieldCollection(arrayList));
    }

    @Test
    public void loadDog() {
        double[][][] objectTo3DArray = new ChannelwiseCenteringReader(new int[]{224, 224, 3}).objectTo3DArray(TestUtils.getTestFile("dog.jpg"));
        Assert.assertTrue(-7.938999999999993d == objectTo3DArray[2][12][0]);
        Assert.assertTrue(-74.779d == objectTo3DArray[12][112][1]);
        Assert.assertTrue(-103.68d == objectTo3DArray[112][112][2]);
    }

    public void testImages(String str, double d, double d2) {
        Network createPretrained = createPretrained(str);
        File testFile = TestUtils.getTestFile("dog.jpg");
        File testFile2 = TestUtils.getTestFile("bus.jpg");
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        hashMap.put("000000", testFile);
        hashMap2.put("000000", testFile2);
        double[] predict = createPretrained.predict(hashMap);
        double[] predict2 = createPretrained.predict(hashMap2);
        long currentTimeMillis = System.currentTimeMillis();
        double[] predict3 = createPretrained.predict(hashMap);
        double[] predict4 = createPretrained.predict(hashMap2);
        long currentTimeMillis2 = System.currentTimeMillis() - currentTimeMillis;
        Assert.assertTrue(TestUtils.aboutEquals(predict, predict3));
        Assert.assertTrue(TestUtils.aboutEquals(predict2, predict4));
        System.out.println("CPU time: " + (currentTimeMillis2 / 2000.0d) + " sec. / image");
        for (int i = 0; i < predict.length; i++) {
            if (i != 254) {
                Assert.assertTrue(predict[i] + " > 0.03", predict[i] < 0.03d);
            } else {
                double d3 = predict[i];
                Assert.assertTrue(d3 + " < " + d3, predict[i] > d);
            }
        }
        for (int i2 = 0; i2 < predict2.length; i2++) {
            if (i2 != 779) {
                Assert.assertTrue(predict2[i2], predict2[i2] < 0.01d);
            } else {
                Assert.assertTrue(predict2[i2], predict2[i2] > d2);
            }
        }
    }

    @Test
    public void testMobileNet() {
        testImages("mobilenet", 0.97d, 0.99d);
    }

    @Test
    public void testMobileNetV2() {
        testImages("mobilenetv2", 0.87d, 0.98d);
    }

    @Test
    public void testResnet18() {
        testImages("resnet18", 0.95d, 0.99d);
    }

    @Test
    public void testResnet50() {
        testImages("resnet50", 0.99d, 0.99d);
    }

    @Test
    public void testXception() {
        testImages("xception", 0.87d, 0.89d);
    }

    public void testBoundingBox(String str, int i, Set<String> set) throws Exception {
        TensorflowBoundingBoxPredictor tensorflowBoundingBoxPredictor = new TensorflowBoundingBoxPredictor(CNNCache.getFeaturizer(str));
        File testFile = TestUtils.getTestFile("pizza_people.jpg");
        tensorflowBoundingBoxPredictor.predict(testFile);
        long currentTimeMillis = System.currentTimeMillis();
        BoundingBoxes predict = tensorflowBoundingBoxPredictor.predict(testFile);
        System.out.println("CPU time: " + ((System.currentTimeMillis() - currentTimeMillis) / 1000.0d) + " sec. / image");
        HashSet hashSet = new HashSet(predict.classList());
        Assert.assertTrue("Boxes: " + predict.size(), predict.size() == i);
        Assert.assertTrue("Classes: " + hashSet.toString(), hashSet.equals(set));
        tensorflowBoundingBoxPredictor.close();
    }

    @Test
    public void testYolo() throws Exception {
        HashSet hashSet = new HashSet();
        hashSet.add("person");
        hashSet.add("cup");
        hashSet.add("pizza");
        hashSet.add("dining table");
        testBoundingBox("yolov4", 8, hashSet);
    }

    @Test
    public void testTinyYolo() throws Exception {
        HashSet hashSet = new HashSet();
        hashSet.add("person");
        hashSet.add("pizza");
        testBoundingBox("tinyyolov4", 4, hashSet);
    }

    @Test
    public void shapeDetectorTest() throws Exception {
        File testFile = TestUtils.getTestFile("circle_object.png");
        File testFile2 = TestUtils.getTestFile("square_object.png");
        TensorflowBoundingBoxPredictor tensorflowBoundingBoxPredictor = new TensorflowBoundingBoxPredictor(ResourceLoader.streamForFile("shape_detector.smbundle.gz"), 1, 0.5d, 0.5d);
        BoundingBoxes predict = tensorflowBoundingBoxPredictor.predict(testFile);
        Assert.assertTrue(predict.size() == 1);
        Assert.assertTrue(predict.get(0).getClassName().equals("circle"));
        Assert.assertTrue(predict.get(0).getScore() > 0.8d);
        BoundingBoxes predict2 = tensorflowBoundingBoxPredictor.predict(testFile2);
        Assert.assertTrue(predict2.size() == 1);
        Assert.assertTrue(predict2.get(0).getClassName().equals("square"));
        Assert.assertTrue(predict2.get(0).getScore() > 0.8d);
        ArrayList arrayList = new ArrayList();
        arrayList.add(null);
        Assert.assertTrue(tensorflowBoundingBoxPredictor.predict((List<Object>) arrayList).size() == 0);
        tensorflowBoundingBoxPredictor.close();
        TensorflowBoundingBoxPredictor tensorflowBoundingBoxPredictor2 = new TensorflowBoundingBoxPredictor(ResourceLoader.streamForFile("shape_detector.smbundle.gz"), 1, 0.0d, 1.0d);
        Assert.assertTrue(tensorflowBoundingBoxPredictor2.predict(testFile).size() == 60);
        tensorflowBoundingBoxPredictor2.close();
    }

    @Test
    public void testNetworkList() {
        Map<String, Integer> imageExtractors = CNNCache.imageExtractors();
        Assert.assertTrue(imageExtractors.size() > 4);
        for (String str : imageExtractors.keySet()) {
            Assert.assertFalse(str.toLowerCase().contains("yolo"));
            Assert.assertTrue(imageExtractors.get(str).intValue() > 1);
        }
    }
}
