package org.jetbrains.kotlinx.dl.api.core.initializer;

import kotlin.Metadata;
import kotlin.jvm.internal.DefaultConstructorMarker;
import kotlin.jvm.internal.Intrinsics;
import kotlin.jvm.internal.SourceDebugExtension;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.kotlinx.dl.api.core.layer.convolutional.Conv1D;
import org.jetbrains.kotlinx.dl.api.core.shape.ShapeFunctionsKt;
import org.jetbrains.kotlinx.dl.api.core.util.DtypeConversionUtilKt;
import org.tensorflow.Operand;
import org.tensorflow.Shape;
import org.tensorflow.op.Ops;
import org.tensorflow.op.dtypes.Cast;
import org.tensorflow.op.linalg.Qr;
import org.tensorflow.op.linalg.Transpose;

/* compiled from: Orthogonal.kt */
@Metadata(mv = {Conv1D.EXTRA_DIM, 8, 0}, k = Conv1D.EXTRA_DIM, xi = 48, d1 = {"��6\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0010\u0007\n��\n\u0002\u0010\t\n\u0002\b\u0006\n\u0002\u0018\u0002\n��\n\u0002\u0010\b\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0010\u000e\n\u0002\b\u0002\u0018��2\u00020\u0001B\u0019\u0012\b\b\u0002\u0010\u0002\u001a\u00020\u0003\u0012\b\b\u0002\u0010\u0004\u001a\u00020\u0005¢\u0006\u0002\u0010\u0006J<\u0010\u000b\u001a\b\u0012\u0004\u0012\u00020\u00030\f2\u0006\u0010\r\u001a\u00020\u000e2\u0006\u0010\u000f\u001a\u00020\u000e2\u0006\u0010\u0010\u001a\u00020\u00112\f\u0010\u0012\u001a\b\u0012\u0004\u0012\u00020\u000e0\f2\u0006\u0010\u0013\u001a\u00020\u0014H\u0016J\b\u0010\u0015\u001a\u00020\u0014H\u0016R\u0011\u0010\u0002\u001a\u00020\u0003¢\u0006\b\n��\u001a\u0004\b\u0007\u0010\bR\u0011\u0010\u0004\u001a\u00020\u0005¢\u0006\b\n��\u001a\u0004\b\t\u0010\n¨\u0006\u0016"}, d2 = {"Lorg/jetbrains/kotlinx/dl/api/core/initializer/Orthogonal;", "Lorg/jetbrains/kotlinx/dl/api/core/initializer/Initializer;", "gain", "", "seed", "", "(FJ)V", "getGain", "()F", "getSeed", "()J", "initialize", "Lorg/tensorflow/Operand;", "fanIn", "", "fanOut", "tf", "Lorg/tensorflow/op/Ops;", "shape", "name", "", "toString", "tensorflow"})
@SourceDebugExtension({"SMAP\nOrthogonal.kt\nKotlin\n*S Kotlin\n*F\n+ 1 Orthogonal.kt\norg/jetbrains/kotlinx/dl/api/core/initializer/Orthogonal\n+ 2 fake.kt\nkotlin/jvm/internal/FakeKt\n*L\n1#1,75:1\n1#2:76\n*E\n"})
/* loaded from: input_file:org/jetbrains/kotlinx/dl/api/core/initializer/Orthogonal.class */
public final class Orthogonal extends Initializer {
    private final float gain;
    private final long seed;

    public Orthogonal(float f, long j) {
        this.gain = f;
        this.seed = j;
    }

    public /* synthetic */ Orthogonal(float f, long j, int i, DefaultConstructorMarker defaultConstructorMarker) {
        this((i & 1) != 0 ? 1.0f : f, (i & 2) != 0 ? 12L : j);
    }

    public final float getGain() {
        return this.gain;
    }

    public final long getSeed() {
        return this.seed;
    }

    @Override // org.jetbrains.kotlinx.dl.api.core.initializer.Initializer
    @NotNull
    public Operand<Float> initialize(int i, int i2, @NotNull Ops ops, @NotNull Operand<Integer> operand, @NotNull String str) {
        Intrinsics.checkNotNullParameter(ops, "tf");
        Intrinsics.checkNotNullParameter(operand, "shape");
        Intrinsics.checkNotNullParameter(str, "name");
        long size = operand.asOutput().shape().size(0);
        if (!(size >= 2)) {
            throw new IllegalArgumentException("The tensor to initialize must be at least two-dimensional".toString());
        }
        Operand statelessRandomNormal = ops.random.statelessRandomNormal(operand, ops.constant(new long[]{this.seed, 0}), DtypeConversionUtilKt.getDType());
        Intrinsics.checkNotNullExpressionValue(statelessRandomNormal, "tf.random.statelessRando…L)), getDType()\n        )");
        Operand operand2 = statelessRandomNormal;
        long j = 1;
        int i3 = 0;
        while (i3 < size - 1) {
            j *= operand2.asOutput().shape().size(i3);
            i3++;
        }
        long size2 = operand2.asOutput().shape().size(i3 - 1);
        Shape make = Shape.make(Math.max(j, size2), new long[]{Math.min(j, size2)});
        Intrinsics.checkNotNullExpressionValue(make, "flatShape");
        Operand reshape = ops.reshape(operand2, ShapeFunctionsKt.shapeOperand(ops, make));
        Intrinsics.checkNotNullExpressionValue(reshape, "tf.reshape(distOpND, shapeOperand(tf, flatShape))");
        Qr qr = ops.linalg.qr(reshape, new Qr.Options[]{Qr.fullMatrices(false)});
        Intrinsics.checkNotNullExpressionValue(qr, "tf.linalg.qr(distOp, qrOptions)");
        Operand q = qr.q();
        Intrinsics.checkNotNullExpressionValue(q, "qrOp.q()");
        Operand operand3 = q;
        Operand r = qr.r();
        Intrinsics.checkNotNullExpressionValue(r, "qrOp.r()");
        Operand tensorDiagPart = ops.linalg.tensorDiagPart(r);
        Intrinsics.checkNotNullExpressionValue(tensorDiagPart, "tf.linalg.tensorDiagPart(ro)");
        Operand mul = ops.withName(str).math.mul(operand3, ops.math.sign(tensorDiagPart));
        Intrinsics.checkNotNullExpressionValue(mul, "tf.withName(name).math.mul(qo, tf.math.sign(d))");
        Operand operand4 = mul;
        if (j < size2) {
            Transpose transpose = ops.withName(str).linalg.transpose(operand4, ops.constant(new int[]{1, 0}));
            Intrinsics.checkNotNullExpressionValue(transpose, "tf.withName(name).linalg…nstant(intArrayOf(1, 0)))");
            operand4 = (Operand) transpose;
        }
        Operand<Float> mul2 = ops.math.mul(ops.reshape(operand4, operand), ops.dtypes.cast(ops.constant(this.gain), DtypeConversionUtilKt.getDType(), new Cast.Options[0]));
        Intrinsics.checkNotNullExpressionValue(mul2, "tf.math.mul(tf.reshape(q…(this.gain), getDType()))");
        return mul2;
    }

    @NotNull
    public String toString() {
        return "Orthogonal(gain=" + this.gain + ", seed=" + this.seed + ')';
    }

    public Orthogonal() {
        this(0.0f, 0L, 3, null);
    }
}
