package org.jetbrains.kotlinx.dl.api.core.loss;

import kotlin.Metadata;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import org.jetbrains.kotlinx.dl.api.core.layer.convolutional.Conv1D;
import org.tensorflow.Operand;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.ReduceSum;
import org.tensorflow.op.math.Mean;

/* compiled from: Losses.kt */
@Metadata(mv = {Conv1D.EXTRA_DIM, 7, Conv1D.EXTRA_DIM}, k = 2, xi = 48, d1 = {"��\"\n��\n\u0002\u0018\u0002\n\u0002\u0010\b\n��\n\u0002\u0018\u0002\n��\n\u0002\u0010\u0007\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\u0005\u001a$\u0010��\u001a\b\u0012\u0004\u0012\u00020\u00020\u00012\u0006\u0010\u0003\u001a\u00020\u00042\f\u0010\u0005\u001a\b\u0012\u0004\u0012\u00020\u00060\u0001H��\u001a<\u0010\u0007\u001a\b\u0012\u0004\u0012\u00020\u00060\u00012\u0006\u0010\u0003\u001a\u00020\u00042\u0006\u0010\b\u001a\u00020\t2\f\u0010\n\u001a\b\u0012\u0004\u0012\u00020\u00060\u00012\u000e\u0010\u000b\u001a\n\u0012\u0004\u0012\u00020\u0006\u0018\u00010\u0001H��\u001a2\u0010\f\u001a\b\u0012\u0004\u0012\u00020\u00060\u00012\u0006\u0010\u0003\u001a\u00020\u00042\f\u0010\n\u001a\b\u0012\u0004\u0012\u00020\u00060\u00012\f\u0010\r\u001a\b\u0012\u0004\u0012\u00020\u00060\u0001H��¨\u0006\u000e"}, d2 = {"allAxes", "Lorg/tensorflow/Operand;", "", "tf", "Lorg/tensorflow/op/Ops;", "op", "", "meanOfLosses", "reductionType", "Lorg/jetbrains/kotlinx/dl/api/core/loss/ReductionType;", "loss", "numberOfLosses", "safeMean", "numElements", "tensorflow"})
/* loaded from: input_file:org/jetbrains/kotlinx/dl/api/core/loss/LossesKt.class */
public final class LossesKt {
    @NotNull
    public static final Operand<Float> meanOfLosses(@NotNull Ops ops, @NotNull ReductionType reductionType, @NotNull Operand<Float> operand, @Nullable Operand<Float> operand2) {
        Intrinsics.checkNotNullParameter(ops, "tf");
        Intrinsics.checkNotNullParameter(reductionType, "reductionType");
        Intrinsics.checkNotNullParameter(operand, "loss");
        Operand mean = ops.math.mean(operand, ops.constant(-1), new Mean.Options[]{Mean.keepDims(false)});
        Intrinsics.checkNotNullExpressionValue(mean, "meanLoss");
        Operand<Float> reduceSum = ops.reduceSum(mean, allAxes(ops, mean), new ReduceSum.Options[]{ReduceSum.keepDims(false)});
        Intrinsics.checkNotNullExpressionValue(reduceSum, "tf.reduceSum(\n        me…Sum.keepDims(false)\n    )");
        Operand<Float> operand3 = reduceSum;
        if (reductionType == ReductionType.SUM_OVER_BATCH_SIZE) {
            if (!(operand2 != null)) {
                throw new IllegalStateException("Operand numberOfLosses must be not null.".toString());
            }
            operand3 = safeMean(ops, operand, operand2);
        }
        return operand3;
    }

    @NotNull
    public static final Operand<Float> safeMean(@NotNull Ops ops, @NotNull Operand<Float> operand, @NotNull Operand<Float> operand2) {
        Intrinsics.checkNotNullParameter(ops, "tf");
        Intrinsics.checkNotNullParameter(operand, "loss");
        Intrinsics.checkNotNullParameter(operand2, "numElements");
        Operand reduceSum = ops.reduceSum(operand, allAxes(ops, operand), new ReduceSum.Options[0]);
        Intrinsics.checkNotNullExpressionValue(reduceSum, "tf.reduceSum(loss, allAxes(tf, loss))");
        Operand<Float> divNoNan = ops.math.divNoNan(reduceSum, operand2);
        Intrinsics.checkNotNullExpressionValue(divNoNan, "tf.math.divNoNan(totalLoss, numElements)");
        return divNoNan;
    }

    @NotNull
    public static final Operand<Integer> allAxes(@NotNull Ops ops, @NotNull Operand<Float> operand) {
        Intrinsics.checkNotNullParameter(ops, "tf");
        Intrinsics.checkNotNullParameter(operand, "op");
        int numDimensions = operand.asOutput().shape().numDimensions();
        if (numDimensions == -1) {
            Operand<Integer> range = ops.range(ops.constant(0), ops.rank(operand), ops.constant(1));
            Intrinsics.checkNotNullExpressionValue(range, "{\n        tf.range(tf.co…p), tf.constant(1))\n    }");
            return range;
        }
        int[] iArr = new int[numDimensions];
        for (int i = 0; i < numDimensions; i++) {
            iArr[i] = i;
        }
        Operand<Integer> constant = ops.constant(iArr);
        Intrinsics.checkNotNullExpressionValue(constant, "{\n        val axes = Int…  tf.constant(axes)\n    }");
        return constant;
    }
}
