package org.tensorflow.framework.metrics;

import org.tensorflow.Operand;
import org.tensorflow.framework.losses.impl.LossesHelper;
import org.tensorflow.framework.metrics.impl.SensitivitySpecificityBase;
import org.tensorflow.framework.utils.CastHelper;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.ExpandDims;
import org.tensorflow.op.core.Gather;
import org.tensorflow.op.core.ReduceMax;
import org.tensorflow.op.core.Where;
import org.tensorflow.op.math.DivNoNan;
import org.tensorflow.op.math.Greater;
import org.tensorflow.types.family.TNumber;

/* loaded from: input_file:org/tensorflow/framework/metrics/RecallAtPrecision.class */
public class RecallAtPrecision<T extends TNumber> extends SensitivitySpecificityBase<T> {
    private final float precision;

    public RecallAtPrecision(float f, long j, Class<T> cls) {
        this(null, f, 200, j, cls);
    }

    public RecallAtPrecision(String str, float f, long j, Class<T> cls) {
        this(str, f, 200, j, cls);
    }

    public RecallAtPrecision(float f, int i, long j, Class<T> cls) {
        this(null, f, i, j, cls);
    }

    public RecallAtPrecision(String str, float f, int i, long j, Class<T> cls) {
        super(str, i, j, cls);
        if (f < 0.0f || f > 1.0f) {
            throw new IllegalArgumentException("recall must be in the range [0, 1].");
        }
        this.precision = f;
    }

    @Override // org.tensorflow.framework.metrics.Metric
    public <U extends TNumber> Operand<U> result(Ops ops, Class<U> cls) {
        init(ops);
        DivNoNan divNoNan = ops.math.divNoNan(this.truePositives, ops.math.add(this.truePositives, this.falsePositives));
        DivNoNan divNoNan2 = ops.math.divNoNan(this.truePositives, ops.math.add(this.truePositives, this.falseNegatives));
        Where where = ops.where(ops.math.greaterEqual(divNoNan, CastHelper.cast(ops, ops.constant(this.precision), getType())));
        Greater greater = ops.math.greater(ops.size(where), ops.constant(0));
        ExpandDims expandDims = ops.expandDims(ops.gather(divNoNan2, where, ops.constant(0), new Gather.Options[0]), ops.constant(0));
        return CastHelper.cast(ops, ops.select(greater, ops.reduceMax(expandDims, LossesHelper.allAxes(ops, expandDims), new ReduceMax.Options[0]), CastHelper.cast(ops, ops.constant(0), getType())), cls);
    }

    public float getPrecision() {
        return this.precision;
    }
}
