package org.tensorflow.framework.losses;

import org.tensorflow.Operand;
import org.tensorflow.framework.losses.impl.AbstractLoss;
import org.tensorflow.framework.losses.impl.LossesHelper;
import org.tensorflow.framework.utils.CastHelper;
import org.tensorflow.op.Ops;
import org.tensorflow.types.family.TNumber;

/* loaded from: input_file:org/tensorflow/framework/losses/SparseCategoricalCrossentropy.class */
public class SparseCategoricalCrossentropy extends AbstractLoss {
    public static final boolean FROM_LOGITS_DEFAULT = false;
    public static final int AXIS_DEFAULT = -1;
    private final boolean fromLogits;
    private final int axis;

    public SparseCategoricalCrossentropy() {
        this(null, false, REDUCTION_DEFAULT, -1);
    }

    public SparseCategoricalCrossentropy(String str) {
        this(str, false, REDUCTION_DEFAULT, -1);
    }

    public SparseCategoricalCrossentropy(Reduction reduction) {
        this(null, false, reduction, -1);
    }

    public SparseCategoricalCrossentropy(String str, Reduction reduction) {
        this(str, false, reduction, -1);
    }

    public SparseCategoricalCrossentropy(String str, boolean z) {
        this(str, z, REDUCTION_DEFAULT, -1);
    }

    public SparseCategoricalCrossentropy(boolean z) {
        this(null, z, REDUCTION_DEFAULT, -1);
    }

    public SparseCategoricalCrossentropy(boolean z, Reduction reduction) {
        this(null, z, reduction, -1);
    }

    public SparseCategoricalCrossentropy(String str, boolean z, Reduction reduction, int i) {
        super(str, reduction);
        this.fromLogits = z;
        this.axis = i;
    }

    @Override // org.tensorflow.framework.losses.Loss
    public <T extends TNumber> Operand<T> call(Ops ops, Operand<? extends TNumber> operand, Operand<T> operand2, Operand<T> operand3) {
        return LossesHelper.computeWeightedLoss(ops, Losses.sparseCategoricalCrossentropy(ops, operand, !this.fromLogits ? LossesHelper.rangeCheck(ops, "predictions range check [0-1]", operand2, CastHelper.cast(ops, ops.constant(0), operand2.type()), CastHelper.cast(ops, ops.constant(1), operand2.type())) : operand2, this.fromLogits, this.axis), getReduction(), operand3);
    }
}
