package org.deeplearning4j.nn.layers.convolution.preprocessor;

import java.util.Arrays;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.OutputPreProcessor;
import org.nd4j.linalg.api.ndarray.INDArray;

/* loaded from: input_file:org/deeplearning4j/nn/layers/convolution/preprocessor/ConvolutionInputPreProcessor.class */
public class ConvolutionInputPreProcessor implements OutputPreProcessor, InputPreProcessor {
    private int rows;
    private int cols;

    public ConvolutionInputPreProcessor(int i, int i2) {
        this.rows = i;
        this.cols = i2;
    }

    @Override // org.deeplearning4j.nn.conf.OutputPreProcessor
    public INDArray preProcess(INDArray iNDArray) {
        if (iNDArray.shape().length == 4) {
            return iNDArray;
        }
        if (iNDArray.columns() != this.rows * this.cols) {
            throw new IllegalArgumentException("Output columns must be equal to rows " + this.rows + " x columns " + this.cols + " but was instead " + Arrays.toString(iNDArray.shape()));
        }
        return iNDArray.reshape(new int[]{iNDArray.rows(), 1, this.rows, this.cols});
    }
}
