package org.bigml.mimir.deepnet;

import com.fasterxml.jackson.databind.JsonNode;
import org.bigml.mimir.deepnet.layers.twod.Flatten;
import org.bigml.mimir.deepnet.layers.twod.Layer2D;
import org.bigml.mimir.math.Matrices;
import org.bigml.mimir.utils.Json;
import org.bigml.mimir.utils.TestUtils;
import org.junit.Assert;
import org.junit.Test;

/* loaded from: input_file:org/bigml/mimir/deepnet/LayersTest.class */
public class LayersTest {
    public void layerTest2D(JsonNode jsonNode) {
        JsonNode jsonNode2 = jsonNode.get("parameters");
        double[][][] dArr = Json.get3DArray(jsonNode.get("input"));
        double[][][] dArr2 = Json.get3DArray(jsonNode.get("output"));
        Layer2D makeLayer = Layer2D.makeLayer(jsonNode2);
        makeLayer.initialize(Matrices.shape(dArr));
        Assert.assertTrue(jsonNode2.toString(), TestUtils.aboutEquals(dArr2, makeLayer.propagateArray(dArr)));
    }

    public void layerTestFlatten(JsonNode jsonNode) {
        JsonNode jsonNode2 = jsonNode.get("parameters");
        Assert.assertTrue(jsonNode2.get("type").asText().equals("flatten"));
        double[][][] dArr = Json.get3DArray(jsonNode.get("input"));
        Assert.assertTrue(jsonNode2.asText(), TestUtils.aboutEquals(Json.get1DArray(jsonNode.get("output")), Flatten.flatten(dArr)));
    }

    public void runTestAtPath(String str) {
        JsonNode readObject = Json.readObject(str);
        for (int i = 0; i < readObject.size(); i++) {
            if (str.contains("flatten")) {
                layerTestFlatten(readObject.get(i));
            } else {
                layerTest2D(readObject.get(i));
            }
        }
    }

    @Test
    public void paddingTest() {
        runTestAtPath("test_pad.json");
    }

    @Test
    public void convolutionTest() {
        runTestAtPath("test_convolution.json");
    }

    @Test
    public void seperableConvolutionTest() {
        runTestAtPath("test_separable_convolution.json");
    }

    @Test
    public void depthwiseConvolutionTest() {
        runTestAtPath("test_depthwise_convolution.json");
    }

    @Test
    public void batchNormalizationTest() {
        runTestAtPath("test_batchnorm.json");
    }

    @Test
    public void poolingTest() {
        runTestAtPath("test_pool.json");
    }

    @Test
    public void activationTest() {
        runTestAtPath("test_activate.json");
    }

    @Test
    public void flattenTest() {
        runTestAtPath("test_flatten.json");
    }
}
