package org.deeplearning4j.nn.layers.convolution.preprocessor;

import org.deeplearning4j.nn.conf.OutputPreProcessor;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.util.ArrayUtil;

/* loaded from: input_file:org/deeplearning4j/nn/layers/convolution/preprocessor/ConvolutionPostProcessor.class */
public class ConvolutionPostProcessor implements OutputPreProcessor {
    private int[] shape;

    public ConvolutionPostProcessor(int[] iArr) {
        this.shape = iArr;
    }

    public ConvolutionPostProcessor() {
    }

    @Override // org.deeplearning4j.nn.conf.OutputPreProcessor
    public INDArray preProcess(INDArray iNDArray) {
        if (this.shape == null || ArrayUtil.prod(this.shape) != iNDArray.length()) {
            if (iNDArray.shape().length == 4) {
                int[] iArr = new int[3];
                System.arraycopy(iNDArray.shape(), 1, iArr, 0, iArr.length);
                this.shape = new int[]{iNDArray.shape()[0], ArrayUtil.prod(iArr)};
            } else if (iNDArray.shape().length == 3) {
                int[] iArr2 = new int[2];
                System.arraycopy(iNDArray.shape(), 1, iArr2, 0, iArr2.length);
                this.shape = new int[]{iNDArray.shape()[0], ArrayUtil.prod(iArr2)};
            }
        }
        return iNDArray.reshape(this.shape);
    }
}
