package org.bigml.mimir.deepnet.network;

import com.fasterxml.jackson.databind.JsonNode;
import java.io.Serializable;
import org.bigml.mimir.deepnet.layers.twod.AbstractGlobalPool2D;
import org.bigml.mimir.deepnet.layers.twod.Layer2D;
import org.bigml.mimir.image.ImageReader;

/* loaded from: input_file:org/bigml/mimir/deepnet/network/Network2D.class */
public class Network2D implements Serializable {
    protected Layer2D[] _layers;
    protected AbstractGlobalPool2D _poolingOperation;
    protected ImageNetworkMetadata _metadata;
    protected transient ImageReader _reader;
    private static final long serialVersionUID = 1;

    public Network2D(JsonNode jsonNode, ImageNetworkMetadata imageNetworkMetadata) {
        this._metadata = imageNetworkMetadata;
        this._poolingOperation = null;
        Layer2D[] makeLayers = Layer2D.makeLayers(jsonNode, true);
        int length = makeLayers.length;
        if (makeLayers[makeLayers.length - 1] instanceof AbstractGlobalPool2D) {
            this._poolingOperation = (AbstractGlobalPool2D) makeLayers[makeLayers.length - 1];
            length--;
        }
        this._layers = new Layer2D[length];
        for (int i = 0; i < this._layers.length; i++) {
            this._layers[i] = makeLayers[i];
        }
    }

    public void initialize() {
        int[] shape = this._metadata.getShape();
        if (shape[2] != 3) {
            throw new IllegalArgumentException("Networks only accept 3-channel input");
        }
        this._reader = this._metadata.getReader();
        int[] iArr = {shape[1], shape[0], 4};
        for (int i = 0; i < this._layers.length; i++) {
            iArr = this._layers[i].initialize(iArr);
        }
        if (this._poolingOperation != null) {
            this._poolingOperation.initialize(iArr);
        }
    }

    public float[] predict(Object obj, int i) {
        return predict(this._reader.objectTo1DArray(obj), i);
    }

    public float[] predict(float[] fArr, int i) {
        return this._poolingOperation.toPooledVector(predictGrid(fArr, i));
    }

    public float[] predictGrid(float[] fArr, int i) {
        return Layer2D.propagate2D(this._layers, fArr, i);
    }

    public Layer2D getLayer(int i) {
        return Layer2D.getLayer(i, this._layers);
    }

    public ImageNetworkMetadata getMetadata() {
        return this._metadata;
    }

    public int getNumberOfOutputs() {
        return this._metadata.getOutputs();
    }

    public int[] getOutputShape(int i) {
        return getLayer(i).getOutputShape();
    }
}
