package org.deeplearning4j.nn.conf.serde;

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.conf.layers.BaseLayer;
import org.deeplearning4j.nn.conf.layers.BaseOutputLayer;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.AdaDelta;
import org.nd4j.linalg.learning.config.AdaGrad;
import org.nd4j.linalg.learning.config.AdaMax;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.learning.config.Nadam;
import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.learning.config.RmsProp;
import org.nd4j.linalg.learning.config.Sgd;
import org.nd4j.linalg.learning.regularization.L1Regularization;
import org.nd4j.linalg.learning.regularization.WeightDecay;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.linalg.lossfunctions.impl.LossBinaryXENT;
import org.nd4j.linalg.lossfunctions.impl.LossL2;
import org.nd4j.linalg.lossfunctions.impl.LossMCXENT;
import org.nd4j.linalg.lossfunctions.impl.LossMSE;
import org.nd4j.linalg.lossfunctions.impl.LossNegativeLogLikelihood;
import org.nd4j.shade.jackson.core.JsonParser;
import org.nd4j.shade.jackson.core.JsonProcessingException;
import org.nd4j.shade.jackson.databind.DeserializationContext;
import org.nd4j.shade.jackson.databind.JsonDeserializer;
import org.nd4j.shade.jackson.databind.JsonMappingException;
import org.nd4j.shade.jackson.databind.deser.ResolvableDeserializer;
import org.nd4j.shade.jackson.databind.deser.std.StdDeserializer;
import org.nd4j.shade.jackson.databind.node.ObjectNode;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/nn/conf/serde/BaseNetConfigDeserializer.class */
public abstract class BaseNetConfigDeserializer<T> extends StdDeserializer<T> implements ResolvableDeserializer {
    private static final Logger log = LoggerFactory.getLogger((Class<?>) BaseNetConfigDeserializer.class);
    protected final JsonDeserializer<?> defaultDeserializer;
    private static Map<String, Class<? extends IActivation>> activationMap;

    public BaseNetConfigDeserializer(JsonDeserializer<?> jsonDeserializer, Class<T> cls) {
        super((Class<?>) cls);
        this.defaultDeserializer = jsonDeserializer;
    }

    @Override // org.nd4j.shade.jackson.databind.JsonDeserializer
    public abstract T deserialize(JsonParser jsonParser, DeserializationContext deserializationContext) throws IOException, JsonProcessingException;

    /* JADX INFO: Access modifiers changed from: protected */
    public boolean requiresIUpdaterFromLegacy(Layer[] layerArr) {
        for (Layer layer : layerArr) {
            if (layer instanceof BaseLayer) {
                BaseLayer baseLayer = (BaseLayer) layer;
                if (baseLayer.getIUpdater() == null && baseLayer.initializer().numParams(baseLayer) > 0) {
                    return true;
                }
            }
        }
        return false;
    }

    protected boolean requiresDropoutFromLegacy(Layer[] layerArr) {
        for (Layer layer : layerArr) {
            if (layer.getIDropout() != null) {
                return false;
            }
        }
        return true;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public boolean requiresRegularizationFromLegacy(Layer[] layerArr) {
        for (Layer layer : layerArr) {
            if ((layer instanceof BaseLayer) && ((BaseLayer) layer).getRegularization() == null) {
                return true;
            }
        }
        return false;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public boolean requiresWeightInitFromLegacy(Layer[] layerArr) {
        for (Layer layer : layerArr) {
            if ((layer instanceof BaseLayer) && ((BaseLayer) layer).getWeightInitFn() == null) {
                return true;
            }
        }
        return false;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public boolean requiresActivationFromLegacy(Layer[] layerArr) {
        for (Layer layer : layerArr) {
            if ((layer instanceof BaseLayer) && ((BaseLayer) layer).getActivationFn() == null) {
                return true;
            }
        }
        return false;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public boolean requiresLegacyLossHandling(Layer[] layerArr) {
        for (Layer layer : layerArr) {
            if ((layer instanceof BaseOutputLayer) && ((BaseOutputLayer) layer).getLossFn() == null) {
                return true;
            }
        }
        return false;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void handleUpdaterBackwardCompatibility(BaseLayer baseLayer, ObjectNode objectNode) {
        String asText;
        if (objectNode == null || !objectNode.has("updater") || (asText = objectNode.get("updater").asText()) == null) {
            return;
        }
        Updater valueOf = Updater.valueOf(asText);
        IUpdater iUpdaterWithDefaultConfig = valueOf.getIUpdaterWithDefaultConfig();
        double asDouble = objectNode.get("learningRate").asDouble();
        double asDouble2 = objectNode.has("epsilon") ? objectNode.get("epsilon").asDouble() : Double.NaN;
        double asDouble3 = objectNode.get("rho").asDouble();
        switch (valueOf) {
            case SGD:
                ((Sgd) iUpdaterWithDefaultConfig).setLearningRate(asDouble);
                break;
            case ADAM:
                if (Double.isNaN(asDouble2)) {
                    asDouble2 = 1.0E-8d;
                }
                ((Adam) iUpdaterWithDefaultConfig).setLearningRate(asDouble);
                ((Adam) iUpdaterWithDefaultConfig).setBeta1(objectNode.get("adamMeanDecay").asDouble());
                ((Adam) iUpdaterWithDefaultConfig).setBeta2(objectNode.get("adamVarDecay").asDouble());
                ((Adam) iUpdaterWithDefaultConfig).setEpsilon(asDouble2);
                break;
            case ADAMAX:
                if (Double.isNaN(asDouble2)) {
                    asDouble2 = 1.0E-8d;
                }
                ((AdaMax) iUpdaterWithDefaultConfig).setLearningRate(asDouble);
                ((AdaMax) iUpdaterWithDefaultConfig).setBeta1(objectNode.get("adamMeanDecay").asDouble());
                ((AdaMax) iUpdaterWithDefaultConfig).setBeta2(objectNode.get("adamVarDecay").asDouble());
                ((AdaMax) iUpdaterWithDefaultConfig).setEpsilon(asDouble2);
                break;
            case ADADELTA:
                if (Double.isNaN(asDouble2)) {
                    asDouble2 = 1.0E-6d;
                }
                ((AdaDelta) iUpdaterWithDefaultConfig).setRho(asDouble3);
                ((AdaDelta) iUpdaterWithDefaultConfig).setEpsilon(asDouble2);
                break;
            case NESTEROVS:
                ((Nesterovs) iUpdaterWithDefaultConfig).setLearningRate(asDouble);
                ((Nesterovs) iUpdaterWithDefaultConfig).setMomentum(objectNode.get("momentum").asDouble());
                break;
            case NADAM:
                if (Double.isNaN(asDouble2)) {
                    asDouble2 = 1.0E-8d;
                }
                ((Nadam) iUpdaterWithDefaultConfig).setLearningRate(asDouble);
                ((Nadam) iUpdaterWithDefaultConfig).setBeta1(objectNode.get("adamMeanDecay").asDouble());
                ((Nadam) iUpdaterWithDefaultConfig).setBeta2(objectNode.get("adamVarDecay").asDouble());
                ((Nadam) iUpdaterWithDefaultConfig).setEpsilon(asDouble2);
                break;
            case ADAGRAD:
                if (Double.isNaN(asDouble2)) {
                    asDouble2 = 1.0E-6d;
                }
                ((AdaGrad) iUpdaterWithDefaultConfig).setLearningRate(asDouble);
                ((AdaGrad) iUpdaterWithDefaultConfig).setEpsilon(asDouble2);
                break;
            case RMSPROP:
                if (Double.isNaN(asDouble2)) {
                    asDouble2 = 1.0E-8d;
                }
                ((RmsProp) iUpdaterWithDefaultConfig).setLearningRate(asDouble);
                ((RmsProp) iUpdaterWithDefaultConfig).setEpsilon(asDouble2);
                ((RmsProp) iUpdaterWithDefaultConfig).setRmsDecay(objectNode.get("rmsDecay").asDouble());
                break;
        }
        baseLayer.setIUpdater(iUpdaterWithDefaultConfig);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void handleL1L2BackwardCompatibility(BaseLayer baseLayer, ObjectNode objectNode) {
        if (objectNode != null) {
            if (objectNode.has("l1") || objectNode.has("l2")) {
                baseLayer.setRegularization(new ArrayList());
                baseLayer.setRegularizationBias(new ArrayList());
                if (objectNode.has("l1")) {
                    double doubleValue = objectNode.get("l1").doubleValue();
                    if (doubleValue > 0.0d) {
                        baseLayer.getRegularization().add(new L1Regularization(doubleValue));
                    }
                }
                if (objectNode.has("l2")) {
                    double doubleValue2 = objectNode.get("l2").doubleValue();
                    if (doubleValue2 > 0.0d) {
                        baseLayer.getRegularization().add(new WeightDecay(doubleValue2, false));
                    }
                }
                if (objectNode.has("l1Bias")) {
                    double doubleValue3 = objectNode.get("l1Bias").doubleValue();
                    if (doubleValue3 > 0.0d) {
                        baseLayer.getRegularizationBias().add(new L1Regularization(doubleValue3));
                    }
                }
                if (objectNode.has("l2Bias")) {
                    double doubleValue4 = objectNode.get("l2Bias").doubleValue();
                    if (doubleValue4 > 0.0d) {
                        baseLayer.getRegularizationBias().add(new WeightDecay(doubleValue4, false));
                    }
                }
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void handleWeightInitBackwardCompatibility(BaseLayer baseLayer, ObjectNode objectNode) {
        if (objectNode != null && objectNode.has("weightInit") && objectNode.has("weightInit")) {
            try {
                WeightInit valueOf = WeightInit.valueOf(objectNode.get("weightInit").asText());
                Distribution distribution = null;
                if (valueOf == WeightInit.DISTRIBUTION && objectNode.has(Nd4j.DISTRIBUTION)) {
                    distribution = (Distribution) NeuralNetConfiguration.mapper().readValue(objectNode.get(Nd4j.DISTRIBUTION).toString(), Distribution.class);
                }
                baseLayer.setWeightInitFn(valueOf.getWeightInitFunction(distribution));
            } catch (Throwable th) {
                log.warn("Failed to infer weight initialization from legacy JSON format", th);
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void handleActivationBackwardCompatibility(BaseLayer baseLayer, ObjectNode objectNode) {
        if (baseLayer.getActivationFn() == null && objectNode.has("activationFunction")) {
            IActivation iActivation = null;
            try {
                iActivation = getMap().get(objectNode.get("activationFunction").asText().toLowerCase()).newInstance();
            } catch (IllegalAccessException | InstantiationException e) {
            }
            baseLayer.setActivationFn(iActivation);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void handleLossBackwardCompatibility(BaseOutputLayer baseOutputLayer, ObjectNode objectNode) {
        if (baseOutputLayer.getLossFn() == null && objectNode.has("activationFunction")) {
            String asText = objectNode.get("lossFunction").asText();
            ILossFunction iLossFunction = null;
            boolean z = -1;
            switch (asText.hashCode()) {
                case -2025956343:
                    if (asText.equals("MCXENT")) {
                        z = false;
                        break;
                    }
                    break;
                case -1851058561:
                    if (asText.equals("NEGATIVELOGLIKELIHOOD")) {
                        z = 2;
                        break;
                    }
                    break;
                case 76639:
                    if (asText.equals("MSE")) {
                        z = true;
                        break;
                    }
                    break;
                case 2690419:
                    if (asText.equals("XENT")) {
                        z = 4;
                        break;
                    }
                    break;
                case 1582187931:
                    if (asText.equals("SQUARED_LOSS")) {
                        z = 3;
                        break;
                    }
                    break;
            }
            switch (z) {
                case false:
                    iLossFunction = new LossMCXENT();
                    break;
                case true:
                    iLossFunction = new LossMSE();
                    break;
                case true:
                    iLossFunction = new LossNegativeLogLikelihood();
                    break;
                case true:
                    iLossFunction = new LossL2();
                    break;
                case true:
                    iLossFunction = new LossBinaryXENT();
                    break;
            }
            baseOutputLayer.setLossFn(iLossFunction);
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    private static synchronized Map<String, Class<? extends IActivation>> getMap() {
        if (activationMap == null) {
            activationMap = new HashMap();
            for (Activation activation : Activation.values()) {
                activationMap.put(activation.toString().toLowerCase(), activation.getActivationFunction().getClass());
            }
        }
        return activationMap;
    }

    @Override // org.nd4j.shade.jackson.databind.deser.ResolvableDeserializer
    public void resolve(DeserializationContext deserializationContext) throws JsonMappingException {
        ((ResolvableDeserializer) this.defaultDeserializer).resolve(deserializationContext);
    }
}
