package org.nd4j.linalg.api.ops.impl.transforms.custom;

import java.util.Collections;
import java.util.List;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.LogSoftMaxDerivative;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/nd4j/linalg/api/ops/impl/transforms/custom/LogSoftMax.class */
public class LogSoftMax extends DynamicCustomOp {
    private Integer dimension;

    public LogSoftMax(SameDiff sameDiff, SDVariable sDVariable) {
        super(sameDiff, sDVariable);
        this.dimension = null;
    }

    public LogSoftMax() {
        this.dimension = null;
    }

    public LogSoftMax(INDArray iNDArray, INDArray iNDArray2) {
        super((String) null, iNDArray, iNDArray2, (List<Double>) null, (int[]) null);
        this.dimension = null;
    }

    public LogSoftMax(INDArray iNDArray) {
        this(iNDArray, iNDArray);
    }

    public LogSoftMax(INDArray iNDArray, int i) {
        this(iNDArray, (INDArray) null);
        this.dimension = Integer.valueOf(i);
    }

    public LogSoftMax(SameDiff sameDiff, SDVariable sDVariable, int i) {
        this(sameDiff, sDVariable);
        this.dimension = Integer.valueOf(i);
        addIArgument(i);
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction, org.nd4j.linalg.api.ops.CustomOp
    public String opName() {
        return "log_softmax";
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public String tensorflowName() {
        return "LogSoftmax";
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public List<SDVariable> doDiff(List<SDVariable> list) {
        return this.dimension == null ? new LogSoftMaxDerivative(this.sameDiff, arg(), list.get(0)).outputs() : new LogSoftMaxDerivative(this.sameDiff, arg(), list.get(0), this.dimension.intValue()).outputs();
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public List<DataType> calculateOutputDataTypes(List<DataType> list) {
        Preconditions.checkState(list != null && list.size() == 1, "Expected 1 input datatype for %s, got %s", getClass(), list);
        return list.get(0).isFPType() ? Collections.singletonList(list.get(0)) : Collections.singletonList(Nd4j.defaultFloatingPointType());
    }
}
