package org.deeplearning4j.nn.layers.recurrent;

import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.deeplearning4j.nn.api.layers.RecurrentLayer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.layers.BaseRecurrentLayer;
import org.deeplearning4j.nn.layers.BaseLayer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;

/* loaded from: input_file:org/deeplearning4j/nn/layers/recurrent/BaseRecurrentLayer.class */
public abstract class BaseRecurrentLayer<LayerConfT extends org.deeplearning4j.nn.conf.layers.BaseRecurrentLayer> extends BaseLayer<LayerConfT> implements RecurrentLayer {
    protected Map<String, INDArray> stateMap;
    protected Map<String, INDArray> tBpttStateMap;
    protected int helperCountFail;

    public BaseRecurrentLayer(NeuralNetConfiguration neuralNetConfiguration, DataType dataType) {
        super(neuralNetConfiguration, dataType);
        this.stateMap = new ConcurrentHashMap();
        this.tBpttStateMap = new ConcurrentHashMap();
        this.helperCountFail = 0;
    }

    @Override // org.deeplearning4j.nn.api.layers.RecurrentLayer
    public Map<String, INDArray> rnnGetPreviousState() {
        return new HashMap(this.stateMap);
    }

    @Override // org.deeplearning4j.nn.api.layers.RecurrentLayer
    public void rnnSetPreviousState(Map<String, INDArray> map) {
        this.stateMap.clear();
        this.stateMap.putAll(map);
    }

    @Override // org.deeplearning4j.nn.api.layers.RecurrentLayer
    public void rnnClearPreviousState() {
        this.stateMap.clear();
        this.tBpttStateMap.clear();
    }

    @Override // org.deeplearning4j.nn.api.layers.RecurrentLayer
    public Map<String, INDArray> rnnGetTBPTTState() {
        return new HashMap(this.tBpttStateMap);
    }

    @Override // org.deeplearning4j.nn.api.layers.RecurrentLayer
    public void rnnSetTBPTTState(Map<String, INDArray> map) {
        this.tBpttStateMap.clear();
        this.tBpttStateMap.putAll(map);
    }

    /* JADX WARN: Multi-variable type inference failed */
    public RNNFormat getDataFormat() {
        return ((org.deeplearning4j.nn.conf.layers.BaseRecurrentLayer) layerConf()).getRnnDataFormat();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public INDArray permuteIfNWC(INDArray iNDArray) {
        if (iNDArray == null) {
            return null;
        }
        return getDataFormat() == RNNFormat.NWC ? iNDArray.permute(new int[]{0, 2, 1}) : iNDArray;
    }
}
