package org.jetbrains.kotlinx.dl.api.core.layer.activation;

import java.util.Arrays;
import java.util.List;
import kotlin.Metadata;
import kotlin.collections.ArraysKt;
import kotlin.collections.CollectionsKt;
import kotlin.jvm.internal.DefaultConstructorMarker;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import org.jetbrains.kotlinx.dl.api.core.initializer.Initializer;
import org.jetbrains.kotlinx.dl.api.core.initializer.Zeros;
import org.jetbrains.kotlinx.dl.api.core.layer.KVariable;
import org.jetbrains.kotlinx.dl.api.core.layer.KVariableKt;
import org.jetbrains.kotlinx.dl.api.core.layer.TrainableLayer;
import org.jetbrains.kotlinx.dl.api.core.layer.convolutional.Conv1D;
import org.jetbrains.kotlinx.dl.api.core.regularizer.Regularizer;
import org.jetbrains.kotlinx.dl.api.core.shape.ShapeFunctionsKt;
import org.tensorflow.Operand;
import org.tensorflow.Shape;
import org.tensorflow.op.Ops;

/* compiled from: PReLU.kt */
@Metadata(mv = {Conv1D.EXTRA_DIM, 8, 0}, k = Conv1D.EXTRA_DIM, xi = 48, d1 = {"��R\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0010\u0015\n��\n\u0002\u0010\u000e\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\t\n\u0002\u0010\u000b\n\u0002\b\u0006\n\u0002\u0010 \n\u0002\b\u0004\n\u0002\u0018\u0002\n\u0002\u0010\u0007\n��\n\u0002\u0018\u0002\n\u0002\b\u0003\u0018��2\u00020\u00012\u00020\u0002B1\u0012\b\b\u0002\u0010\u0003\u001a\u00020\u0004\u0012\n\b\u0002\u0010\u0005\u001a\u0004\u0018\u00010\u0006\u0012\n\b\u0002\u0010\u0007\u001a\u0004\u0018\u00010\b\u0012\b\b\u0002\u0010\t\u001a\u00020\n¢\u0006\u0002\u0010\u000bJ\b\u0010!\u001a\u00020\nH\u0002J$\u0010\"\u001a\b\u0012\u0004\u0012\u00020$0#2\u0006\u0010%\u001a\u00020&2\f\u0010'\u001a\b\u0012\u0004\u0012\u00020$0#H\u0016J\b\u0010(\u001a\u00020\nH\u0016R\u001a\u0010\f\u001a\u00020\rX\u0080.¢\u0006\u000e\n��\u001a\u0004\b\u000e\u0010\u000f\"\u0004\b\u0010\u0010\u0011R\u0011\u0010\u0003\u001a\u00020\u0004¢\u0006\b\n��\u001a\u0004\b\u0012\u0010\u0013R\u0013\u0010\u0005\u001a\u0004\u0018\u00010\u0006¢\u0006\b\n��\u001a\u0004\b\u0014\u0010\u0015R\u001a\u0010\u0016\u001a\u00020\u0017X\u0096\u000e¢\u0006\u000e\n��\u001a\u0004\b\u0016\u0010\u0018\"\u0004\b\u0019\u0010\u001aR\u0013\u0010\u0007\u001a\u0004\u0018\u00010\b¢\u0006\b\n��\u001a\u0004\b\u001b\u0010\u001cR\u001a\u0010\u001d\u001a\b\u0012\u0004\u0012\u00020\r0\u001e8VX\u0096\u0004¢\u0006\u0006\u001a\u0004\b\u001f\u0010 ¨\u0006)"}, d2 = {"Lorg/jetbrains/kotlinx/dl/api/core/layer/activation/PReLU;", "Lorg/jetbrains/kotlinx/dl/api/core/layer/activation/AbstractActivationLayer;", "Lorg/jetbrains/kotlinx/dl/api/core/layer/TrainableLayer;", "alphaInitializer", "Lorg/jetbrains/kotlinx/dl/api/core/initializer/Initializer;", "alphaRegularizer", "Lorg/jetbrains/kotlinx/dl/api/core/regularizer/Regularizer;", "sharedAxes", "", "name", "", "(Lorg/jetbrains/kotlinx/dl/api/core/initializer/Initializer;Lorg/jetbrains/kotlinx/dl/api/core/regularizer/Regularizer;[ILjava/lang/String;)V", "alpha", "Lorg/jetbrains/kotlinx/dl/api/core/layer/KVariable;", "getAlpha$tensorflow", "()Lorg/jetbrains/kotlinx/dl/api/core/layer/KVariable;", "setAlpha$tensorflow", "(Lorg/jetbrains/kotlinx/dl/api/core/layer/KVariable;)V", "getAlphaInitializer", "()Lorg/jetbrains/kotlinx/dl/api/core/initializer/Initializer;", "getAlphaRegularizer", "()Lorg/jetbrains/kotlinx/dl/api/core/regularizer/Regularizer;", "isTrainable", "", "()Z", "setTrainable", "(Z)V", "getSharedAxes", "()[I", "variables", "", "getVariables", "()Ljava/util/List;", "alphaVariableName", "forward", "Lorg/tensorflow/Operand;", "", "tf", "Lorg/tensorflow/op/Ops;", "input", "toString", "tensorflow"})
/* loaded from: input_file:org/jetbrains/kotlinx/dl/api/core/layer/activation/PReLU.class */
public final class PReLU extends AbstractActivationLayer implements TrainableLayer {

    @NotNull
    private final Initializer alphaInitializer;

    @Nullable
    private final Regularizer alphaRegularizer;

    @Nullable
    private final int[] sharedAxes;
    public KVariable alpha;
    private boolean isTrainable;

    /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
    public PReLU(@NotNull Initializer initializer, @Nullable Regularizer regularizer, @Nullable int[] iArr, @NotNull String str) {
        super(str);
        Intrinsics.checkNotNullParameter(initializer, "alphaInitializer");
        Intrinsics.checkNotNullParameter(str, "name");
        this.alphaInitializer = initializer;
        this.alphaRegularizer = regularizer;
        this.sharedAxes = iArr;
        this.isTrainable = true;
    }

    public /* synthetic */ PReLU(Initializer initializer, Regularizer regularizer, int[] iArr, String str, int i, DefaultConstructorMarker defaultConstructorMarker) {
        this((i & 1) != 0 ? new Zeros() : initializer, (i & 2) != 0 ? null : regularizer, (i & 4) != 0 ? null : iArr, (i & 8) != 0 ? "" : str);
    }

    @NotNull
    public final Initializer getAlphaInitializer() {
        return this.alphaInitializer;
    }

    @Nullable
    public final Regularizer getAlphaRegularizer() {
        return this.alphaRegularizer;
    }

    @Nullable
    public final int[] getSharedAxes() {
        return this.sharedAxes;
    }

    @NotNull
    public final KVariable getAlpha$tensorflow() {
        KVariable kVariable = this.alpha;
        if (kVariable != null) {
            return kVariable;
        }
        Intrinsics.throwUninitializedPropertyAccessException("alpha");
        return null;
    }

    public final void setAlpha$tensorflow(@NotNull KVariable kVariable) {
        Intrinsics.checkNotNullParameter(kVariable, "<set-?>");
        this.alpha = kVariable;
    }

    private final String alphaVariableName() {
        return getName().length() > 0 ? getName() + "_alpha" : "alpha";
    }

    @Override // org.jetbrains.kotlinx.dl.api.core.layer.ParametrizedLayer
    @NotNull
    public List<KVariable> getVariables() {
        return CollectionsKt.listOf(getAlpha$tensorflow());
    }

    @Override // org.jetbrains.kotlinx.dl.api.core.layer.TrainableLayer
    public boolean isTrainable() {
        return this.isTrainable;
    }

    @Override // org.jetbrains.kotlinx.dl.api.core.layer.TrainableLayer
    public void setTrainable(boolean z) {
        this.isTrainable = z;
    }

    @Override // org.jetbrains.kotlinx.dl.api.core.layer.activation.AbstractActivationLayer
    @NotNull
    public Operand<Float> forward(@NotNull Ops ops, @NotNull Operand<Float> operand) {
        Intrinsics.checkNotNullParameter(ops, "tf");
        Intrinsics.checkNotNullParameter(operand, "input");
        Shape shape = operand.asOutput().shape();
        Intrinsics.checkNotNullExpressionValue(shape, "inputShape");
        long[] longArray = CollectionsKt.toLongArray(ArraysKt.drop(ShapeFunctionsKt.toLongArray(shape), 1));
        if (this.sharedAxes != null) {
            for (int i : this.sharedAxes) {
                longArray[i - 1] = 1;
            }
        }
        int size = (int) shape.size(shape.numDimensions() - 1);
        long j = longArray[0];
        long[] longArray2 = CollectionsKt.toLongArray(ArraysKt.drop(longArray, 1));
        Shape make = Shape.make(j, Arrays.copyOf(longArray2, longArray2.length));
        String alphaVariableName = alphaVariableName();
        Intrinsics.checkNotNullExpressionValue(make, "alphaShape");
        setAlpha$tensorflow(KVariableKt.createVariable(ops, alphaVariableName, make, size, size, this.alphaInitializer, this.alphaRegularizer));
        Operand<Float> add = ops.math.add(ops.nn.relu(operand), ops.math.mul(ops.math.neg(getAlpha$tensorflow().getVariable()), ops.nn.relu(ops.math.neg(operand))));
        Intrinsics.checkNotNullExpressionValue(add, "tf.math.add(positive, negative)");
        return add;
    }

    @NotNull
    public String toString() {
        String str;
        StringBuilder append = new StringBuilder().append("PReLU(name = ").append(getName()).append(", isTrainable=").append(isTrainable()).append(", alphaInitializer=").append(this.alphaInitializer).append(", alphaRegularizer=").append(this.alphaRegularizer).append(", sharedAxes=");
        int[] iArr = this.sharedAxes;
        if (iArr != null) {
            str = Arrays.toString(iArr);
            Intrinsics.checkNotNullExpressionValue(str, "toString(this)");
        } else {
            str = null;
        }
        return append.append(str).append(')').toString();
    }

    @Override // org.jetbrains.kotlinx.dl.api.core.layer.ParametrizedLayer
    public int getParamCount() {
        return TrainableLayer.DefaultImpls.getParamCount(this);
    }

    public PReLU() {
        this(null, null, null, null, 15, null);
    }
}
