package org.bigml.mimir.deepnet.layers.twod;

import com.fasterxml.jackson.databind.JsonNode;
import java.util.Arrays;
import java.util.HashMap;
import org.bigml.mimir.deepnet.layers.Activation;
import org.bigml.mimir.image.FeaturizeTest;
import org.bigml.mimir.math.Matrices;
import org.bigml.mimir.math.gpu.GPURunnable;
import org.bigml.mimir.math.gpu.Program;

/* loaded from: input_file:org/bigml/mimir/deepnet/layers/twod/SumBlock2D.class */
public class SumBlock2D implements Layer2D, GPURunnable {
    private int _index;
    protected int[] _inputShape;
    protected Layer2D[] _path1;
    protected Layer2D[] _path2;
    protected Activation.ActivationFn _afn;
    protected transient OutputTensor _output;
    private static HashMap<String, SumBlockType> _BLOCK_TYPES;
    private static final long serialVersionUID = 1;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.bigml.mimir.deepnet.layers.twod.SumBlock2D$1, reason: invalid class name */
    /* loaded from: input_file:org/bigml/mimir/deepnet/layers/twod/SumBlock2D$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$bigml$mimir$deepnet$layers$twod$SumBlock2D$SumBlockType = new int[SumBlockType.values().length];

        static {
            try {
                $SwitchMap$org$bigml$mimir$deepnet$layers$twod$SumBlock2D$SumBlockType[SumBlockType.RESIDUAL.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$bigml$mimir$deepnet$layers$twod$SumBlock2D$SumBlockType[SumBlockType.XCEPTION.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/bigml/mimir/deepnet/layers/twod/SumBlock2D$SumBlockType.class */
    public enum SumBlockType {
        RESIDUAL,
        XCEPTION
    }

    public static boolean isSumBlockType(String str) {
        return _BLOCK_TYPES.containsKey(str);
    }

    public static SumBlock2D makeSumBlock(JsonNode jsonNode, String str) {
        String str2;
        String str3;
        switch (AnonymousClass1.$SwitchMap$org$bigml$mimir$deepnet$layers$twod$SumBlock2D$SumBlockType[_BLOCK_TYPES.get(str).ordinal()]) {
            case 1:
                str2 = "convolution_path";
                str3 = "identity_path";
                break;
            case FeaturizeTest.LEVELS /* 2 */:
                str2 = "separable_convolution_path";
                str3 = "single_convolution_path";
                break;
            default:
                throw new IllegalArgumentException("Block type '" + str + "' unknown!");
        }
        JsonNode jsonNode2 = jsonNode.get(str2);
        JsonNode jsonNode3 = jsonNode.get(str3);
        Layer2D[] makeLayers = Layer2D.makeLayers(jsonNode2, false);
        Layer2D[] makeLayers2 = Layer2D.makeLayers(jsonNode3, false);
        String str4 = null;
        JsonNode jsonNode4 = jsonNode.get("activation_function");
        if (jsonNode4 != null) {
            str4 = jsonNode4.asText();
        }
        return new SumBlock2D(makeLayers, makeLayers2, str4);
    }

    public SumBlock2D(Layer2D[] layer2DArr, Layer2D[] layer2DArr2, String str) {
        this._path1 = layer2DArr;
        this._path2 = layer2DArr2;
        this._afn = Activation.getActivator(str);
    }

    @Override // org.bigml.mimir.deepnet.layers.twod.Layer2D
    public int getIndex() {
        return this._index;
    }

    @Override // org.bigml.mimir.deepnet.layers.twod.Layer2D
    public void setIndex(int i) {
        this._index = i;
    }

    @Override // org.bigml.mimir.deepnet.layers.twod.Layer2D
    public int[] initialize(int[] iArr) {
        int[] initialize = Layer2D.initialize(this._path1, iArr);
        int[] initialize2 = Layer2D.initialize(this._path2, iArr);
        try {
            if (!$assertionsDisabled && initialize[0] != initialize2[0]) {
                throw new AssertionError();
            }
            if (!$assertionsDisabled && initialize[1] != initialize2[1]) {
                throw new AssertionError();
            }
            if (!$assertionsDisabled && initialize[2] != initialize2[2]) {
                throw new AssertionError();
            }
            this._output = new OutputTensor(initialize);
            return initialize;
        } catch (Throwable th) {
            System.err.println(Arrays.toString(initialize));
            System.err.println(Arrays.toString(initialize2));
            throw th;
        }
    }

    @Override // org.bigml.mimir.deepnet.layers.twod.Layer2D
    public int[] getOutputShape() {
        return this._output.getShape();
    }

    @Override // org.bigml.mimir.deepnet.layers.twod.Layer2D
    public int[] getInputShape() {
        return this._inputShape;
    }

    @Override // org.bigml.mimir.math.gpu.GPURunnable
    public Program.Type getProgramType() {
        return null;
    }

    @Override // org.bigml.mimir.math.gpu.GPURunnable
    public float[] run(float[] fArr, int i) {
        float[] fArr2 = this._output.get();
        float[] propagate2D = Layer2D.propagate2D(this._path1, fArr, i);
        float[] propagate2D2 = Layer2D.propagate2D(this._path2, fArr, i);
        for (int i2 = 0; i2 < propagate2D.length; i2++) {
            fArr2[i2] = propagate2D[i2] + propagate2D2[i2];
        }
        return Activation.activate(fArr2, this._afn);
    }

    @Override // org.bigml.mimir.deepnet.layers.twod.Layer2D
    public float[] propagate2D(float[] fArr) {
        return run(fArr, -1);
    }

    @Override // org.bigml.mimir.deepnet.layers.twod.Layer2D
    public double[][][] propagateArray(double[][][] dArr) {
        return Matrices.reshape(propagate2D(Matrices.unroll(dArr)), this._output.getShape());
    }

    @Override // org.bigml.mimir.deepnet.layers.twod.Layer2D
    public float[] getLastOutput() {
        return this._output.get();
    }

    static {
        $assertionsDisabled = !SumBlock2D.class.desiredAssertionStatus();
        _BLOCK_TYPES = new HashMap<>();
        _BLOCK_TYPES.put("resnet_block", SumBlockType.RESIDUAL);
        _BLOCK_TYPES.put("resnet18_block", SumBlockType.RESIDUAL);
        _BLOCK_TYPES.put("mobilenet_residual_block", SumBlockType.RESIDUAL);
        _BLOCK_TYPES.put("darknet_residual_block", SumBlockType.RESIDUAL);
        _BLOCK_TYPES.put("xception_block", SumBlockType.XCEPTION);
    }
}
