package org.jetbrains.kotlinx.dl.api.core.layer.normalization;

import java.util.List;
import kotlin.Metadata;
import kotlin.collections.CollectionsKt;
import kotlin.jvm.internal.DefaultConstructorMarker;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import org.jetbrains.kotlinx.dl.api.core.initializer.Initializer;
import org.jetbrains.kotlinx.dl.api.core.initializer.Ones;
import org.jetbrains.kotlinx.dl.api.core.initializer.Zeros;
import org.jetbrains.kotlinx.dl.api.core.layer.KVariable;
import org.jetbrains.kotlinx.dl.api.core.layer.KVariableKt;
import org.jetbrains.kotlinx.dl.api.core.layer.Layer;
import org.jetbrains.kotlinx.dl.api.core.layer.NoGradients;
import org.jetbrains.kotlinx.dl.api.core.layer.ParametrizedLayer;
import org.jetbrains.kotlinx.dl.api.core.layer.TrainableLayerKt;
import org.jetbrains.kotlinx.dl.api.core.layer.convolutional.Conv1D;
import org.jetbrains.kotlinx.dl.api.core.regularizer.Regularizer;
import org.jetbrains.kotlinx.dl.api.core.util.NameConventionsKt;
import org.tensorflow.Operand;
import org.tensorflow.Shape;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Constant;
import org.tensorflow.op.core.Variable;
import org.tensorflow.op.math.Mul;

/* compiled from: BatchNorm.kt */
@Metadata(mv = {Conv1D.EXTRA_DIM, 7, Conv1D.EXTRA_DIM}, k = Conv1D.EXTRA_DIM, xi = 48, d1 = {"��b\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0010 \n\u0002\u0010\b\n��\n\u0002\u0010\u0006\n��\n\u0002\u0010\u000b\n\u0002\b\u0003\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\u0004\n\u0002\u0010\u000e\n\u0002\b\u0004\n\u0002\u0018\u0002\n\u0002\b \n\u0002\u0018\u0002\n\u0002\u0010\u0007\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\b\u0018��2\u00020\u00012\u00020\u00022\u00020\u0003B\u0087\u0001\u0012\u000e\b\u0002\u0010\u0004\u001a\b\u0012\u0004\u0012\u00020\u00060\u0005\u0012\b\b\u0002\u0010\u0007\u001a\u00020\b\u0012\b\b\u0002\u0010\t\u001a\u00020\n\u0012\b\b\u0002\u0010\u000b\u001a\u00020\b\u0012\b\b\u0002\u0010\f\u001a\u00020\n\u0012\b\b\u0002\u0010\r\u001a\u00020\u000e\u0012\b\b\u0002\u0010\u000f\u001a\u00020\u000e\u0012\n\b\u0002\u0010\u0010\u001a\u0004\u0018\u00010\u0011\u0012\n\b\u0002\u0010\u0012\u001a\u0004\u0018\u00010\u0011\u0012\b\b\u0002\u0010\u0013\u001a\u00020\u000e\u0012\b\b\u0002\u0010\u0014\u001a\u00020\u000e\u0012\b\b\u0002\u0010\u0015\u001a\u00020\u0016¢\u0006\u0002\u0010\u0017Jn\u0010;\u001a\b\u0012\u0004\u0012\u00020=0<2\u0006\u0010>\u001a\u00020?2\f\u0010@\u001a\b\u0012\u0004\u0012\u00020=0<2\u000e\u0010(\u001a\n\u0012\u0004\u0012\u00020=\u0018\u00010A2\u000e\u0010\u001a\u001a\n\u0012\u0004\u0012\u00020=\u0018\u00010<2\f\u00100\u001a\b\u0012\u0004\u0012\u00020=0<2\f\u0010B\u001a\b\u0012\u0004\u0012\u00020=0<2\f\u0010C\u001a\b\u0012\u0004\u0012\u00020=0<H\u0002JB\u0010D\u001a\b\u0012\u0004\u0012\u00020=0<2\u0006\u0010>\u001a\u00020?2\f\u0010E\u001a\b\u0012\u0004\u0012\u00020=0<2\f\u0010F\u001a\b\u0012\u0004\u0012\u00020\n0<2\u000e\u0010G\u001a\n\u0012\u0004\u0012\u00020=\u0018\u00010<H\u0016J\b\u0010H\u001a\u00020\u0016H\u0016R\u0017\u0010\u0004\u001a\b\u0012\u0004\u0012\u00020\u00060\u0005¢\u0006\b\n��\u001a\u0004\b\u0018\u0010\u0019R\u001c\u0010\u001a\u001a\u0004\u0018\u00010\u001bX\u0080\u000e¢\u0006\u000e\n��\u001a\u0004\b\u001c\u0010\u001d\"\u0004\b\u001e\u0010\u001fR\u0011\u0010\u000f\u001a\u00020\u000e¢\u0006\b\n��\u001a\u0004\b \u0010!R\u0013\u0010\u0012\u001a\u0004\u0018\u00010\u0011¢\u0006\b\n��\u001a\u0004\b\"\u0010#R\u0011\u0010\t\u001a\u00020\n¢\u0006\b\n��\u001a\u0004\b$\u0010%R\u0011\u0010\u000b\u001a\u00020\b¢\u0006\b\n��\u001a\u0004\b&\u0010'R\u001c\u0010(\u001a\u0004\u0018\u00010\u001bX\u0080\u000e¢\u0006\u000e\n��\u001a\u0004\b)\u0010\u001d\"\u0004\b*\u0010\u001fR\u0011\u0010\r\u001a\u00020\u000e¢\u0006\b\n��\u001a\u0004\b+\u0010!R\u0013\u0010\u0010\u001a\u0004\u0018\u00010\u0011¢\u0006\b\n��\u001a\u0004\b,\u0010#R\u0014\u0010-\u001a\u00020\n8VX\u0096\u0004¢\u0006\u0006\u001a\u0004\b.\u0010%R\u0011\u0010\u0007\u001a\u00020\b¢\u0006\b\n��\u001a\u0004\b/\u0010'R\u001a\u00100\u001a\u00020\u001bX\u0080.¢\u0006\u000e\n��\u001a\u0004\b1\u0010\u001d\"\u0004\b2\u0010\u001fR\u0011\u0010\u0013\u001a\u00020\u000e¢\u0006\b\n��\u001a\u0004\b3\u0010!R\u001a\u00104\u001a\u00020\u001bX\u0080.¢\u0006\u000e\n��\u001a\u0004\b5\u0010\u001d\"\u0004\b6\u0010\u001fR\u0011\u0010\u0014\u001a\u00020\u000e¢\u0006\b\n��\u001a\u0004\b7\u0010!R\u0011\u0010\f\u001a\u00020\n¢\u0006\b\n��\u001a\u0004\b8\u0010%R\u001a\u00109\u001a\b\u0012\u0004\u0012\u00020\u001b0\u00058VX\u0096\u0004¢\u0006\u0006\u001a\u0004\b:\u0010\u0019¨\u0006I"}, d2 = {"Lorg/jetbrains/kotlinx/dl/api/core/layer/normalization/BatchNorm;", "Lorg/jetbrains/kotlinx/dl/api/core/layer/Layer;", "Lorg/jetbrains/kotlinx/dl/api/core/layer/NoGradients;", "Lorg/jetbrains/kotlinx/dl/api/core/layer/ParametrizedLayer;", "axis", "", "", "momentum", "", "center", "", "epsilon", "scale", "gammaInitializer", "Lorg/jetbrains/kotlinx/dl/api/core/initializer/Initializer;", "betaInitializer", "gammaRegularizer", "Lorg/jetbrains/kotlinx/dl/api/core/regularizer/Regularizer;", "betaRegularizer", "movingMeanInitializer", "movingVarianceInitializer", "name", "", "(Ljava/util/List;DZDZLorg/jetbrains/kotlinx/dl/api/core/initializer/Initializer;Lorg/jetbrains/kotlinx/dl/api/core/initializer/Initializer;Lorg/jetbrains/kotlinx/dl/api/core/regularizer/Regularizer;Lorg/jetbrains/kotlinx/dl/api/core/regularizer/Regularizer;Lorg/jetbrains/kotlinx/dl/api/core/initializer/Initializer;Lorg/jetbrains/kotlinx/dl/api/core/initializer/Initializer;Ljava/lang/String;)V", "getAxis", "()Ljava/util/List;", "beta", "Lorg/jetbrains/kotlinx/dl/api/core/layer/KVariable;", "getBeta$tensorflow", "()Lorg/jetbrains/kotlinx/dl/api/core/layer/KVariable;", "setBeta$tensorflow", "(Lorg/jetbrains/kotlinx/dl/api/core/layer/KVariable;)V", "getBetaInitializer", "()Lorg/jetbrains/kotlinx/dl/api/core/initializer/Initializer;", "getBetaRegularizer", "()Lorg/jetbrains/kotlinx/dl/api/core/regularizer/Regularizer;", "getCenter", "()Z", "getEpsilon", "()D", "gamma", "getGamma$tensorflow", "setGamma$tensorflow", "getGammaInitializer", "getGammaRegularizer", "hasActivation", "getHasActivation", "getMomentum", "movingMean", "getMovingMean$tensorflow", "setMovingMean$tensorflow", "getMovingMeanInitializer", "movingVariance", "getMovingVariance$tensorflow", "setMovingVariance$tensorflow", "getMovingVarianceInitializer", "getScale", "variables", "getVariables", "batchNorm", "Lorg/tensorflow/Operand;", "", "tf", "Lorg/tensorflow/op/Ops;", "x", "Lorg/tensorflow/op/core/Variable;", "movingVar", "eps", "build", "input", "isTraining", "numberOfLosses", "toString", "tensorflow"})
/* loaded from: input_file:org/jetbrains/kotlinx/dl/api/core/layer/normalization/BatchNorm.class */
public final class BatchNorm extends Layer implements NoGradients, ParametrizedLayer {

    @NotNull
    private final List<Integer> axis;
    private final double momentum;
    private final boolean center;
    private final double epsilon;
    private final boolean scale;

    @NotNull
    private final Initializer gammaInitializer;

    @NotNull
    private final Initializer betaInitializer;

    @Nullable
    private final Regularizer gammaRegularizer;

    @Nullable
    private final Regularizer betaRegularizer;

    @NotNull
    private final Initializer movingMeanInitializer;

    @NotNull
    private final Initializer movingVarianceInitializer;

    @Nullable
    private KVariable gamma;

    @Nullable
    private KVariable beta;
    public KVariable movingMean;
    public KVariable movingVariance;

    /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
    public BatchNorm(@NotNull List<Integer> list, double d, boolean z, double d2, boolean z2, @NotNull Initializer initializer, @NotNull Initializer initializer2, @Nullable Regularizer regularizer, @Nullable Regularizer regularizer2, @NotNull Initializer initializer3, @NotNull Initializer initializer4, @NotNull String str) {
        super(str);
        Intrinsics.checkNotNullParameter(list, "axis");
        Intrinsics.checkNotNullParameter(initializer, "gammaInitializer");
        Intrinsics.checkNotNullParameter(initializer2, "betaInitializer");
        Intrinsics.checkNotNullParameter(initializer3, "movingMeanInitializer");
        Intrinsics.checkNotNullParameter(initializer4, "movingVarianceInitializer");
        Intrinsics.checkNotNullParameter(str, "name");
        this.axis = list;
        this.momentum = d;
        this.center = z;
        this.epsilon = d2;
        this.scale = z2;
        this.gammaInitializer = initializer;
        this.betaInitializer = initializer2;
        this.gammaRegularizer = regularizer;
        this.betaRegularizer = regularizer2;
        this.movingMeanInitializer = initializer3;
        this.movingVarianceInitializer = initializer4;
    }

    public /* synthetic */ BatchNorm(List list, double d, boolean z, double d2, boolean z2, Initializer initializer, Initializer initializer2, Regularizer regularizer, Regularizer regularizer2, Initializer initializer3, Initializer initializer4, String str, int i, DefaultConstructorMarker defaultConstructorMarker) {
        this((i & 1) != 0 ? CollectionsKt.arrayListOf(new Integer[]{3}) : list, (i & 2) != 0 ? 0.99d : d, (i & 4) != 0 ? true : z, (i & 8) != 0 ? 0.001d : d2, (i & 16) != 0 ? true : z2, (i & 32) != 0 ? new Ones() : initializer, (i & 64) != 0 ? new Zeros() : initializer2, (i & 128) != 0 ? null : regularizer, (i & 256) != 0 ? null : regularizer2, (i & 512) != 0 ? new Zeros() : initializer3, (i & 1024) != 0 ? new Ones() : initializer4, (i & 2048) != 0 ? "" : str);
    }

    @NotNull
    public final List<Integer> getAxis() {
        return this.axis;
    }

    public final double getMomentum() {
        return this.momentum;
    }

    public final boolean getCenter() {
        return this.center;
    }

    public final double getEpsilon() {
        return this.epsilon;
    }

    public final boolean getScale() {
        return this.scale;
    }

    @NotNull
    public final Initializer getGammaInitializer() {
        return this.gammaInitializer;
    }

    @NotNull
    public final Initializer getBetaInitializer() {
        return this.betaInitializer;
    }

    @Nullable
    public final Regularizer getGammaRegularizer() {
        return this.gammaRegularizer;
    }

    @Nullable
    public final Regularizer getBetaRegularizer() {
        return this.betaRegularizer;
    }

    @NotNull
    public final Initializer getMovingMeanInitializer() {
        return this.movingMeanInitializer;
    }

    @NotNull
    public final Initializer getMovingVarianceInitializer() {
        return this.movingVarianceInitializer;
    }

    @Nullable
    public final KVariable getGamma$tensorflow() {
        return this.gamma;
    }

    public final void setGamma$tensorflow(@Nullable KVariable kVariable) {
        this.gamma = kVariable;
    }

    @Nullable
    public final KVariable getBeta$tensorflow() {
        return this.beta;
    }

    public final void setBeta$tensorflow(@Nullable KVariable kVariable) {
        this.beta = kVariable;
    }

    @NotNull
    public final KVariable getMovingMean$tensorflow() {
        KVariable kVariable = this.movingMean;
        if (kVariable != null) {
            return kVariable;
        }
        Intrinsics.throwUninitializedPropertyAccessException("movingMean");
        return null;
    }

    public final void setMovingMean$tensorflow(@NotNull KVariable kVariable) {
        Intrinsics.checkNotNullParameter(kVariable, "<set-?>");
        this.movingMean = kVariable;
    }

    @NotNull
    public final KVariable getMovingVariance$tensorflow() {
        KVariable kVariable = this.movingVariance;
        if (kVariable != null) {
            return kVariable;
        }
        Intrinsics.throwUninitializedPropertyAccessException("movingVariance");
        return null;
    }

    public final void setMovingVariance$tensorflow(@NotNull KVariable kVariable) {
        Intrinsics.checkNotNullParameter(kVariable, "<set-?>");
        this.movingVariance = kVariable;
    }

    @Override // org.jetbrains.kotlinx.dl.api.core.layer.ParametrizedLayer
    @NotNull
    public List<KVariable> getVariables() {
        return CollectionsKt.listOfNotNull(new KVariable[]{this.gamma, this.beta, getMovingMean$tensorflow(), getMovingVariance$tensorflow()});
    }

    @Override // org.jetbrains.kotlinx.dl.api.core.layer.Layer
    @NotNull
    public Operand<Float> build(@NotNull Ops ops, @NotNull Operand<Float> operand, @NotNull Operand<Boolean> operand2, @Nullable Operand<Float> operand3) {
        Intrinsics.checkNotNullParameter(ops, "tf");
        Intrinsics.checkNotNullParameter(operand, "input");
        Intrinsics.checkNotNullParameter(operand2, "isTraining");
        Shape make = Shape.make(operand.asOutput().shape().size(this.axis.get(0).intValue()), new long[0]);
        if (getName().length() == 0) {
            throw new RuntimeException("Cannot build BatchNorm layer, because of empty name");
        }
        String batchNormMovingMeanVarName = NameConventionsKt.batchNormMovingMeanVarName(getName());
        Intrinsics.checkNotNullExpressionValue(make, "weightShape");
        setMovingMean$tensorflow(KVariableKt.createVariable(ops, batchNormMovingMeanVarName, make, Integer.MIN_VALUE, Integer.MIN_VALUE, this.movingMeanInitializer, null));
        setMovingVariance$tensorflow(KVariableKt.createVariable(ops, NameConventionsKt.batchNormMovingVarianceVarName(getName()), make, Integer.MIN_VALUE, Integer.MIN_VALUE, this.movingVarianceInitializer, null));
        if (this.scale) {
            this.gamma = KVariableKt.createVariable(ops, NameConventionsKt.batchNormGammaVarName(getName()), make, Integer.MIN_VALUE, Integer.MIN_VALUE, this.gammaInitializer, this.gammaRegularizer);
        }
        if (this.center) {
            this.beta = KVariableKt.createVariable(ops, NameConventionsKt.batchNormBetaVarName(getName()), make, Integer.MIN_VALUE, Integer.MIN_VALUE, this.betaInitializer, this.betaRegularizer);
        }
        Ops withName = ops.withName("BatchNorm");
        Intrinsics.checkNotNullExpressionValue(withName, "tf");
        KVariable kVariable = this.gamma;
        Variable<Float> variable = kVariable != null ? kVariable.getVariable() : null;
        KVariable kVariable2 = this.beta;
        Variable<Float> variable2 = kVariable2 != null ? kVariable2.getVariable() : null;
        Operand<Float> operand4 = (Operand) getMovingMean$tensorflow().getVariable();
        Operand<Float> operand5 = (Operand) getMovingVariance$tensorflow().getVariable();
        Constant constant = withName.constant((float) this.epsilon);
        Intrinsics.checkNotNullExpressionValue(constant, "tf.constant(epsilon.toFloat())");
        return batchNorm(withName, operand, variable, (Operand) variable2, operand4, operand5, (Operand) constant);
    }

    private final Operand<Float> batchNorm(Ops ops, Operand<Float> operand, Variable<Float> variable, Operand<Float> operand2, Operand<Float> operand3, Operand<Float> operand4, Operand<Float> operand5) {
        Operand rsqrt = ops.math.rsqrt(ops.math.add(operand4, operand5));
        Intrinsics.checkNotNullExpressionValue(rsqrt, "tf.math.rsqrt(tf.math.add(movingVar, eps))");
        Operand operand6 = rsqrt;
        if (this.scale) {
            Mul mul = ops.math.mul(operand6, (Operand) variable);
            Intrinsics.checkNotNullExpressionValue(mul, "tf.math.mul(inv, gamma)");
            operand6 = (Operand) mul;
        }
        Operand<Float> mul2 = ops.math.mul(ops.math.sub(operand, operand3), operand6);
        if (!this.center) {
            Intrinsics.checkNotNullExpressionValue(mul2, "xNorm");
            return mul2;
        }
        Operand<Float> add = ops.math.add(mul2, operand2);
        Intrinsics.checkNotNullExpressionValue(add, "tf.math.add(xNorm, beta)");
        return add;
    }

    @NotNull
    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append("BatchNorm(name = ").append(getName()).append(", isTrainable=").append(TrainableLayerKt.isTrainable(this)).append(", axis=").append(this.axis).append(", momentum=").append(this.momentum).append(", center=").append(this.center).append(", epsilon=").append(this.epsilon).append(", scale=").append(this.scale).append(", gammaInitializer=").append(this.gammaInitializer).append(", betaInitializer=").append(this.betaInitializer).append(", gammaRegularizer=").append(this.gammaRegularizer).append(", betaRegularizer=").append(this.betaRegularizer).append(", movingMeanInitializer=");
        StringBuilder append = sb.append(this.movingMeanInitializer).append(", movingVarianceInitializer=").append(this.movingVarianceInitializer).append(", hasActivation=").append(getHasActivation()).append(", gammaShapeArray=");
        KVariable kVariable = this.gamma;
        StringBuilder append2 = append.append(kVariable != null ? kVariable.getShape() : null).append(", betaShapeArray=");
        KVariable kVariable2 = this.beta;
        append2.append(kVariable2 != null ? kVariable2.getShape() : null).append(", movingMeanShapeArray=").append(getMovingMean$tensorflow().getShape()).append(", movingVarianceShapeArray=").append(getMovingVariance$tensorflow().getShape()).append(')');
        return sb.toString();
    }

    @Override // org.jetbrains.kotlinx.dl.api.core.layer.Layer
    public boolean getHasActivation() {
        return false;
    }

    @Override // org.jetbrains.kotlinx.dl.api.core.layer.ParametrizedLayer
    public int getParamCount() {
        return ParametrizedLayer.DefaultImpls.getParamCount(this);
    }

    public BatchNorm() {
        this(null, 0.0d, false, 0.0d, false, null, null, null, null, null, null, null, 4095, null);
    }
}
