package org.deeplearning4j.nn.layers.feedforward.embedding;

import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.BaseLayer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/deeplearning4j/nn/layers/feedforward/embedding/EmbeddingLayer.class */
public class EmbeddingLayer extends BaseLayer<org.deeplearning4j.nn.conf.layers.EmbeddingLayer> {
    public EmbeddingLayer(NeuralNetConfiguration neuralNetConfiguration) {
        super(neuralNetConfiguration);
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public Pair<Gradient, INDArray> backpropGradient(INDArray iNDArray) {
        INDArray muli = iNDArray.muli(Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(conf().getLayer().getActivationFunction(), preOutput(this.input)).derivative()));
        if (this.maskArray != null) {
            muli.muliColumnVector(this.maskArray);
        }
        getParam("W");
        INDArray iNDArray2 = this.gradientViews.get("W");
        iNDArray2.assign(0);
        int[] iArr = new int[this.input.length()];
        for (int i = 0; i < iArr.length; i++) {
            iArr[i] = this.input.getInt(new int[]{i, 0});
            iNDArray2.getRow(iArr[i]).addi(muli.getRow(i));
        }
        INDArray iNDArray3 = this.gradientViews.get("b");
        iNDArray3.assign(muli.sum(new int[]{0}));
        DefaultGradient defaultGradient = new DefaultGradient();
        defaultGradient.gradientForVariable().put("W", iNDArray2);
        defaultGradient.gradientForVariable().put("b", iNDArray3);
        return new Pair<>(defaultGradient, null);
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer
    public INDArray preOutput(boolean z) {
        if (this.input.columns() != 1) {
            throw new IllegalStateException("Cannot do forward pass for embedding layer with input more than one column. Expected input shape: [numExamples,1] with each entry being an integer index");
        }
        int[] iArr = new int[this.input.length()];
        for (int i = 0; i < iArr.length; i++) {
            iArr[i] = this.input.getInt(new int[]{i, 0});
        }
        INDArray param = getParam("W");
        INDArray param2 = getParam("b");
        INDArray createUninitialized = Nd4j.createUninitialized(new int[]{iArr.length, param.size(1)}, 'c');
        for (int i2 = 0; i2 < iArr.length; i2++) {
            createUninitialized.putRow(i2, param.getRow(iArr[i2]));
        }
        createUninitialized.addiRowVector(param2);
        return createUninitialized;
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public INDArray activate(boolean z) {
        INDArray execAndReturn = Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(this.conf.getLayer().getActivationFunction(), preOutput(z)));
        if (this.maskArray != null) {
            execAndReturn.muliColumnVector(this.maskArray);
        }
        return execAndReturn;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.deeplearning4j.nn.layers.BaseLayer
    public void applyDropOutIfNecessary(boolean z) {
        throw new UnsupportedOperationException("Dropout not supported with EmbeddingLayer");
    }
}
