package org.deeplearning4j.nn.conf.layers;

import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.nd4j.linalg.lossfunctions.LossFunctions;

/* loaded from: input_file:org/deeplearning4j/nn/conf/layers/BasePretrainNetwork.class */
public abstract class BasePretrainNetwork extends FeedForwardLayer {
    protected LossFunctions.LossFunction lossFunction;
    protected String customLossFunction;

    /* loaded from: input_file:org/deeplearning4j/nn/conf/layers/BasePretrainNetwork$Builder.class */
    public static abstract class Builder<T extends Builder<T>> extends FeedForwardLayer.Builder<T> {
        protected LossFunctions.LossFunction lossFunction = LossFunctions.LossFunction.RECONSTRUCTION_CROSSENTROPY;
        protected String customLossFunction = null;

        public T lossFunction(LossFunctions.LossFunction lossFunction) {
            this.lossFunction = lossFunction;
            return this;
        }

        public T customLossFunction(String str) {
            this.customLossFunction = str;
            return this;
        }
    }

    public BasePretrainNetwork(Builder builder) {
        super(builder);
        this.lossFunction = builder.lossFunction;
        this.customLossFunction = builder.customLossFunction;
    }

    @Override // org.deeplearning4j.nn.conf.layers.FeedForwardLayer, org.deeplearning4j.nn.conf.layers.Layer
    public double getL1ByParam(String str) {
        boolean z = -1;
        switch (str.hashCode()) {
            case 87:
                if (str.equals("W")) {
                    z = false;
                    break;
                }
                break;
            case 98:
                if (str.equals("b")) {
                    z = true;
                    break;
                }
                break;
            case 3104:
                if (str.equals("bB")) {
                    z = 2;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                return this.l1;
            case true:
                return 0.0d;
            case true:
                return 0.0d;
            default:
                throw new IllegalArgumentException("Unknown parameter name: \"" + str + "\"");
        }
    }

    @Override // org.deeplearning4j.nn.conf.layers.FeedForwardLayer, org.deeplearning4j.nn.conf.layers.Layer
    public double getL2ByParam(String str) {
        boolean z = -1;
        switch (str.hashCode()) {
            case 87:
                if (str.equals("W")) {
                    z = false;
                    break;
                }
                break;
            case 98:
                if (str.equals("b")) {
                    z = true;
                    break;
                }
                break;
            case 3104:
                if (str.equals("bB")) {
                    z = 2;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                return this.l2;
            case true:
                return 0.0d;
            case true:
                return 0.0d;
            default:
                throw new IllegalArgumentException("Unknown parameter name: \"" + str + "\"");
        }
    }

    @Override // org.deeplearning4j.nn.conf.layers.FeedForwardLayer, org.deeplearning4j.nn.conf.layers.Layer
    public double getLearningRateByParam(String str) {
        boolean z = -1;
        switch (str.hashCode()) {
            case 87:
                if (str.equals("W")) {
                    z = false;
                    break;
                }
                break;
            case 98:
                if (str.equals("b")) {
                    z = true;
                    break;
                }
                break;
            case 3104:
                if (str.equals("bB")) {
                    z = 2;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                return this.learningRate;
            case true:
                return !Double.isNaN(this.biasLearningRate) ? this.biasLearningRate : this.learningRate;
            case true:
                return !Double.isNaN(this.biasLearningRate) ? this.biasLearningRate : this.learningRate;
            default:
                throw new IllegalArgumentException("Unknown parameter name: \"" + str + "\"");
        }
    }

    public LossFunctions.LossFunction getLossFunction() {
        return this.lossFunction;
    }

    public String getCustomLossFunction() {
        return this.customLossFunction;
    }

    public void setLossFunction(LossFunctions.LossFunction lossFunction) {
        this.lossFunction = lossFunction;
    }

    public void setCustomLossFunction(String str) {
        this.customLossFunction = str;
    }

    public BasePretrainNetwork() {
    }

    @Override // org.deeplearning4j.nn.conf.layers.FeedForwardLayer, org.deeplearning4j.nn.conf.layers.Layer
    public String toString() {
        return "BasePretrainNetwork(super=" + super.toString() + ", lossFunction=" + getLossFunction() + ", customLossFunction=" + getCustomLossFunction() + ")";
    }

    @Override // org.deeplearning4j.nn.conf.layers.FeedForwardLayer, org.deeplearning4j.nn.conf.layers.Layer
    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof BasePretrainNetwork)) {
            return false;
        }
        BasePretrainNetwork basePretrainNetwork = (BasePretrainNetwork) obj;
        if (!basePretrainNetwork.canEqual(this) || !super.equals(obj)) {
            return false;
        }
        LossFunctions.LossFunction lossFunction = getLossFunction();
        LossFunctions.LossFunction lossFunction2 = basePretrainNetwork.getLossFunction();
        if (lossFunction == null) {
            if (lossFunction2 != null) {
                return false;
            }
        } else if (!lossFunction.equals(lossFunction2)) {
            return false;
        }
        String customLossFunction = getCustomLossFunction();
        String customLossFunction2 = basePretrainNetwork.getCustomLossFunction();
        return customLossFunction == null ? customLossFunction2 == null : customLossFunction.equals(customLossFunction2);
    }

    @Override // org.deeplearning4j.nn.conf.layers.FeedForwardLayer, org.deeplearning4j.nn.conf.layers.Layer
    protected boolean canEqual(Object obj) {
        return obj instanceof BasePretrainNetwork;
    }

    @Override // org.deeplearning4j.nn.conf.layers.FeedForwardLayer, org.deeplearning4j.nn.conf.layers.Layer
    public int hashCode() {
        int hashCode = (1 * 59) + super.hashCode();
        LossFunctions.LossFunction lossFunction = getLossFunction();
        int hashCode2 = (hashCode * 59) + (lossFunction == null ? 43 : lossFunction.hashCode());
        String customLossFunction = getCustomLossFunction();
        return (hashCode2 * 59) + (customLossFunction == null ? 43 : customLossFunction.hashCode());
    }
}
