package org.deeplearning4j.nn.layers.recurrent;

import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.BaseLayer;
import org.nd4j.linalg.api.ndarray.INDArray;

/* loaded from: input_file:org/deeplearning4j/nn/layers/recurrent/BaseRecurrentLayer.class */
public abstract class BaseRecurrentLayer<LayerConfT extends Layer> extends BaseLayer<LayerConfT> {
    protected Map<String, INDArray> stateMap;
    protected Map<String, INDArray> tBpttStateMap;

    public BaseRecurrentLayer(NeuralNetConfiguration neuralNetConfiguration) {
        super(neuralNetConfiguration);
        this.stateMap = new ConcurrentHashMap();
        this.tBpttStateMap = new ConcurrentHashMap();
    }

    public BaseRecurrentLayer(NeuralNetConfiguration neuralNetConfiguration, INDArray iNDArray) {
        super(neuralNetConfiguration, iNDArray);
        this.stateMap = new ConcurrentHashMap();
        this.tBpttStateMap = new ConcurrentHashMap();
    }

    public abstract INDArray rnnTimeStep(INDArray iNDArray);

    public Map<String, INDArray> rnnGetPreviousState() {
        return new HashMap(this.stateMap);
    }

    public void rnnSetPreviousState(Map<String, INDArray> map) {
        this.stateMap.clear();
        this.stateMap.putAll(map);
    }

    public void rnnClearPreviousState() {
        this.stateMap.clear();
        this.tBpttStateMap.clear();
    }

    public abstract INDArray rnnActivateUsingStoredState(INDArray iNDArray, boolean z, boolean z2);

    public Map<String, INDArray> rnnGetTBPTTState() {
        return new HashMap(this.tBpttStateMap);
    }

    public void rnnSetTBPTTState(Map<String, INDArray> map) {
        this.tBpttStateMap.clear();
        this.tBpttStateMap.putAll(map);
    }

    public abstract Pair<Gradient, INDArray> tbpttBackpropGradient(INDArray iNDArray, int i);
}
