package org.jetbrains.kotlinx.dl.api.core.layer.merge;

import java.util.ArrayList;
import java.util.List;
import kotlin.Metadata;
import kotlin.NoWhenBranchMatchedException;
import kotlin.collections.CollectionsKt;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import org.jetbrains.kotlinx.dl.api.core.layer.convolutional.Conv1D;
import org.tensorflow.Operand;
import org.tensorflow.op.Scope;
import org.tensorflow.op.core.Concat;
import org.tensorflow.op.core.Constant;
import org.tensorflow.op.core.ExpandDims;
import org.tensorflow.op.core.ReduceSum;
import org.tensorflow.op.core.Reshape;
import org.tensorflow.op.core.Shape;
import org.tensorflow.op.core.Squeeze;
import org.tensorflow.op.linalg.MatMul;
import org.tensorflow.op.linalg.Transpose;
import org.tensorflow.op.math.Mul;
import org.tensorflow.op.math.Rsqrt;
import org.tensorflow.op.math.Square;

/* compiled from: Dot.kt */
@Metadata(mv = {Conv1D.EXTRA_DIM, 7, Conv1D.EXTRA_DIM}, k = 2, xi = 48, d1 = {"��\u001c\n��\n\u0002\u0018\u0002\n\u0002\u0010\u0007\n��\n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0010\u0015\n\u0002\b\u0002\u001a:\u0010��\u001a\b\u0012\u0004\u0012\u00020\u00020\u00012\b\u0010\u0003\u001a\u0004\u0018\u00010\u00042\f\u0010\u0005\u001a\b\u0012\u0004\u0012\u00020\u00020\u00012\f\u0010\u0006\u001a\b\u0012\u0004\u0012\u00020\u00020\u00012\u0006\u0010\u0007\u001a\u00020\b\u001a.\u0010\t\u001a\b\u0012\u0004\u0012\u00020\u00020\u00012\b\u0010\u0003\u001a\u0004\u0018\u00010\u00042\f\u0010\u0005\u001a\b\u0012\u0004\u0012\u00020\u00020\u00012\b\u0010\u0007\u001a\u0004\u0018\u00010\b¨\u0006\n"}, d2 = {"batchDot", "Lorg/tensorflow/Operand;", "", "scope", "Lorg/tensorflow/op/Scope;", "x", "y", "axis", "", "l2Normalize", "tensorflow"})
/* loaded from: input_file:org/jetbrains/kotlinx/dl/api/core/layer/merge/DotKt.class */
public final class DotKt {
    @NotNull
    public static final Operand<Float> l2Normalize(@Nullable Scope scope, @NotNull Operand<Float> operand, @Nullable int[] iArr) {
        Intrinsics.checkNotNullParameter(operand, "x");
        Operand create = ReduceSum.create(scope, Square.create(scope, operand), Constant.create(scope, iArr), new ReduceSum.Options[]{ReduceSum.keepDims(true)});
        Intrinsics.checkNotNullExpressionValue(create, "create(\n        scope,\n …eSum.keepDims(true)\n    )");
        Operand create2 = Rsqrt.create(scope, org.tensorflow.op.math.Maximum.create(scope, create, Constant.create(scope, 1.0E-12f)));
        Intrinsics.checkNotNullExpressionValue(create2, "create(\n        scope,\n …, 1e-12f)\n        )\n    )");
        Operand<Float> create3 = Mul.create(scope, operand, create2);
        Intrinsics.checkNotNullExpressionValue(create3, "create(scope, x, invNorm)");
        return create3;
    }

    @NotNull
    public static final Operand<Float> batchDot(@Nullable Scope scope, @NotNull Operand<Float> operand, @NotNull Operand<Float> operand2, @NotNull int[] iArr) {
        int i;
        Boolean bool;
        Boolean bool2;
        Operand<Float> operand3;
        Operand<Float> operand4;
        Intrinsics.checkNotNullParameter(operand, "x");
        Intrinsics.checkNotNullParameter(operand2, "y");
        Intrinsics.checkNotNullParameter(iArr, "axis");
        int numDimensions = operand.asOutput().shape().numDimensions();
        int numDimensions2 = operand2.asOutput().shape().numDimensions();
        Operand<Float> operand5 = operand;
        Operand<Float> operand6 = operand2;
        if (numDimensions > numDimensions2) {
            i = numDimensions - numDimensions2;
            List listOf = CollectionsKt.listOf(Shape.create(scope, operand2));
            ArrayList arrayList = new ArrayList(i);
            for (int i2 = 0; i2 < i; i2++) {
                arrayList.add(Constant.create(scope, 1));
            }
            Reshape create = Reshape.create(scope, operand2, Concat.create(scope, CollectionsKt.plus(listOf, arrayList), Constant.create(scope, 0)));
            Intrinsics.checkNotNullExpressionValue(create, "create(\n            scop…\n            ),\n        )");
            operand6 = (Operand) create;
        } else if (numDimensions2 > numDimensions) {
            i = numDimensions2 - numDimensions;
            List listOf2 = CollectionsKt.listOf(Shape.create(scope, operand));
            ArrayList arrayList2 = new ArrayList(i);
            for (int i3 = 0; i3 < i; i3++) {
                arrayList2.add(Constant.create(scope, 1));
            }
            Reshape create2 = Reshape.create(scope, operand, Concat.create(scope, CollectionsKt.plus(listOf2, arrayList2), Constant.create(scope, 0)));
            Intrinsics.checkNotNullExpressionValue(create2, "create(\n            scop…\n            ),\n        )");
            operand5 = (Operand) create2;
        } else {
            i = 0;
        }
        int numDimensions3 = operand5.asOutput().shape().numDimensions();
        int numDimensions4 = operand6.asOutput().shape().numDimensions();
        if (numDimensions3 == 2 && numDimensions4 == 2) {
            if (iArr[0] == iArr[1]) {
                ReduceSum create3 = ReduceSum.create(scope, Mul.create(scope, operand5, operand6), Constant.create(scope, iArr[0]), new ReduceSum.Options[0]);
                Intrinsics.checkNotNullExpressionValue(create3, "{\n            ReduceSum.…cope, axis[0]))\n        }");
                operand4 = (Operand) create3;
            } else {
                ReduceSum create4 = ReduceSum.create(scope, Mul.create(scope, Transpose.create(scope, operand5, Constant.create(scope, new int[]{1, 0})), operand6), Constant.create(scope, iArr[1]), new ReduceSum.Options[0]);
                Intrinsics.checkNotNullExpressionValue(create4, "{\n            ReduceSum.…)\n            )\n        }");
                operand4 = (Operand) create4;
            }
            operand3 = operand4;
        } else {
            boolean z = iArr[0] == numDimensions3 - 1;
            if (z) {
                bool = null;
            } else {
                if (z) {
                    throw new NoWhenBranchMatchedException();
                }
                bool = true;
            }
            Boolean bool3 = bool;
            boolean z2 = iArr[1] == numDimensions4 - 1;
            if (z2) {
                bool2 = true;
            } else {
                if (z2) {
                    throw new NoWhenBranchMatchedException();
                }
                bool2 = null;
            }
            Operand<Float> create5 = MatMul.create(scope, operand5, operand6, new MatMul.Options[]{MatMul.transposeA(bool3), MatMul.transposeB(bool2)});
            Intrinsics.checkNotNullExpressionValue(create5, "create(scope, x2, y2, Ma… MatMul.transposeB(adjY))");
            operand3 = create5;
        }
        if (i != 0) {
            float f = numDimensions > numDimensions2 ? (numDimensions + numDimensions2) - 3 : numDimensions - 1;
            float[] fArr = new float[i];
            for (int i4 = 0; i4 < i; i4++) {
                int i5 = i4;
                fArr[i5] = i5 + f;
            }
            Squeeze create6 = Squeeze.create(scope, Constant.create(scope, fArr), new Squeeze.Options[0]);
            Intrinsics.checkNotNullExpressionValue(create6, "create(scope, Constant.c…rray(diff) { it + idx }))");
            operand3 = (Operand) create6;
        }
        if (operand3.asOutput().shape().numDimensions() == 1) {
            ExpandDims.create(scope, operand3, Constant.create(scope, 1));
        }
        return operand3;
    }
}
