package org.deeplearning4j.nn.conf.layers.samediff;

import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.api.TrainingConfig;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.graph.GraphVertex;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException;
import org.deeplearning4j.nn.conf.memory.MemoryReport;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.layers.samediff.SameDiffGraphVertex;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.learning.regularization.Regularization;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.util.ArrayUtil;

/* loaded from: input_file:org/deeplearning4j/nn/conf/layers/samediff/SameDiffVertex.class */
public abstract class SameDiffVertex extends GraphVertex implements TrainingConfig {
    private SDVertexParams vertexParams;
    private String name;
    protected List<Regularization> regularization;
    protected List<Regularization> regularizationBias;
    protected IUpdater updater;
    protected IUpdater biasUpdater;
    protected GradientNormalization gradientNormalization;
    protected double gradientNormalizationThreshold = Double.NaN;
    protected DataType dataType;

    public abstract SDVariable defineVertex(SameDiff sameDiff, Map<String, SDVariable> map, Map<String, SDVariable> map2, Map<String, SDVariable> map3);

    public abstract void defineParametersAndInputs(SDVertexParams sDVertexParams);

    public abstract void initializeParameters(Map<String, INDArray> map);

    public SDVertexParams getVertexParams() {
        if (this.vertexParams == null) {
            this.vertexParams = new SDVertexParams();
            defineParametersAndInputs(this.vertexParams);
        }
        return this.vertexParams;
    }

    @Override // org.deeplearning4j.nn.conf.graph.GraphVertex
    /* renamed from: clone */
    public GraphVertex mo5691clone() {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    @Override // org.deeplearning4j.nn.conf.graph.GraphVertex
    public long numParams(boolean z) {
        long j = 0;
        Iterator<long[]> it2 = getVertexParams().getParamShapes().values().iterator();
        while (it2.hasNext()) {
            j += ArrayUtil.prodLong(it2.next());
        }
        return (int) j;
    }

    @Override // org.deeplearning4j.nn.conf.graph.GraphVertex
    public int minVertexInputs() {
        return 1;
    }

    @Override // org.deeplearning4j.nn.conf.graph.GraphVertex
    public int maxVertexInputs() {
        return -1;
    }

    @Override // org.deeplearning4j.nn.conf.graph.GraphVertex
    public org.deeplearning4j.nn.graph.vertex.GraphVertex instantiate(ComputationGraph computationGraph, String str, int i, INDArray iNDArray, boolean z, DataType dataType) {
        this.name = str;
        return new SameDiffGraphVertex(this, computationGraph, str, i, iNDArray, z, dataType);
    }

    @Override // org.deeplearning4j.nn.conf.graph.GraphVertex
    public InputType getOutputType(int i, InputType... inputTypeArr) throws InvalidInputTypeException {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    public Pair<INDArray, MaskState> feedForwardMaskArrays(INDArray[] iNDArrayArr, MaskState maskState, int i) {
        throw new UnsupportedOperationException("Not yet supported");
    }

    public void validateInput(INDArray[] iNDArrayArr) {
    }

    @Override // org.deeplearning4j.nn.conf.graph.GraphVertex
    public MemoryReport getMemoryReport(InputType... inputTypeArr) {
        return null;
    }

    public char paramReshapeOrder(String str) {
        return 'c';
    }

    public void applyGlobalConfig(NeuralNetConfiguration.Builder builder) {
        if (this.regularization == null || this.regularization.isEmpty()) {
            this.regularization = builder.getRegularization();
        }
        if (this.regularizationBias == null || this.regularizationBias.isEmpty()) {
            this.regularizationBias = builder.getRegularizationBias();
        }
        if (this.updater == null) {
            this.updater = builder.getIUpdater();
        }
        if (this.biasUpdater == null) {
            this.biasUpdater = builder.getBiasUpdater();
        }
        if (this.gradientNormalization == null) {
            this.gradientNormalization = builder.getGradientNormalization();
        }
        if (Double.isNaN(this.gradientNormalizationThreshold)) {
            this.gradientNormalizationThreshold = builder.getGradientNormalizationThreshold();
        }
        applyGlobalConfigToLayer(builder);
    }

    public void applyGlobalConfigToLayer(NeuralNetConfiguration.Builder builder) {
    }

    @Override // org.deeplearning4j.nn.api.TrainingConfig
    public String getLayerName() {
        return this.name;
    }

    @Override // org.deeplearning4j.nn.api.TrainingConfig
    public List<Regularization> getRegularizationByParam(String str) {
        if ((this.regularization == null || this.regularization.isEmpty()) && (this.regularizationBias == null || this.regularizationBias.isEmpty())) {
            return null;
        }
        if (getVertexParams().isWeightParam(str)) {
            return this.regularization;
        }
        if (getVertexParams().isBiasParam(str)) {
            return this.regularizationBias;
        }
        throw new IllegalStateException("Unknown parameter name: " + str + " - not in weights (" + getVertexParams().getWeightParameterKeys() + ") or biases (" + getVertexParams().getBiasParameterKeys() + ")");
    }

    @Override // org.deeplearning4j.nn.api.TrainingConfig
    public boolean isPretrainParam(String str) {
        return false;
    }

    @Override // org.deeplearning4j.nn.api.TrainingConfig
    public IUpdater getUpdaterByParam(String str) {
        if (getVertexParams().isWeightParam(str)) {
            return this.updater;
        }
        if (getVertexParams().isBiasParam(str)) {
            return this.biasUpdater == null ? this.updater : this.biasUpdater;
        }
        throw new IllegalStateException("Unknown parameter name: " + str + " - not in weights (" + getVertexParams().getWeightParameterKeys() + ") or biases (" + getVertexParams().getBiasParameterKeys() + ")");
    }

    @Override // org.deeplearning4j.nn.api.TrainingConfig
    public GradientNormalization getGradientNormalization() {
        return this.gradientNormalization;
    }

    @Override // org.deeplearning4j.nn.api.TrainingConfig
    public double getGradientNormalizationThreshold() {
        return this.gradientNormalizationThreshold;
    }

    @Override // org.deeplearning4j.nn.conf.graph.GraphVertex
    public void setDataType(DataType dataType) {
        this.dataType = dataType;
    }

    public String getName() {
        return this.name;
    }

    public List<Regularization> getRegularization() {
        return this.regularization;
    }

    public List<Regularization> getRegularizationBias() {
        return this.regularizationBias;
    }

    public IUpdater getUpdater() {
        return this.updater;
    }

    public IUpdater getBiasUpdater() {
        return this.biasUpdater;
    }

    public DataType getDataType() {
        return this.dataType;
    }

    public void setVertexParams(SDVertexParams sDVertexParams) {
        this.vertexParams = sDVertexParams;
    }

    public void setName(String str) {
        this.name = str;
    }

    public void setRegularization(List<Regularization> list) {
        this.regularization = list;
    }

    public void setRegularizationBias(List<Regularization> list) {
        this.regularizationBias = list;
    }

    public void setUpdater(IUpdater iUpdater) {
        this.updater = iUpdater;
    }

    public void setBiasUpdater(IUpdater iUpdater) {
        this.biasUpdater = iUpdater;
    }

    public void setGradientNormalization(GradientNormalization gradientNormalization) {
        this.gradientNormalization = gradientNormalization;
    }

    public void setGradientNormalizationThreshold(double d) {
        this.gradientNormalizationThreshold = d;
    }

    @Override // org.deeplearning4j.nn.conf.graph.GraphVertex
    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof SameDiffVertex)) {
            return false;
        }
        SameDiffVertex sameDiffVertex = (SameDiffVertex) obj;
        if (!sameDiffVertex.canEqual(this)) {
            return false;
        }
        SDVertexParams vertexParams = getVertexParams();
        SDVertexParams vertexParams2 = sameDiffVertex.getVertexParams();
        if (vertexParams == null) {
            if (vertexParams2 != null) {
                return false;
            }
        } else if (!vertexParams.equals(vertexParams2)) {
            return false;
        }
        String name = getName();
        String name2 = sameDiffVertex.getName();
        if (name == null) {
            if (name2 != null) {
                return false;
            }
        } else if (!name.equals(name2)) {
            return false;
        }
        List<Regularization> regularization = getRegularization();
        List<Regularization> regularization2 = sameDiffVertex.getRegularization();
        if (regularization == null) {
            if (regularization2 != null) {
                return false;
            }
        } else if (!regularization.equals(regularization2)) {
            return false;
        }
        List<Regularization> regularizationBias = getRegularizationBias();
        List<Regularization> regularizationBias2 = sameDiffVertex.getRegularizationBias();
        if (regularizationBias == null) {
            if (regularizationBias2 != null) {
                return false;
            }
        } else if (!regularizationBias.equals(regularizationBias2)) {
            return false;
        }
        IUpdater updater = getUpdater();
        IUpdater updater2 = sameDiffVertex.getUpdater();
        if (updater == null) {
            if (updater2 != null) {
                return false;
            }
        } else if (!updater.equals(updater2)) {
            return false;
        }
        IUpdater biasUpdater = getBiasUpdater();
        IUpdater biasUpdater2 = sameDiffVertex.getBiasUpdater();
        if (biasUpdater == null) {
            if (biasUpdater2 != null) {
                return false;
            }
        } else if (!biasUpdater.equals(biasUpdater2)) {
            return false;
        }
        GradientNormalization gradientNormalization = getGradientNormalization();
        GradientNormalization gradientNormalization2 = sameDiffVertex.getGradientNormalization();
        if (gradientNormalization == null) {
            if (gradientNormalization2 != null) {
                return false;
            }
        } else if (!gradientNormalization.equals(gradientNormalization2)) {
            return false;
        }
        if (Double.compare(getGradientNormalizationThreshold(), sameDiffVertex.getGradientNormalizationThreshold()) != 0) {
            return false;
        }
        DataType dataType = getDataType();
        DataType dataType2 = sameDiffVertex.getDataType();
        return dataType == null ? dataType2 == null : dataType.equals(dataType2);
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof SameDiffVertex;
    }

    @Override // org.deeplearning4j.nn.conf.graph.GraphVertex
    public int hashCode() {
        SDVertexParams vertexParams = getVertexParams();
        int hashCode = (1 * 59) + (vertexParams == null ? 43 : vertexParams.hashCode());
        String name = getName();
        int hashCode2 = (hashCode * 59) + (name == null ? 43 : name.hashCode());
        List<Regularization> regularization = getRegularization();
        int hashCode3 = (hashCode2 * 59) + (regularization == null ? 43 : regularization.hashCode());
        List<Regularization> regularizationBias = getRegularizationBias();
        int hashCode4 = (hashCode3 * 59) + (regularizationBias == null ? 43 : regularizationBias.hashCode());
        IUpdater updater = getUpdater();
        int hashCode5 = (hashCode4 * 59) + (updater == null ? 43 : updater.hashCode());
        IUpdater biasUpdater = getBiasUpdater();
        int hashCode6 = (hashCode5 * 59) + (biasUpdater == null ? 43 : biasUpdater.hashCode());
        GradientNormalization gradientNormalization = getGradientNormalization();
        int hashCode7 = (hashCode6 * 59) + (gradientNormalization == null ? 43 : gradientNormalization.hashCode());
        long doubleToLongBits = Double.doubleToLongBits(getGradientNormalizationThreshold());
        int i = (hashCode7 * 59) + ((int) ((doubleToLongBits >>> 32) ^ doubleToLongBits));
        DataType dataType = getDataType();
        return (i * 59) + (dataType == null ? 43 : dataType.hashCode());
    }

    public String toString() {
        return "SameDiffVertex(vertexParams=" + getVertexParams() + ", name=" + getName() + ", regularization=" + getRegularization() + ", regularizationBias=" + getRegularizationBias() + ", updater=" + getUpdater() + ", biasUpdater=" + getBiasUpdater() + ", gradientNormalization=" + getGradientNormalization() + ", gradientNormalizationThreshold=" + getGradientNormalizationThreshold() + ", dataType=" + getDataType() + ")";
    }
}
