package org.jetbrains.kotlinx.dl.api.core.optimizer;

import java.util.ArrayList;
import java.util.List;
import kotlin.Metadata;
import kotlin.jvm.internal.DefaultConstructorMarker;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.kotlinx.dl.api.core.KGraph;
import org.jetbrains.kotlinx.dl.api.core.layer.convolutional.Conv1D;
import org.jetbrains.kotlinx.dl.api.core.util.DtypeConversionUtilKt;
import org.tensorflow.Operand;
import org.tensorflow.Output;
import org.tensorflow.op.Ops;
import org.tensorflow.op.TrainOps;
import org.tensorflow.op.core.Gradients;
import org.tensorflow.op.core.Variable;
import org.tensorflow.op.train.ApplyGradientDescent;

/* compiled from: SGD.kt */
@Metadata(mv = {Conv1D.EXTRA_DIM, 7, Conv1D.EXTRA_DIM}, k = Conv1D.EXTRA_DIM, xi = 48, d1 = {"��J\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0010\u0007\n��\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0010\u000b\n\u0002\b\u0003\n\u0002\u0010\u000e\n\u0002\b\u0003\n\u0002\u0010 \n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\u0018��2\u00020\u0001B\u0019\u0012\b\b\u0002\u0010\u0002\u001a\u00020\u0003\u0012\b\b\u0002\u0010\u0004\u001a\u00020\u0005¢\u0006\u0002\u0010\u0006J@\u0010\u000f\u001a\u000e\u0012\n\u0012\b\u0012\u0004\u0012\u00020\u00030\u00110\u00102\u0006\u0010\u0012\u001a\u00020\u00132\u0006\u0010\u0014\u001a\u00020\u00152\u0012\u0010\u0016\u001a\u000e\u0012\n\u0012\b\u0012\u0004\u0012\u00020\u00030\u00170\u00102\u0006\u0010\u0018\u001a\u00020\u0019H\u0014R\u0014\u0010\u0007\u001a\u00020\b8PX\u0090\u0004¢\u0006\u0006\u001a\u0004\b\t\u0010\nR\u000e\u0010\u0002\u001a\u00020\u0003X\u0082\u000e¢\u0006\u0002\n��R\u0014\u0010\u000b\u001a\u00020\f8VX\u0096\u0004¢\u0006\u0006\u001a\u0004\b\r\u0010\u000e¨\u0006\u001a"}, d2 = {"Lorg/jetbrains/kotlinx/dl/api/core/optimizer/SGD;", "Lorg/jetbrains/kotlinx/dl/api/core/optimizer/Optimizer;", "learningRate", "", "clipGradient", "Lorg/jetbrains/kotlinx/dl/api/core/optimizer/ClipGradientAction;", "(FLorg/jetbrains/kotlinx/dl/api/core/optimizer/ClipGradientAction;)V", "isRunningOnGPU", "", "isRunningOnGPU$tensorflow", "()Z", "optimizerName", "", "getOptimizerName", "()Ljava/lang/String;", "applyGradients", "", "Lorg/tensorflow/Operand;", "graph", "Lorg/jetbrains/kotlinx/dl/api/core/KGraph;", "tf", "Lorg/tensorflow/op/Ops;", "weights", "Lorg/tensorflow/op/core/Variable;", "gradients", "Lorg/tensorflow/op/core/Gradients;", "tensorflow"})
/* loaded from: input_file:org/jetbrains/kotlinx/dl/api/core/optimizer/SGD.class */
public final class SGD extends Optimizer {
    private float learningRate;

    /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
    public SGD(float f, @NotNull ClipGradientAction clipGradientAction) {
        super(clipGradientAction);
        Intrinsics.checkNotNullParameter(clipGradientAction, "clipGradient");
        this.learningRate = f;
        if (!(this.learningRate >= 0.0f)) {
            throw new IllegalArgumentException(("Learning rate " + this.learningRate + " should be >= 0.0.").toString());
        }
    }

    public /* synthetic */ SGD(float f, ClipGradientAction clipGradientAction, int i, DefaultConstructorMarker defaultConstructorMarker) {
        this((i & 1) != 0 ? 0.2f : f, (i & 2) != 0 ? new NoClipGradient() : clipGradientAction);
    }

    @Override // org.jetbrains.kotlinx.dl.api.core.optimizer.Optimizer
    @NotNull
    protected List<Operand<Float>> applyGradients(@NotNull KGraph kGraph, @NotNull Ops ops, @NotNull List<Variable<Float>> list, @NotNull Gradients gradients) {
        Intrinsics.checkNotNullParameter(kGraph, "graph");
        Intrinsics.checkNotNullParameter(ops, "tf");
        Intrinsics.checkNotNullParameter(list, "weights");
        Intrinsics.checkNotNullParameter(gradients, "gradients");
        ArrayList arrayList = new ArrayList();
        int size = list.size();
        for (int i = 0; i < size; i++) {
            TrainOps trainOps = ops.train;
            Operand operand = list.get(i);
            Operand constant = ops.constant(Float.valueOf(this.learningRate), DtypeConversionUtilKt.getDType());
            ClipGradientAction clipGradient = getClipGradient();
            Output dy = gradients.dy(i);
            Intrinsics.checkNotNullExpressionValue(dy, "gradients.dy(i)");
            ApplyGradientDescent applyGradientDescent = trainOps.applyGradientDescent(operand, constant, clipGradient.clipGradient(ops, (Operand) dy), new ApplyGradientDescent.Options[]{ApplyGradientDescent.useLocking(true)});
            Intrinsics.checkNotNullExpressionValue(applyGradientDescent, "tf.train.applyGradientDe…g(true)\n                )");
            arrayList.add(applyGradientDescent);
        }
        return arrayList;
    }

    @Override // org.jetbrains.kotlinx.dl.api.core.optimizer.Optimizer
    @NotNull
    public String getOptimizerName() {
        return "SGD";
    }

    @Override // org.jetbrains.kotlinx.dl.api.core.optimizer.Optimizer
    public boolean isRunningOnGPU$tensorflow() {
        return true;
    }

    public SGD() {
        this(0.0f, null, 3, null);
    }
}
