package org.bigml.mimir.deepnet.layers.twod;

import com.fasterxml.jackson.databind.JsonNode;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import org.bigml.mimir.deepnet.layers.Activation;
import org.bigml.mimir.deepnet.layers.BatchNormalize;
import org.bigml.mimir.math.gpu.GPURunnable;
import org.bigml.mimir.utils.Json;

/* loaded from: input_file:org/bigml/mimir/deepnet/layers/twod/Layer2D.class */
public interface Layer2D extends Serializable {
    static Layer2D[] makeLayers(JsonNode jsonNode, boolean z) {
        if (jsonNode == null) {
            return new Layer2D[0];
        }
        ArrayList arrayList = new ArrayList(jsonNode.size());
        boolean z2 = z;
        for (int i = 0; i < jsonNode.size(); i++) {
            Layer2D makeLayer = makeLayer(jsonNode.get(i), arrayList, z2);
            int size = arrayList.size();
            if (size > 1 && (makeLayer instanceof Activation) && (arrayList.get(size - 1) instanceof BatchNormalize) && (arrayList.get(size - 2) instanceof Convolution2D)) {
                makeLayer = new ConvolutionBlock2D((Convolution2D) arrayList.remove(size - 2), (BatchNormalize) arrayList.remove(size - 1), (Activation) makeLayer);
            }
            makeLayer.setIndex(i);
            arrayList.add(makeLayer);
            if (makeLayer instanceof AbstractConvolution2D) {
                z2 = false;
            }
        }
        return toArray(arrayList);
    }

    static Layer2D[] toArray(List<Layer2D> list) {
        Layer2D[] layer2DArr = new Layer2D[list.size()];
        for (int i = 0; i < layer2DArr.length; i++) {
            layer2DArr[i] = list.get(i);
        }
        return layer2DArr;
    }

    static Layer2D makeLayer(JsonNode jsonNode, List<Layer2D> list, boolean z) {
        String asText = jsonNode.get("type").asText();
        if (SumBlock2D.isSumBlockType(asText)) {
            return SumBlock2D.makeSumBlock(jsonNode, asText);
        }
        if (asText.equals("padding_2d")) {
            return makePadding(jsonNode);
        }
        if (asText.equals("upsampling_2d")) {
            return makeUpSampling(jsonNode);
        }
        if (asText.equals("concatenate")) {
            return makeConcatenate(jsonNode, list);
        }
        if (asText.equals("convolution_2d")) {
            return makeConvolution(jsonNode, z);
        }
        if (asText.equals("separable_convolution_2d")) {
            return makeSeparableConvolution(jsonNode);
        }
        if (asText.equals("depthwise_convolution_2d")) {
            return makeDepthwiseConvolution(jsonNode);
        }
        if (asText.equals("batch_normalization")) {
            return makeBatchnorm(jsonNode);
        }
        if (asText.equals("max_pool_2d")) {
            return makeMaxPool(jsonNode);
        }
        if (asText.equals("average_pool_2d")) {
            return makeAveragePool(jsonNode);
        }
        if (asText.equals("global_max_pool_2d")) {
            return makeGlobalMaxPool(jsonNode);
        }
        if (asText.equals("global_average_pool_2d")) {
            return makeGlobalAveragePool(jsonNode);
        }
        if (asText.equals("activation")) {
            return makeActivation(jsonNode);
        }
        throw new IllegalArgumentException("'" + asText + "' is not a valid layer type");
    }

    static Layer2D makeLayer(JsonNode jsonNode) {
        return makeLayer(jsonNode, null, false);
    }

    static Layer2D makePadding(JsonNode jsonNode) {
        return new Padding2D(Json.get2DIntArray(jsonNode.get("padding")));
    }

    static Layer2D makeUpSampling(JsonNode jsonNode) {
        return new UpSampling2D(Json.get1DIntArray(jsonNode.get("size")));
    }

    static Layer2D makeConcatenate(JsonNode jsonNode, List<Layer2D> list) {
        int[] iArr = Json.get1DIntArray(jsonNode.get("inputs"));
        Layer2D[] array = toArray(list);
        return new Concatenate2D(getLayer(iArr[0], array), getLayer(iArr[1], array));
    }

    static double[] getBiases(JsonNode jsonNode, int i) {
        double[] dArr = Json.get1DArray(jsonNode.get("bias"));
        return dArr == null ? new double[i] : dArr;
    }

    static Layer2D makeConvolution(JsonNode jsonNode, boolean z) {
        double[][][][] dArr = Json.get4DArray(jsonNode.get("kernel"));
        double[] biases = getBiases(jsonNode, dArr[0][0][0].length);
        int[] iArr = Json.get1DIntArray(jsonNode.get("strides"));
        boolean equals = jsonNode.get("padding").asText().equals("same");
        return z ? new InitialConvolution2D(dArr, biases, iArr, equals) : new Convolution2D(dArr, biases, iArr, equals);
    }

    static Layer2D makeSeparableConvolution(JsonNode jsonNode) {
        if (jsonNode.get("depth_multiplier").asInt() != 1) {
            throw new UnsupportedOperationException("Depth multipliers > 1 are not yet implemented!");
        }
        double[][][][] dArr = Json.get4DArray(jsonNode.get("depth_kernel"));
        double[][][][] dArr2 = Json.get4DArray(jsonNode.get("point_kernel"));
        return new SeparableConvolution2D(dArr, dArr2, getBiases(jsonNode, dArr2[0][0][0].length), Json.get1DIntArray(jsonNode.get("strides")), jsonNode.get("padding").asText().equals("same"));
    }

    static Layer2D makeDepthwiseConvolution(JsonNode jsonNode) {
        if (jsonNode.get("depth_multiplier").asInt() != 1) {
            throw new UnsupportedOperationException("Depth multipliers > 1 are not yet implemented!");
        }
        double[][][][] dArr = Json.get4DArray(jsonNode.get("kernel"));
        return new DepthwiseConvolution2D(dArr, getBiases(jsonNode, dArr[0][0].length), Json.get1DIntArray(jsonNode.get("strides")), jsonNode.get("padding").asText().equals("same"));
    }

    static Layer2D makeBatchnorm(JsonNode jsonNode) {
        return new BatchNormalize(Json.get1DArray(jsonNode.get("mean")), Json.get1DArray(jsonNode.get("variance")), Json.get1DArray(jsonNode.get("beta")), Json.get1DArray(jsonNode.get("gamma")));
    }

    static Layer2D makeMaxPool(JsonNode jsonNode) {
        return new MaxPool2D(Json.get1DIntArray(jsonNode.get("strides")), Json.get1DIntArray(jsonNode.get("pool_size")), jsonNode.get("padding").asText().equals("same"));
    }

    static Layer2D makeAveragePool(JsonNode jsonNode) {
        return new AveragePool2D(Json.get1DIntArray(jsonNode.get("strides")), Json.get1DIntArray(jsonNode.get("pool_size")), jsonNode.get("padding").asText().equals("same"));
    }

    static Layer2D makeActivation(JsonNode jsonNode) {
        return new Activation(jsonNode.get("activation_function").asText());
    }

    static Layer2D makeGlobalAveragePool(JsonNode jsonNode) {
        return new GlobalAveragePool2D();
    }

    static Layer2D makeGlobalMaxPool(JsonNode jsonNode) {
        return new GlobalMaxPool2D();
    }

    static Layer2D getLayer(int i, Layer2D[] layer2DArr) {
        if (i == -1) {
            return layer2DArr[layer2DArr.length - 1];
        }
        if (i < 0) {
            throw new IllegalArgumentException("Invalid index: " + i);
        }
        for (Layer2D layer2D : layer2DArr) {
            if (layer2D != null && layer2D.getIndex() == i) {
                return layer2D;
            }
        }
        throw new IllegalArgumentException("Index " + i + " not found");
    }

    static float[] propagate2D(Layer2D[] layer2DArr, float[] fArr, int i) {
        float[] fArr2 = fArr;
        for (int i2 = 0; i2 < layer2DArr.length; i2++) {
            fArr2 = layer2DArr[i2] instanceof GPURunnable ? ((GPURunnable) layer2DArr[i2]).run(fArr2, i) : layer2DArr[i2].propagate2D(fArr2);
        }
        return fArr2;
    }

    static int[] initialize(Layer2D[] layer2DArr, int[] iArr) {
        int[] iArr2 = iArr;
        for (Layer2D layer2D : layer2DArr) {
            iArr2 = layer2D.initialize(iArr2);
        }
        return iArr2;
    }

    int[] initialize(int[] iArr);

    int getIndex();

    void setIndex(int i);

    int[] getInputShape();

    int[] getOutputShape();

    float[] propagate2D(float[] fArr);

    double[][][] propagateArray(double[][][] dArr);

    float[] getLastOutput();
}
