package org.tensorflow.framework.optimizers;

import java.util.Iterator;
import java.util.List;
import org.tensorflow.Graph;
import org.tensorflow.Output;
import org.tensorflow.op.Op;
import org.tensorflow.op.dtypes.Cast;
import org.tensorflow.op.train.ApplyMomentum;
import org.tensorflow.types.family.TType;

/* loaded from: input_file:org/tensorflow/framework/optimizers/Momentum.class */
public class Momentum extends Optimizer {
    public static final float LEARNING_RATE_DEFAULT = 0.01f;
    public static final float MOMENTUM_DEFAULT = 0.0f;
    public static final boolean NESTEROV_DEFAULT = false;
    public static final String MOMENTUM = "momentum";
    private final float learningRate;
    private final float momentum;
    private final boolean useNesterov;

    public Momentum(Graph graph) {
        this(graph, 0.01f, 0.0f, false);
    }

    public Momentum(Graph graph, float f) {
        this(graph, f, 0.0f, false);
    }

    public Momentum(Graph graph, float f, float f2) {
        this(graph, f, f2, false);
    }

    public Momentum(Graph graph, float f, float f2, boolean z) {
        super(graph);
        this.learningRate = f;
        this.momentum = f2;
        this.useNesterov = z;
    }

    public Momentum(Graph graph, String str, float f, float f2, boolean z) {
        super(graph, str);
        this.learningRate = f;
        this.momentum = f2;
        this.useNesterov = z;
    }

    @Override // org.tensorflow.framework.optimizers.Optimizer
    protected void createSlots(List<Output<? extends TType>> list) {
        Iterator<Output<? extends TType>> it = list.iterator();
        while (it.hasNext()) {
            createMomentumSlot(it.next());
        }
    }

    private <T extends TType> void createMomentumSlot(Output<T> output) {
        createSlot(output.asOutput(), "momentum", this.tf.fill(this.tf.shape(output), this.tf.dtypes.cast(this.tf.constant(0.0f), output.type(), new Cast.Options[0])));
    }

    @Override // org.tensorflow.framework.optimizers.Optimizer
    protected <T extends TType> Op applyDense(Output<T> output, Output<T> output2) {
        return this.tf.train.applyMomentum(output2, getSlot(output2, "momentum").get(), this.tf.dtypes.cast(this.tf.constant(this.learningRate), output.type(), new Cast.Options[0]), output, this.tf.dtypes.cast(this.tf.constant(this.momentum), output.type(), new Cast.Options[0]), new ApplyMomentum.Options[]{ApplyMomentum.useNesterov(Boolean.valueOf(this.useNesterov)), ApplyMomentum.useLocking(true)});
    }

    public String toString() {
        return "Momentum{learningRate=" + this.learningRate + ", momentum=" + this.momentum + ", useNesterov=" + this.useNesterov + '}';
    }

    @Override // org.tensorflow.framework.optimizers.Optimizer
    public String getOptimizerName() {
        return "Momentum";
    }
}
