package org.jetbrains.kotlinx.dl.api.core.optimizer;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import kotlin.Metadata;
import kotlin.jvm.internal.DefaultConstructorMarker;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.kotlinx.dl.api.core.KGraph;
import org.jetbrains.kotlinx.dl.api.core.layer.convolutional.Conv1D;
import org.jetbrains.kotlinx.dl.api.core.util.DtypeConversionUtilKt;
import org.jetbrains.kotlinx.dl.api.core.util.NameConventionsKt;
import org.tensorflow.Operand;
import org.tensorflow.Output;
import org.tensorflow.Shape;
import org.tensorflow.op.MathOps;
import org.tensorflow.op.Ops;
import org.tensorflow.op.Scope;
import org.tensorflow.op.core.Assign;
import org.tensorflow.op.core.Constant;
import org.tensorflow.op.core.Fill;
import org.tensorflow.op.core.Gradients;
import org.tensorflow.op.core.Variable;
import org.tensorflow.op.train.ApplyAdaMax;

/* compiled from: Adamax.kt */
@Metadata(mv = {Conv1D.EXTRA_DIM, 7, Conv1D.EXTRA_DIM}, k = Conv1D.EXTRA_DIM, xi = 48, d1 = {"��d\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0010\u0007\n\u0002\b\u0004\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0010\u000b\n\u0002\b\u0004\n\u0002\u0010\u000e\n\u0002\b\u0003\n\u0002\u0010 \n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0010\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0003\u0018��2\u00020\u0001B7\u0012\b\b\u0002\u0010\u0002\u001a\u00020\u0003\u0012\b\b\u0002\u0010\u0004\u001a\u00020\u0003\u0012\b\b\u0002\u0010\u0005\u001a\u00020\u0003\u0012\b\b\u0002\u0010\u0006\u001a\u00020\u0003\u0012\b\b\u0002\u0010\u0007\u001a\u00020\b¢\u0006\u0002\u0010\tJ@\u0010\u0019\u001a\u000e\u0012\n\u0012\b\u0012\u0004\u0012\u00020\u00030\u001b0\u001a2\u0006\u0010\u001c\u001a\u00020\u001d2\u0006\u0010\u001e\u001a\u00020\u001f2\u0012\u0010 \u001a\u000e\u0012\n\u0012\b\u0012\u0004\u0012\u00020\u00030\r0\u001a2\u0006\u0010!\u001a\u00020\"H\u0014J&\u0010#\u001a\u00020$2\u0006\u0010\u001c\u001a\u00020\u001d2\u0006\u0010\u001e\u001a\u00020\u001f2\f\u0010%\u001a\b\u0012\u0004\u0012\u00020\u00030&H\u0002J,\u0010'\u001a\u00020$2\u0006\u0010\u001c\u001a\u00020\u001d2\u0006\u0010\u001e\u001a\u00020\u001f2\u0012\u0010(\u001a\u000e\u0012\n\u0012\b\u0012\u0004\u0012\u00020\u00030&0\u001aH\u0014R\u000e\u0010\u0004\u001a\u00020\u0003X\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\u0005\u001a\u00020\u0003X\u0082\u0004¢\u0006\u0002\n��R\u0014\u0010\n\u001a\b\u0012\u0004\u0012\u00020\u00030\u000bX\u0082.¢\u0006\u0002\n��R\u0014\u0010\f\u001a\b\u0012\u0004\u0012\u00020\u00030\rX\u0082.¢\u0006\u0002\n��R\u0014\u0010\u000e\u001a\b\u0012\u0004\u0012\u00020\u00030\u000bX\u0082.¢\u0006\u0002\n��R\u000e\u0010\u0006\u001a\u00020\u0003X\u0082\u0004¢\u0006\u0002\n��R\u0014\u0010\u000f\u001a\b\u0012\u0004\u0012\u00020\u00030\u000bX\u0082.¢\u0006\u0002\n��R\u0014\u0010\u0010\u001a\u00020\u00118PX\u0090\u0004¢\u0006\u0006\u001a\u0004\b\u0012\u0010\u0013R\u000e\u0010\u0002\u001a\u00020\u0003X\u0082\u0004¢\u0006\u0002\n��R\u0014\u0010\u0014\u001a\b\u0012\u0004\u0012\u00020\u00030\u000bX\u0082.¢\u0006\u0002\n��R\u0014\u0010\u0015\u001a\u00020\u00168VX\u0096\u0004¢\u0006\u0006\u001a\u0004\b\u0017\u0010\u0018¨\u0006)"}, d2 = {"Lorg/jetbrains/kotlinx/dl/api/core/optimizer/Adamax;", "Lorg/jetbrains/kotlinx/dl/api/core/optimizer/Optimizer;", "learningRate", "", "beta1", "beta2", "epsilon", "clipGradient", "Lorg/jetbrains/kotlinx/dl/api/core/optimizer/ClipGradientAction;", "(FFFFLorg/jetbrains/kotlinx/dl/api/core/optimizer/ClipGradientAction;)V", "betaOneConst", "Lorg/tensorflow/op/core/Constant;", "betaOnePower", "Lorg/tensorflow/op/core/Variable;", "betaTwoConst", "epsilonConstant", "isRunningOnGPU", "", "isRunningOnGPU$tensorflow", "()Z", "learningRateConst", "optimizerName", "", "getOptimizerName", "()Ljava/lang/String;", "applyGradients", "", "Lorg/tensorflow/Operand;", "graph", "Lorg/jetbrains/kotlinx/dl/api/core/KGraph;", "tf", "Lorg/tensorflow/op/Ops;", "weights", "gradients", "Lorg/tensorflow/op/core/Gradients;", "createAdamaxSlot", "", "v", "Lorg/tensorflow/Output;", "createSlots", "variables", "tensorflow"})
/* loaded from: input_file:org/jetbrains/kotlinx/dl/api/core/optimizer/Adamax.class */
public final class Adamax extends Optimizer {
    private final float learningRate;
    private final float beta1;
    private final float beta2;
    private final float epsilon;
    private Constant<Float> epsilonConstant;
    private Constant<Float> learningRateConst;
    private Constant<Float> betaOneConst;
    private Constant<Float> betaTwoConst;
    private Variable<Float> betaOnePower;

    /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
    public Adamax(float f, float f2, float f3, float f4, @NotNull ClipGradientAction clipGradientAction) {
        super(clipGradientAction);
        Intrinsics.checkNotNullParameter(clipGradientAction, "clipGradient");
        this.learningRate = f;
        this.beta1 = f2;
        this.beta2 = f3;
        this.epsilon = f4;
        if (!(this.learningRate >= 0.0f)) {
            throw new IllegalArgumentException(("Learning rate " + this.learningRate + " should be >= 0.0.").toString());
        }
        if (!(this.beta1 > 0.0f && this.beta1 < 1.0f)) {
            throw new IllegalArgumentException(("Beta1 " + this.beta1 + " should be in range (0.0; 1.0).").toString());
        }
        if (!(this.beta2 > 0.0f && this.beta2 < 1.0f)) {
            throw new IllegalArgumentException(("Beta2 " + this.beta2 + " should be in range (0.0; 1.0).").toString());
        }
        if (!(this.epsilon >= 0.0f)) {
            throw new IllegalArgumentException(("L2Strength " + this.epsilon + " should be >= 0.0.").toString());
        }
    }

    public /* synthetic */ Adamax(float f, float f2, float f3, float f4, ClipGradientAction clipGradientAction, int i, DefaultConstructorMarker defaultConstructorMarker) {
        this((i & 1) != 0 ? 0.001f : f, (i & 2) != 0 ? 0.9f : f2, (i & 4) != 0 ? 0.999f : f3, (i & 8) != 0 ? 1.0E-7f : f4, (i & 16) != 0 ? new NoClipGradient() : clipGradientAction);
    }

    @Override // org.jetbrains.kotlinx.dl.api.core.optimizer.Optimizer
    @NotNull
    protected List<Operand<Float>> applyGradients(@NotNull KGraph kGraph, @NotNull Ops ops, @NotNull List<Variable<Float>> list, @NotNull Gradients gradients) {
        Intrinsics.checkNotNullParameter(kGraph, "graph");
        Intrinsics.checkNotNullParameter(ops, "tf");
        Intrinsics.checkNotNullParameter(list, "weights");
        Intrinsics.checkNotNullParameter(gradients, "gradients");
        ArrayList arrayList = new ArrayList();
        Constant<Float> constant = ops.constant(Float.valueOf(this.beta1), DtypeConversionUtilKt.getDType());
        Intrinsics.checkNotNullExpressionValue(constant, "tf.constant(beta1, getDType())");
        this.betaOneConst = constant;
        Constant<Float> constant2 = ops.constant(Float.valueOf(this.beta2), DtypeConversionUtilKt.getDType());
        Intrinsics.checkNotNullExpressionValue(constant2, "tf.constant(beta2, getDType())");
        this.betaTwoConst = constant2;
        Constant<Float> constant3 = ops.constant(Float.valueOf(this.learningRate), DtypeConversionUtilKt.getDType());
        Intrinsics.checkNotNullExpressionValue(constant3, "tf.constant(learningRate, getDType())");
        this.learningRateConst = constant3;
        Constant<Float> constant4 = ops.constant(Float.valueOf(this.epsilon), DtypeConversionUtilKt.getDType());
        Intrinsics.checkNotNullExpressionValue(constant4, "tf.constant(epsilon, getDType())");
        this.epsilonConstant = constant4;
        Scope scope = new Scope(kGraph.getTfGraph$tensorflow());
        int size = list.size();
        for (int i = 0; i < size; i++) {
            Operand operand = (Variable) list.get(i);
            String name = operand.ref().op().name();
            Intrinsics.checkNotNullExpressionValue(name, "varName");
            Operand operand2 = operand;
            Operand slot = getSlot(name, "m");
            Operand slot2 = getSlot(name, "v");
            Variable<Float> variable = this.betaOnePower;
            if (variable == null) {
                Intrinsics.throwUninitializedPropertyAccessException("betaOnePower");
                variable = null;
            }
            Operand operand3 = (Operand) variable;
            Constant<Float> constant5 = this.learningRateConst;
            if (constant5 == null) {
                Intrinsics.throwUninitializedPropertyAccessException("learningRateConst");
                constant5 = null;
            }
            Operand operand4 = (Operand) constant5;
            Constant<Float> constant6 = this.betaOneConst;
            if (constant6 == null) {
                Intrinsics.throwUninitializedPropertyAccessException("betaOneConst");
                constant6 = null;
            }
            Operand operand5 = (Operand) constant6;
            Constant<Float> constant7 = this.betaTwoConst;
            if (constant7 == null) {
                Intrinsics.throwUninitializedPropertyAccessException("betaTwoConst");
                constant7 = null;
            }
            Operand operand6 = (Operand) constant7;
            Constant<Float> constant8 = this.epsilonConstant;
            if (constant8 == null) {
                Intrinsics.throwUninitializedPropertyAccessException("epsilonConstant");
                constant8 = null;
            }
            ClipGradientAction clipGradient = getClipGradient();
            Output dy = gradients.dy(i);
            Intrinsics.checkNotNullExpressionValue(dy, "gradients.dy(i)");
            ApplyAdaMax create = ApplyAdaMax.create(scope, operand2, slot, slot2, operand3, operand4, operand5, operand6, (Operand) constant8, clipGradient.clipGradient(ops, (Operand) dy), new ApplyAdaMax.Options[]{ApplyAdaMax.useLocking(true)});
            Intrinsics.checkNotNullExpressionValue(create, "create(\n                …g(true)\n                )");
            arrayList.add(create);
        }
        Variable<Float> variable2 = this.betaOnePower;
        if (variable2 == null) {
            Intrinsics.throwUninitializedPropertyAccessException("betaOnePower");
            variable2 = null;
        }
        Operand operand7 = (Operand) variable2;
        MathOps mathOps = ops.math;
        Variable<Float> variable3 = this.betaOnePower;
        if (variable3 == null) {
            Intrinsics.throwUninitializedPropertyAccessException("betaOnePower");
            variable3 = null;
        }
        Operand operand8 = (Operand) variable3;
        Constant<Float> constant9 = this.betaOneConst;
        if (constant9 == null) {
            Intrinsics.throwUninitializedPropertyAccessException("betaOneConst");
            constant9 = null;
        }
        Assign<?> assign = ops.assign(operand7, mathOps.mul(operand8, (Operand) constant9), new Assign.Options[0]);
        Intrinsics.checkNotNullExpressionValue(assign, "betaOnePowerInit");
        kGraph.addOptimizerVariableInitializer(assign);
        Variable<Float> variable4 = this.betaOnePower;
        if (variable4 == null) {
            Intrinsics.throwUninitializedPropertyAccessException("betaOnePower");
            variable4 = null;
        }
        kGraph.addOptimizerVariable(variable4);
        return arrayList;
    }

    private final void createAdamaxSlot(KGraph kGraph, Ops ops, Output<Float> output) {
        Fill fill = ops.withName(NameConventionsKt.defaultInitializerOpName(createName$tensorflow(output, "m"))).fill(ops.shape((Operand) output), ops.constant(Float.valueOf(0.0f), DtypeConversionUtilKt.getDType()));
        Output<Float> asOutput = output.asOutput();
        Intrinsics.checkNotNullExpressionValue(asOutput, "v.asOutput()");
        Intrinsics.checkNotNullExpressionValue(fill, "firstMomentInitializer");
        createSlot(kGraph, ops, asOutput, "m", (Operand) fill);
        Fill fill2 = ops.withName(NameConventionsKt.defaultInitializerOpName(createName$tensorflow(output, "v"))).fill(ops.shape((Operand) output), ops.constant(Float.valueOf(0.0f), DtypeConversionUtilKt.getDType()));
        Output<Float> asOutput2 = output.asOutput();
        Intrinsics.checkNotNullExpressionValue(asOutput2, "v.asOutput()");
        Intrinsics.checkNotNullExpressionValue(fill2, "secondMomentInitializer");
        createSlot(kGraph, ops, asOutput2, "v", (Operand) fill2);
    }

    @Override // org.jetbrains.kotlinx.dl.api.core.optimizer.Optimizer
    protected void createSlots(@NotNull KGraph kGraph, @NotNull Ops ops, @NotNull List<Output<Float>> list) {
        String str;
        String str2;
        String str3;
        Intrinsics.checkNotNullParameter(kGraph, "graph");
        Intrinsics.checkNotNullParameter(ops, "tf");
        Intrinsics.checkNotNullParameter(list, "variables");
        Iterator<Output<Float>> it = list.iterator();
        while (it.hasNext()) {
            Output<Float> asOutput = it.next().asOutput();
            Intrinsics.checkNotNullExpressionValue(asOutput, "v.asOutput()");
            createAdamaxSlot(kGraph, ops, asOutput);
        }
        str = AdamaxKt.FIRST_BETA_POWER_NAME;
        Variable<Float> variable = ops.withName(str).variable(Shape.scalar(), DtypeConversionUtilKt.getDType(), new Variable.Options[0]);
        Intrinsics.checkNotNullExpressionValue(variable, "tf.withName(FIRST_BETA_P…ape.scalar(), getDType())");
        this.betaOnePower = variable;
        str2 = AdamaxKt.FIRST_BETA_POWER_NAME;
        Ops withName = ops.withName(NameConventionsKt.defaultAssignOpName(str2));
        Variable<Float> variable2 = this.betaOnePower;
        if (variable2 == null) {
            Intrinsics.throwUninitializedPropertyAccessException("betaOnePower");
            variable2 = null;
        }
        str3 = AdamaxKt.FIRST_BETA_POWER_NAME;
        Assign<?> assign = withName.assign((Operand) variable2, ops.withName(NameConventionsKt.defaultInitializerOpName(str3)).constant(Float.valueOf(this.beta1), DtypeConversionUtilKt.getDType()), new Assign.Options[0]);
        Intrinsics.checkNotNullExpressionValue(assign, "tf.withName(betaOnePower…getDType())\n            )");
        kGraph.addOptimizerVariableInitializer(assign);
    }

    @Override // org.jetbrains.kotlinx.dl.api.core.optimizer.Optimizer
    @NotNull
    public String getOptimizerName() {
        return "Adamax";
    }

    @Override // org.jetbrains.kotlinx.dl.api.core.optimizer.Optimizer
    public boolean isRunningOnGPU$tensorflow() {
        return false;
    }

    public Adamax() {
        this(0.0f, 0.0f, 0.0f, 0.0f, null, 31, null);
    }
}
