package org.deeplearning4j.nn.layers.convolution.upsampling;

import java.util.Arrays;
import org.deeplearning4j.exception.DL4JInvalidInputException;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/nn/layers/convolution/upsampling/Upsampling1D.class */
public class Upsampling1D extends Upsampling2D {
    private static final Logger log = LoggerFactory.getLogger((Class<?>) Upsampling1D.class);

    public Upsampling1D(NeuralNetConfiguration neuralNetConfiguration, DataType dataType) {
        super(neuralNetConfiguration, dataType);
    }

    @Override // org.deeplearning4j.nn.layers.convolution.upsampling.Upsampling2D, org.deeplearning4j.nn.api.Layer
    public Pair<Gradient, INDArray> backpropGradient(INDArray iNDArray, LayerWorkspaceMgr layerWorkspaceMgr) {
        assertInputSet(true);
        int[] size = layerConf().getSize();
        INDArray repeat = iNDArray.reshape(iNDArray.size(0), iNDArray.size(1), iNDArray.size(2), 1).repeat(3, size[0]);
        INDArray iNDArray2 = this.input;
        this.input = this.input.castTo(this.dataType).reshape(this.input.size(0), this.input.size(1), this.input.size(2), 1);
        long size2 = this.input.size(0);
        long size3 = this.input.size(1);
        long size4 = this.input.size(2);
        long size5 = this.input.size(3);
        INDArray reshape = layerWorkspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, this.input.dataType(), size2 * size3 * size4 * size5).reshape('c', size2, size3, size4, size5);
        Nd4j.getExecutioner().exec(DynamicCustomOp.builder("upsampling_bp").addIntegerArguments(1).addInputs(this.input, repeat).addOutputs(reshape).callInplace(false).build());
        DefaultGradient defaultGradient = new DefaultGradient();
        INDArray slice = reshape.slice(0L, 3);
        this.input = iNDArray2;
        return new Pair<>(defaultGradient, slice.divi(Integer.valueOf(size[0])));
    }

    @Override // org.deeplearning4j.nn.layers.convolution.upsampling.Upsampling2D
    protected int[] getSize() {
        return ((org.deeplearning4j.nn.conf.layers.Upsampling1D) this.conf.getLayer()).getSize();
    }

    @Override // org.deeplearning4j.nn.layers.convolution.upsampling.Upsampling2D, org.deeplearning4j.nn.api.Layer
    public INDArray activate(boolean z, LayerWorkspaceMgr layerWorkspaceMgr) {
        assertInputSet(false);
        if (this.input.rank() != 3) {
            throw new DL4JInvalidInputException("Got rank " + this.input.rank() + " array as input to Subsampling1DLayer with shape " + Arrays.toString(this.input.shape()) + ". Expected rank 3 array with shape [minibatchSize, features, length]. " + layerId());
        }
        INDArray iNDArray = this.input;
        this.input = this.input.castTo(this.dataType).reshape(this.input.size(0), this.input.size(1), this.input.size(2), 1);
        INDArray activate = super.activate(z, layerWorkspaceMgr);
        this.input = iNDArray;
        return activate.reshape(activate.size(0), activate.size(1), activate.size(2));
    }

    @Override // org.deeplearning4j.nn.layers.convolution.upsampling.Upsampling2D
    protected INDArray preOutput(boolean z, boolean z2, LayerWorkspaceMgr layerWorkspaceMgr) {
        INDArray iNDArray = this.input;
        this.input = this.input.reshape(this.input.size(0), this.input.size(1), this.input.size(2), 1);
        INDArray preOutput = super.preOutput(z, z2, layerWorkspaceMgr);
        this.input = iNDArray;
        return preOutput.slice(0L, 3);
    }
}
