package de.datexis.sector.tagger;

import com.google.common.collect.Lists;
import de.datexis.common.Resource;
import de.datexis.encoder.Encoder;
import de.datexis.encoder.EncodingHelpers;
import de.datexis.evaluation.ModelEvaluation;
import de.datexis.model.Dataset;
import de.datexis.model.Document;
import de.datexis.model.Sentence;
import de.datexis.sector.eval.ClassificationScoreCalculator;
import de.datexis.tagger.AbstractMultiDataSetIterator;
import de.datexis.tagger.DocumentSentenceIterator;
import de.datexis.tagger.Tagger;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.datasets.iterator.file.FileMultiDataSetIterator;
import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration;
import org.deeplearning4j.earlystopping.EarlyStoppingResult;
import org.deeplearning4j.earlystopping.listener.EarlyStoppingListener;
import org.deeplearning4j.earlystopping.trainer.EarlyStoppingGraphTrainer;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.BackpropType;
import org.deeplearning4j.nn.conf.CacheMode;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.nn.conf.graph.MergeVertex;
import org.deeplearning4j.nn.conf.graph.PreprocessorVertex;
import org.deeplearning4j.nn.conf.graph.SubsetVertex;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.LSTM;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional;
import org.deeplearning4j.nn.conf.preprocessor.FeedForwardToRnnPreProcessor;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.graph.vertex.GraphVertex;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.optimize.listeners.PerformanceListener;
import org.deeplearning4j.ui.api.UIServer;
import org.deeplearning4j.ui.stats.StatsListener;
import org.deeplearning4j.ui.storage.InMemoryStatsStorage;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.linalg.schedule.ExponentialSchedule;
import org.nd4j.linalg.schedule.ScheduleType;
import org.nd4j.shade.jackson.annotation.JsonIgnore;
import org.nd4j.shade.jackson.databind.JsonNode;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:de/datexis/sector/tagger/SectorTagger.class */
public class SectorTagger extends Tagger {
    protected static final Logger log = LoggerFactory.getLogger(SectorTagger.class);
    protected Encoder bagEncoder;
    protected Encoder embEncoder;
    protected Encoder flagEncoder;
    protected Encoder targetEncoder;
    protected int workers;
    protected boolean requireSubsampling;
    protected ModelEvaluation eval;
    protected final FeedForwardToRnnPreProcessor ff2rnn;

    public SectorTagger() {
        super("SECTOR");
        this.bagEncoder = null;
        this.embEncoder = null;
        this.flagEncoder = null;
        this.targetEncoder = null;
        this.workers = 4;
        this.eval = new ModelEvaluation("null");
        this.ff2rnn = new FeedForwardToRnnPreProcessor();
    }

    public SectorTagger(String str) {
        super(str);
        this.bagEncoder = null;
        this.embEncoder = null;
        this.flagEncoder = null;
        this.targetEncoder = null;
        this.workers = 4;
        this.eval = new ModelEvaluation("null");
        this.ff2rnn = new FeedForwardToRnnPreProcessor();
    }

    public SectorTagger(Resource resource) {
        super(resource);
        this.bagEncoder = null;
        this.embEncoder = null;
        this.flagEncoder = null;
        this.targetEncoder = null;
        this.workers = 4;
        this.eval = new ModelEvaluation("null");
        this.ff2rnn = new FeedForwardToRnnPreProcessor();
        setId("SECTOR");
    }

    @JsonIgnore
    /* renamed from: getNN, reason: merged with bridge method [inline-methods] */
    public ComputationGraph m28getNN() {
        return this.net;
    }

    public boolean isRequireSubsampling() {
        return this.requireSubsampling;
    }

    public void setRequireSubsampling(boolean z) {
        this.requireSubsampling = z;
    }

    public void setInputEncoders(Encoder encoder, Encoder encoder2, Encoder encoder3) {
        this.bagEncoder = encoder;
        this.embEncoder = encoder2;
        this.flagEncoder = encoder3;
    }

    public void setTargetEncoder(Encoder encoder) {
        this.targetEncoder = encoder;
    }

    public SectorTagger setWorkspaceParams(int i) {
        this.workers = i;
        return this;
    }

    @JsonIgnore
    public List<Encoder> getEncoders() {
        return Lists.newArrayList(new Encoder[]{this.bagEncoder, this.embEncoder, this.flagEncoder, this.targetEncoder});
    }

    @JsonIgnore
    public Encoder getTargetEncoder() {
        return this.targetEncoder;
    }

    public void setEncoders(List<Encoder> list) {
        if (list.size() != 4) {
            throw new IllegalArgumentException("wrong number of encoders given (expected=4, actual=" + list.size() + ")");
        }
        this.bagEncoder = list.get(0);
        this.embEncoder = list.get(1);
        this.flagEncoder = list.get(2);
        this.targetEncoder = list.get(3);
    }

    public SectorTagger buildSECTORModel(int i, int i2, int i3, int i4, double d, double d2, ILossFunction iLossFunction, Activation activation) {
        long embeddingVectorSize;
        log.info("initializing graph with layer sizes bag={}, lstm={}, emb={} and {} loss", new Object[]{Integer.valueOf(i), Integer.valueOf(i2), Integer.valueOf(i3), iLossFunction.name()});
        this.embeddingLayerSize = i3;
        ComputationGraphConfiguration.GraphBuilder addInputs = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(new Adam(new ExponentialSchedule(ScheduleType.EPOCH, d, 0.85d))).weightInit(WeightInit.XAVIER).l2(1.0E-5d).gradientNormalization(GradientNormalization.ClipL2PerLayer).trainingWorkspaceMode(WorkspaceMode.ENABLED).inferenceWorkspaceMode(WorkspaceMode.ENABLED).cacheMode(CacheMode.HOST).graphBuilder().addInputs(new String[]{"bag"}).addInputs(new String[]{"emb"}).addInputs(new String[]{"flag"});
        if (i > 0) {
            embeddingVectorSize = i + this.embEncoder.getEmbeddingVectorSize() + this.flagEncoder.getEmbeddingVectorSize();
            addInputs.addLayer("FF1", new DenseLayer.Builder().nIn(this.bagEncoder.getEmbeddingVectorSize()).nOut(i).activation(Activation.ELU).weightInit(WeightInit.RELU).build(), new String[]{"bag"}).addLayer("FF2", new DenseLayer.Builder().nIn(i).nOut(i).activation(Activation.ELU).weightInit(WeightInit.RELU).build(), new String[]{"FF1"}).addVertex("surf", new PreprocessorVertex(new FeedForwardToRnnPreProcessor()), new String[]{"FF2"}).addVertex("sentence", new MergeVertex(), new String[]{"surf", "emb", "flag"});
        } else {
            embeddingVectorSize = this.bagEncoder.getEmbeddingVectorSize() + this.embEncoder.getEmbeddingVectorSize() + this.flagEncoder.getEmbeddingVectorSize();
            addInputs.addVertex("sentence", new MergeVertex(), new String[]{"bag", "emb", "flag"});
        }
        addInputs.addLayer("BLSTM", new Bidirectional(Bidirectional.Mode.CONCAT, new LSTM.Builder().nIn(embeddingVectorSize).nOut(i2).activation(Activation.TANH).gateActivationFunction(Activation.SIGMOID).dropOut(d2).build()), new String[]{"sentence"});
        addInputs.addVertex("FW", new SubsetVertex(0, i2 - 1), new String[]{"BLSTM"});
        addInputs.addVertex("BW", new SubsetVertex(i2, (2 * i2) - 1), new String[]{"BLSTM"});
        if (this.embeddingLayerSize > 0) {
            addInputs.addLayer("embeddingFW", new DenseLayer.Builder().nIn(i2).nOut(i3).activation(Activation.TANH).build(), new String[]{"FW"}).addLayer("embeddingBW", new DenseLayer.Builder().nIn(i2).nOut(i3).activation(Activation.TANH).build(), new String[]{"BW"});
            addInputs.addLayer("targetFW", new RnnOutputLayer.Builder(iLossFunction).nIn(i3).nOut(this.targetEncoder.getEmbeddingVectorSize()).activation(activation).weightInit(WeightInit.SIGMOID_UNIFORM).build(), new String[]{"embeddingFW"}).addLayer("targetBW", new RnnOutputLayer.Builder(iLossFunction).nIn(i3).nOut(this.targetEncoder.getEmbeddingVectorSize()).activation(activation).weightInit(WeightInit.SIGMOID_UNIFORM).build(), new String[]{"embeddingBW"});
        } else {
            addInputs.addLayer("targetFW", new RnnOutputLayer.Builder(iLossFunction).nIn(i2).nOut(this.targetEncoder.getEmbeddingVectorSize()).activation(activation).weightInit(WeightInit.SIGMOID_UNIFORM).build(), new String[]{"FW"}).addLayer("targetBW", new RnnOutputLayer.Builder(iLossFunction).nIn(i2).nOut(this.targetEncoder.getEmbeddingVectorSize()).activation(activation).weightInit(WeightInit.SIGMOID_UNIFORM).build(), new String[]{"BW"});
        }
        addInputs.setOutputs(new String[]{"targetFW", "targetBW"}).setInputTypes(new InputType[]{InputType.recurrent(this.inputVectorSize), InputType.recurrent(this.inputVectorSize), InputType.recurrent(this.inputVectorSize)}).backpropType(BackpropType.Standard);
        ComputationGraph computationGraph = new ComputationGraph(addInputs.build());
        computationGraph.init();
        this.net = computationGraph;
        this.net.setListeners(new TrainingListener[]{new PerformanceListener(16, true)});
        return this;
    }

    public void trainModelPresaved(String str, int i) {
        FileMultiDataSetIterator fileMultiDataSetIterator = new FileMultiDataSetIterator(new File(str), this.batchSize);
        this.timer.start();
        for (int i2 = 0; i2 < i; i2++) {
            m28getNN().fit(fileMultiDataSetIterator, i);
            fileMultiDataSetIterator.reset();
            appendTrainLog("Completed epoch " + i2 + " of " + this.numEpochs, this.timer.getLong("epoch"));
            Nd4j.getMemoryManager().invokeGc();
        }
        this.timer.stop();
        appendTrainLog("Training complete", this.timer.getLong());
        Nd4j.getMemoryManager().togglePeriodicGc(true);
        setModelAvailable(true);
    }

    public void trainModel(Dataset dataset) {
        trainModel(dataset, this.numEpochs);
    }

    public void trainModel(Dataset dataset, int i) {
        SectorTaggerIterator sectorTaggerIterator = new SectorTaggerIterator(AbstractMultiDataSetIterator.Stage.TRAIN, dataset.getDocuments(), this, this.numExamples, this.maxTimeSeriesLength, this.batchSize, true, this.requireSubsampling);
        int i2 = this.numExamples / this.batchSize;
        this.timer.start();
        appendTrainLog("Training " + getName() + " with " + this.numExamples + " examples in " + i2 + " batches for " + i + " epochs.");
        int i3 = 0;
        Nd4j.getMemoryManager().togglePeriodicGc(false);
        for (int i4 = 1; i4 <= i; i4++) {
            appendTrainLog("Starting epoch " + i4 + " of " + i);
            triggerEpochListeners(true, i4 - 1);
            m28getNN().fit(sectorTaggerIterator);
            i3 += this.numExamples;
            this.timer.setSplit("epoch");
            appendTrainLog("Completed epoch " + i4 + " of " + i, this.timer.getLong("epoch"));
            triggerEpochListeners(false, i4 - 1);
            if (i4 < i) {
                sectorTaggerIterator.reset();
            }
            Nd4j.getMemoryManager().invokeGc();
        }
        this.timer.stop();
        appendTrainLog("Training complete", this.timer.getLong());
        Nd4j.getMemoryManager().togglePeriodicGc(true);
        setModelAvailable(true);
    }

    public EarlyStoppingResult<ComputationGraph> trainModel(Dataset dataset, Dataset dataset2, EarlyStoppingConfiguration earlyStoppingConfiguration) {
        SectorTaggerIterator sectorTaggerIterator = new SectorTaggerIterator(AbstractMultiDataSetIterator.Stage.TRAIN, dataset.getDocuments(), this, this.numExamples, this.maxTimeSeriesLength, this.batchSize, true, this.requireSubsampling);
        SectorTaggerIterator sectorTaggerIterator2 = new SectorTaggerIterator(AbstractMultiDataSetIterator.Stage.TEST, dataset2.getDocuments(), this, -1, this.maxTimeSeriesLength, this.batchSize, false, this.requireSubsampling);
        int numExamples = (int) (sectorTaggerIterator.getNumExamples() / this.batchSize);
        this.timer.start();
        appendTrainLog("Training " + getName() + " with " + sectorTaggerIterator.getNumExamples() + " examples in " + numExamples + " batches using early stopping.");
        earlyStoppingConfiguration.setScoreCalculator(new ClassificationScoreCalculator((Tagger) this, this.targetEncoder, (MultiDataSetIterator) sectorTaggerIterator2));
        EarlyStoppingGraphTrainer earlyStoppingGraphTrainer = new EarlyStoppingGraphTrainer(earlyStoppingConfiguration, m28getNN(), sectorTaggerIterator, new EarlyStoppingListener<ComputationGraph>() { // from class: de.datexis.sector.tagger.SectorTagger.1
            public void onStart(EarlyStoppingConfiguration<ComputationGraph> earlyStoppingConfiguration2, ComputationGraph computationGraph) {
            }

            public void onEpoch(int i, double d, EarlyStoppingConfiguration<ComputationGraph> earlyStoppingConfiguration2, ComputationGraph computationGraph) {
                Nd4j.getMemoryManager().invokeGc();
            }

            public void onCompletion(EarlyStoppingResult<ComputationGraph> earlyStoppingResult) {
                SectorTagger.log.info("Finished training with result {}", earlyStoppingResult.toString());
            }

            public /* bridge */ /* synthetic */ void onEpoch(int i, double d, EarlyStoppingConfiguration earlyStoppingConfiguration2, Model model) {
                onEpoch(i, d, (EarlyStoppingConfiguration<ComputationGraph>) earlyStoppingConfiguration2, (ComputationGraph) model);
            }

            public /* bridge */ /* synthetic */ void onStart(EarlyStoppingConfiguration earlyStoppingConfiguration2, Model model) {
                onStart((EarlyStoppingConfiguration<ComputationGraph>) earlyStoppingConfiguration2, (ComputationGraph) model);
            }
        });
        Nd4j.getMemoryManager().togglePeriodicGc(false);
        EarlyStoppingResult<ComputationGraph> fit = earlyStoppingGraphTrainer.fit();
        Nd4j.getMemoryManager().togglePeriodicGc(true);
        this.timer.stop();
        appendTrainLog("Training complete", this.timer.getLong());
        this.net = fit.getBestModel();
        setModelAvailable(true);
        return fit;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public void testModel(Dataset dataset) {
        this.timer.start();
        attachVectors(dataset.getDocuments(), AbstractMultiDataSetIterator.Stage.TEST, this.targetEncoder.getClass());
        this.timer.stop();
        appendTestLog("Testing complete", this.timer.getLong());
    }

    public void tag(Collection<Document> collection) {
        throw new UnsupportedOperationException("not implemented");
    }

    public Map<String, INDArray> encodeMatrix(DocumentSentenceIterator.DocumentBatch documentBatch) {
        Map<String, INDArray> feedForward = feedForward(m28getNN(), documentBatch.dataset);
        if (feedForward.containsKey("embedding")) {
            feedForward.put("embedding", this.ff2rnn.preProcess(feedForward.get("embedding"), documentBatch.size, LayerWorkspaceMgr.noWorkspaces()));
        } else if (feedForward.containsKey("embeddingFW")) {
            INDArray preProcess = this.ff2rnn.preProcess(feedForward.get("embeddingFW"), documentBatch.size, LayerWorkspaceMgr.noWorkspaces());
            INDArray preProcess2 = this.ff2rnn.preProcess(feedForward.get("embeddingBW"), documentBatch.size, LayerWorkspaceMgr.noWorkspaces());
            feedForward.put("embeddingFW", preProcess);
            feedForward.put("embeddingBW", preProcess2);
            feedForward.put("embedding", preProcess.add(preProcess2).divi(2));
        }
        return feedForward;
    }

    public static Map<String, INDArray> feedForward(ComputationGraph computationGraph, MultiDataSet multiDataSet) {
        INDArray[] features = multiDataSet.getFeatures();
        computationGraph.setLayerMaskArrays(multiDataSet.getFeaturesMaskArrays(), multiDataSet.getLabelsMaskArrays());
        Map<String, INDArray> feedForward = computationGraph.feedForward(features, false, true);
        if (!feedForward.containsKey("target") && feedForward.containsKey("targetFW")) {
            feedForward.put("target", feedForward.get("targetFW").add(feedForward.get("targetBW")).divi(2));
        }
        return feedForward;
    }

    protected void triggerEpochListeners(boolean z, int i) {
        Collection<TrainingListener> listeners = m28getNN().getListeners();
        m28getNN().getConfiguration().setEpochCount(i);
        if (listeners == null || listeners.isEmpty()) {
            return;
        }
        for (TrainingListener trainingListener : listeners) {
            if (z) {
                trainingListener.onEpochStart(m28getNN());
            } else {
                trainingListener.onEpochEnd(m28getNN());
            }
        }
    }

    public void attachVectors(Collection<Document> collection, AbstractMultiDataSetIterator.Stage stage, Class<? extends Encoder> cls) {
        SectorTaggerIterator sectorTaggerIterator = new SectorTaggerIterator(stage, collection, this, this.batchSize, false, this.requireSubsampling);
        while (sectorTaggerIterator.hasNext()) {
            attachVectors(sectorTaggerIterator.nextDocumentBatch(), cls);
        }
    }

    protected void attachVectors(DocumentSentenceIterator.DocumentBatch documentBatch, Class<? extends Encoder> cls) {
        Map<String, INDArray> encodeMatrix = encodeMatrix(documentBatch);
        INDArray iNDArray = encodeMatrix.get("target");
        INDArray iNDArray2 = null;
        INDArray iNDArray3 = null;
        INDArray iNDArray4 = encodeMatrix.containsKey("embedding") ? encodeMatrix.get("embedding") : null;
        if (encodeMatrix.containsKey("embeddingFW")) {
            iNDArray2 = encodeMatrix.get("embeddingFW");
            iNDArray3 = encodeMatrix.get("embeddingBW");
        }
        int i = 0;
        Iterator it = documentBatch.docs.iterator();
        while (it.hasNext()) {
            int i2 = 0;
            for (Sentence sentence : ((Document) it.next()).getSentences()) {
                sentence.putVector(this.targetEncoder.getClass(), EncodingHelpers.getTimeStep(iNDArray, i, i2));
                if (iNDArray4 != null) {
                    sentence.putVector(SectorEncoder.class, EncodingHelpers.getTimeStep(iNDArray4, i, i2));
                }
                if (iNDArray2 != null) {
                    INDArray timeStep = EncodingHelpers.getTimeStep(iNDArray2, i, i2);
                    INDArray timeStep2 = EncodingHelpers.getTimeStep(iNDArray3, i, i2);
                    sentence.putVector("embeddingFW", timeStep);
                    sentence.putVector("embeddingBW", timeStep2);
                }
                i2++;
            }
            i++;
        }
    }

    protected static void clearLayerStates(ComputationGraph computationGraph) {
        for (Layer layer : computationGraph.getLayers()) {
            layer.clear();
            layer.clearNoiseWeightParams();
        }
        for (GraphVertex graphVertex : computationGraph.getVertices()) {
            graphVertex.clearVertex();
        }
        computationGraph.clear();
        computationGraph.clearLayerMaskArrays();
    }

    public void enableTrainingUI() {
        InMemoryStatsStorage inMemoryStatsStorage = new InMemoryStatsStorage();
        this.net.addListeners(new TrainingListener[]{new StatsListener(inMemoryStatsStorage, 1)});
        UIServer.getInstance().attach(inMemoryStatsStorage);
        UIServer.getInstance().enableRemoteListener(inMemoryStatsStorage, true);
    }

    public void saveModel(Resource resource, String str) {
        Resource resolve = resource.resolve(str + ".zip");
        try {
            OutputStream outputStream = resolve.getOutputStream();
            Throwable th = null;
            try {
                try {
                    ModelSerializer.writeModel(this.net, outputStream, true);
                    setModel(resolve);
                    if (outputStream != null) {
                        if (0 != 0) {
                            try {
                                outputStream.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            outputStream.close();
                        }
                    }
                } finally {
                }
            } catch (Throwable th3) {
                th = th3;
                throw th3;
            }
        } catch (IOException e) {
            log.error(e.toString());
        }
    }

    public void loadModel(Resource resource) {
        try {
            InputStream inputStream = resource.getInputStream();
            Throwable th = null;
            try {
                try {
                    this.net = ModelSerializer.restoreComputationGraph(inputStream, false);
                    setModel(resource);
                    setModelAvailable(true);
                    log.info("loaded Computation Graph from " + resource.getFileName());
                    if (inputStream != null) {
                        if (0 != 0) {
                            try {
                                inputStream.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            inputStream.close();
                        }
                    }
                } catch (Throwable th3) {
                    th = th3;
                    throw th3;
                }
            } finally {
            }
        } catch (IOException e) {
            log.error(e.toString());
        }
    }

    public ComputationGraphConfiguration getGraphConfiguration() {
        return null;
    }

    public void setGraphConfiguration(JsonNode jsonNode) {
    }
}
