package org.jetbrains.kotlinx.dl.api.core.optimizer;

import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import kotlin.Metadata;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.kotlinx.dl.api.core.KGraph;
import org.jetbrains.kotlinx.dl.api.core.layer.convolutional.Conv1D;
import org.jetbrains.kotlinx.dl.api.core.util.DtypeConversionUtilKt;
import org.jetbrains.kotlinx.dl.api.core.util.NameConventionsKt;
import org.tensorflow.Operand;
import org.tensorflow.Output;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Assign;
import org.tensorflow.op.core.Gradients;
import org.tensorflow.op.core.Variable;

/* compiled from: Optimizer.kt */
@Metadata(mv = {Conv1D.EXTRA_DIM, 7, Conv1D.EXTRA_DIM}, k = Conv1D.EXTRA_DIM, xi = 48, d1 = {"��`\n\u0002\u0018\u0002\n\u0002\u0010��\n��\n\u0002\u0018\u0002\n\u0002\b\u0004\n\u0002\u0010\u000b\n\u0002\b\u0003\n\u0002\u0010\u000e\n\u0002\b\u0003\n\u0002\u0010%\n\u0002\u0018\u0002\n\u0002\u0010\u0007\n��\n\u0002\u0010 \n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\u0004\n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0010\u0002\n\u0002\b\t\b&\u0018��2\u00020\u0001B\r\u0012\u0006\u0010\u0002\u001a\u00020\u0003¢\u0006\u0002\u0010\u0004J@\u0010\u0013\u001a\u000e\u0012\n\u0012\b\u0012\u0004\u0012\u00020\u00120\u00150\u00142\u0006\u0010\u0016\u001a\u00020\u00172\u0006\u0010\u0018\u001a\u00020\u00192\u0012\u0010\u001a\u001a\u000e\u0012\n\u0012\b\u0012\u0004\u0012\u00020\u00120\u00110\u00142\u0006\u0010\u001b\u001a\u00020\u001cH$J2\u0010\u001d\u001a\u00020\u001c2\u0006\u0010\u0018\u001a\u00020\u00192\f\u0010\u001e\u001a\b\u0012\u0004\u0012\u00020\u00120\u00152\u0012\u0010\u001a\u001a\u000e\u0012\n\u0012\b\u0012\u0004\u0012\u00020\u00120\u00110\u0014H\u0002J#\u0010\u001f\u001a\u00020\f2\f\u0010 \u001a\b\u0012\u0004\u0012\u00020\u00120!2\u0006\u0010\"\u001a\u00020\fH\u0010¢\u0006\u0002\b#J<\u0010$\u001a\u00020%2\u0006\u0010\u0016\u001a\u00020\u00172\u0006\u0010\u0018\u001a\u00020\u00192\f\u0010 \u001a\b\u0012\u0004\u0012\u00020\u00120!2\u0006\u0010\"\u001a\u00020\f2\f\u0010&\u001a\b\u0012\u0004\u0012\u00020\u00120\u0015H\u0014J,\u0010'\u001a\u00020%2\u0006\u0010\u0016\u001a\u00020\u00172\u0006\u0010\u0018\u001a\u00020\u00192\u0012\u0010(\u001a\u000e\u0012\n\u0012\b\u0012\u0004\u0012\u00020\u00120!0\u0014H\u0014J\u001e\u0010)\u001a\b\u0012\u0004\u0012\u00020\u00120\u00112\u0006\u0010*\u001a\u00020\f2\u0006\u0010\"\u001a\u00020\fH\u0004JK\u0010+\u001a\u000e\u0012\n\u0012\b\u0012\u0004\u0012\u00020\u00120\u00150\u00142\u0006\u0010\u0016\u001a\u00020\u00172\u0012\u0010\u001a\u001a\u000e\u0012\n\u0012\b\u0012\u0004\u0012\u00020\u00120\u00110\u00142\u0006\u0010\u0018\u001a\u00020\u00192\f\u0010\u001e\u001a\b\u0012\u0004\u0012\u00020\u00120\u0015H��¢\u0006\u0002\b,J(\u0010-\u001a\u000e\u0012\n\u0012\b\u0012\u0004\u0012\u00020\u00120!0\u00142\u0012\u0010(\u001a\u000e\u0012\n\u0012\b\u0012\u0004\u0012\u00020\u00120\u00110\u0014H\u0002R\u0011\u0010\u0002\u001a\u00020\u0003¢\u0006\b\n��\u001a\u0004\b\u0005\u0010\u0006R\u0012\u0010\u0007\u001a\u00020\bX \u0004¢\u0006\u0006\u001a\u0004\b\t\u0010\nR\u0012\u0010\u000b\u001a\u00020\fX¦\u0004¢\u0006\u0006\u001a\u0004\b\r\u0010\u000eR,\u0010\u000f\u001a \u0012\u0004\u0012\u00020\f\u0012\u0016\u0012\u0014\u0012\u0004\u0012\u00020\f\u0012\n\u0012\b\u0012\u0004\u0012\u00020\u00120\u00110\u00100\u0010X\u0082.¢\u0006\u0002\n��¨\u0006."}, d2 = {"Lorg/jetbrains/kotlinx/dl/api/core/optimizer/Optimizer;", "", "clipGradient", "Lorg/jetbrains/kotlinx/dl/api/core/optimizer/ClipGradientAction;", "(Lorg/jetbrains/kotlinx/dl/api/core/optimizer/ClipGradientAction;)V", "getClipGradient", "()Lorg/jetbrains/kotlinx/dl/api/core/optimizer/ClipGradientAction;", "isRunningOnGPU", "", "isRunningOnGPU$tensorflow", "()Z", "optimizerName", "", "getOptimizerName", "()Ljava/lang/String;", "slots", "", "Lorg/tensorflow/op/core/Variable;", "", "applyGradients", "", "Lorg/tensorflow/Operand;", "graph", "Lorg/jetbrains/kotlinx/dl/api/core/KGraph;", "tf", "Lorg/tensorflow/op/Ops;", "weights", "gradients", "Lorg/tensorflow/op/core/Gradients;", "computeGradients", "loss", "createName", "variable", "Lorg/tensorflow/Output;", "slotName", "createName$tensorflow", "createSlot", "", "initializer", "createSlots", "variables", "getSlot", "varName", "prepareTargets", "prepareTargets$tensorflow", "variablesToOutputs", "tensorflow"})
/* loaded from: input_file:org/jetbrains/kotlinx/dl/api/core/optimizer/Optimizer.class */
public abstract class Optimizer {

    @NotNull
    private final ClipGradientAction clipGradient;
    private Map<String, Map<String, Variable<Float>>> slots;

    public Optimizer(@NotNull ClipGradientAction clipGradientAction) {
        Intrinsics.checkNotNullParameter(clipGradientAction, "clipGradient");
        this.clipGradient = clipGradientAction;
    }

    @NotNull
    public final ClipGradientAction getClipGradient() {
        return this.clipGradient;
    }

    @NotNull
    public final List<Operand<Float>> prepareTargets$tensorflow(@NotNull KGraph kGraph, @NotNull List<Variable<Float>> list, @NotNull Ops ops, @NotNull Operand<Float> operand) {
        Intrinsics.checkNotNullParameter(kGraph, "graph");
        Intrinsics.checkNotNullParameter(list, "weights");
        Intrinsics.checkNotNullParameter(ops, "tf");
        Intrinsics.checkNotNullParameter(operand, "loss");
        this.slots = new LinkedHashMap();
        Gradients computeGradients = computeGradients(ops, operand, list);
        createSlots(kGraph, ops, variablesToOutputs(list));
        return applyGradients(kGraph, ops, list, computeGradients);
    }

    private final List<Output<Float>> variablesToOutputs(List<Variable<Float>> list) {
        ArrayList arrayList = new ArrayList();
        int size = list.size();
        for (int i = 0; i < size; i++) {
            Output asOutput = list.get(i).asOutput();
            Intrinsics.checkNotNullExpressionValue(asOutput, "variables[i].asOutput()");
            arrayList.add(i, asOutput);
        }
        return arrayList;
    }

    @NotNull
    protected abstract List<Operand<Float>> applyGradients(@NotNull KGraph kGraph, @NotNull Ops ops, @NotNull List<Variable<Float>> list, @NotNull Gradients gradients);

    private final Gradients computeGradients(Ops ops, Operand<Float> operand, List<Variable<Float>> list) {
        Gradients gradients = ops.gradients(operand, list, new Gradients.Options[0]);
        Intrinsics.checkNotNullExpressionValue(gradients, "tf.gradients(loss, weights)");
        return gradients;
    }

    protected void createSlots(@NotNull KGraph kGraph, @NotNull Ops ops, @NotNull List<Output<Float>> list) {
        Intrinsics.checkNotNullParameter(kGraph, "graph");
        Intrinsics.checkNotNullParameter(ops, "tf");
        Intrinsics.checkNotNullParameter(list, "variables");
    }

    @NotNull
    public abstract String getOptimizerName();

    /* JADX INFO: Access modifiers changed from: protected */
    public void createSlot(@NotNull KGraph kGraph, @NotNull Ops ops, @NotNull Output<Float> output, @NotNull String str, @NotNull Operand<Float> operand) {
        Intrinsics.checkNotNullParameter(kGraph, "graph");
        Intrinsics.checkNotNullParameter(ops, "tf");
        Intrinsics.checkNotNullParameter(output, "variable");
        Intrinsics.checkNotNullParameter(str, "slotName");
        Intrinsics.checkNotNullParameter(operand, "initializer");
        Variable<Float> variable = ops.withName(createName$tensorflow(output, str)).variable(output.shape(), DtypeConversionUtilKt.getDType(), new Variable.Options[0]);
        Intrinsics.checkNotNullExpressionValue(variable, "tf.withName(createName).…able.shape(), getDType())");
        Assign<?> assign = ops.withName(NameConventionsKt.defaultAssignOpName(createName$tensorflow(output, str))).assign((Operand) variable, operand, new Assign.Options[0]);
        Intrinsics.checkNotNullExpressionValue(assign, "tf.withName(assignName).assign(slot, initializer)");
        kGraph.addOptimizerVariableInitializer(assign);
        kGraph.addOptimizerVariable(variable);
        String name = output.op().name();
        Map<String, Map<String, Variable<Float>>> map = this.slots;
        if (map == null) {
            Intrinsics.throwUninitializedPropertyAccessException("slots");
            map = null;
        }
        Map<String, Variable<Float>> computeIfAbsent = map.computeIfAbsent(str, Optimizer::m69createSlot$lambda0);
        Intrinsics.checkNotNullExpressionValue(computeIfAbsent, "slots.computeIfAbsent(slotName) { mutableMapOf() }");
        Intrinsics.checkNotNullExpressionValue(name, "varName");
        computeIfAbsent.put(name, variable);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @NotNull
    public final Variable<Float> getSlot(@NotNull String str, @NotNull String str2) {
        Intrinsics.checkNotNullParameter(str, "varName");
        Intrinsics.checkNotNullParameter(str2, "slotName");
        Map<String, Map<String, Variable<Float>>> map = this.slots;
        if (map == null) {
            Intrinsics.throwUninitializedPropertyAccessException("slots");
            map = null;
        }
        Map<String, Variable<Float>> map2 = map.get(str2);
        Intrinsics.checkNotNull(map2);
        Variable<Float> variable = map2.get(str);
        Intrinsics.checkNotNull(variable);
        return variable;
    }

    @NotNull
    public String createName$tensorflow(@NotNull Output<Float> output, @NotNull String str) {
        Intrinsics.checkNotNullParameter(output, "variable");
        Intrinsics.checkNotNullParameter(str, "slotName");
        return NameConventionsKt.defaultOptimizerVariableName(output.op().name() + '-' + str);
    }

    public abstract boolean isRunningOnGPU$tensorflow();

    /* renamed from: createSlot$lambda-0, reason: not valid java name */
    private static final Map m69createSlot$lambda0(String str) {
        Intrinsics.checkNotNullParameter(str, "it");
        return new LinkedHashMap();
    }
}
