package org.jetbrains.kotlinx.dl.api.core.metric;

import kotlin.Metadata;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import org.jetbrains.kotlinx.dl.api.core.layer.convolutional.Conv1D;
import org.jetbrains.kotlinx.dl.api.core.loss.ReductionType;
import org.jetbrains.kotlinx.dl.api.core.util.DtypeConversionUtilKt;
import org.tensorflow.Operand;
import org.tensorflow.op.Ops;
import org.tensorflow.op.dtypes.Cast;
import org.tensorflow.op.math.Equal;
import org.tensorflow.op.math.Mean;

/* compiled from: Metric.kt */
@Metadata(mv = {Conv1D.EXTRA_DIM, 7, Conv1D.EXTRA_DIM}, k = Conv1D.EXTRA_DIM, xi = 48, d1 = {"��\u001e\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\u0010\u0007\n��\n\u0002\u0018\u0002\n\u0002\b\u0004\u0018��2\u00020\u0001B\u0005¢\u0006\u0002\u0010\u0002JB\u0010\u0003\u001a\b\u0012\u0004\u0012\u00020\u00050\u00042\u0006\u0010\u0006\u001a\u00020\u00072\f\u0010\b\u001a\b\u0012\u0004\u0012\u00020\u00050\u00042\f\u0010\t\u001a\b\u0012\u0004\u0012\u00020\u00050\u00042\u000e\u0010\n\u001a\n\u0012\u0004\u0012\u00020\u0005\u0018\u00010\u0004H\u0016¨\u0006\u000b"}, d2 = {"Lorg/jetbrains/kotlinx/dl/api/core/metric/Accuracy;", "Lorg/jetbrains/kotlinx/dl/api/core/metric/Metric;", "()V", "apply", "Lorg/tensorflow/Operand;", "", "tf", "Lorg/tensorflow/op/Ops;", "yPred", "yTrue", "numberOfLabels", "tensorflow"})
/* loaded from: input_file:org/jetbrains/kotlinx/dl/api/core/metric/Accuracy.class */
public final class Accuracy extends Metric {
    public Accuracy() {
        super(ReductionType.SUM_OVER_BATCH_SIZE);
    }

    @Override // org.jetbrains.kotlinx.dl.api.core.metric.Metric
    @NotNull
    public Operand<Float> apply(@NotNull Ops ops, @NotNull Operand<Float> operand, @NotNull Operand<Float> operand2, @Nullable Operand<Float> operand3) {
        Intrinsics.checkNotNullParameter(ops, "tf");
        Intrinsics.checkNotNullParameter(operand, "yPred");
        Intrinsics.checkNotNullParameter(operand2, "yTrue");
        Operand argMax = ops.math.argMax(operand, ops.constant(1));
        Intrinsics.checkNotNullExpressionValue(argMax, "tf.math.argMax(yPred, tf.constant(1))");
        Operand operand4 = argMax;
        Operand argMax2 = ops.math.argMax(operand2, ops.constant(1));
        Intrinsics.checkNotNullExpressionValue(argMax2, "tf.math.argMax(yTrue, tf.constant(1))");
        Operand<Float> mean = ops.math.mean(ops.dtypes.cast(ops.math.equal(operand4, argMax2, new Equal.Options[0]), DtypeConversionUtilKt.getDType(), new Cast.Options[0]), ops.constant(0), new Mean.Options[0]);
        Intrinsics.checkNotNullExpressionValue(mean, "tf.math.mean(tf.dtypes.c…DType()), tf.constant(0))");
        return mean;
    }
}
