package de.datexis.cdv.tagger;

import com.google.common.collect.Lists;
import de.datexis.cdv.index.AspectIndex;
import de.datexis.cdv.index.EntityIndex;
import de.datexis.common.Resource;
import de.datexis.encoder.Encoder;
import de.datexis.encoder.EncodingHelpers;
import de.datexis.encoder.IEncoder;
import de.datexis.model.Dataset;
import de.datexis.model.Document;
import de.datexis.model.Sentence;
import de.datexis.model.Span;
import de.datexis.sector.eval.ClassificationScoreCalculator;
import de.datexis.tagger.AbstractMultiDataSetIterator;
import de.datexis.tagger.DocumentSentenceIterator;
import de.datexis.tagger.Tagger;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration;
import org.deeplearning4j.earlystopping.EarlyStoppingResult;
import org.deeplearning4j.earlystopping.listener.EarlyStoppingListener;
import org.deeplearning4j.earlystopping.trainer.EarlyStoppingGraphTrainer;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.preprocessor.FeedForwardToRnnPreProcessor;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.graph.vertex.GraphVertex;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.ui.api.UIServer;
import org.deeplearning4j.ui.model.stats.StatsListener;
import org.deeplearning4j.ui.model.storage.InMemoryStatsStorage;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.shade.jackson.annotation.JsonIgnore;
import org.nd4j.shade.jackson.databind.JsonNode;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:de/datexis/cdv/tagger/CDVTagger.class */
public class CDVTagger extends Tagger implements IEncoder {
    protected static final Logger log = LoggerFactory.getLogger(CDVTagger.class);
    protected IEncoder inputEncoder;
    protected IEncoder flagEncoder;
    protected IEncoder entityEncoder;
    protected IEncoder aspectEncoder;
    protected final FeedForwardToRnnPreProcessor ff2rnn;
    protected int maxWordsPerSentence;
    protected boolean balancing;
    protected Class<? extends DocumentSentenceIterator> iteratorClass;

    public CDVTagger() {
        super("HTM");
        this.entityEncoder = null;
        this.aspectEncoder = null;
        this.ff2rnn = new FeedForwardToRnnPreProcessor();
        this.maxWordsPerSentence = -1;
        this.balancing = true;
        this.iteratorClass = CDVSentenceIterator.class;
    }

    public CDVTagger(String str) {
        super(str);
        this.entityEncoder = null;
        this.aspectEncoder = null;
        this.ff2rnn = new FeedForwardToRnnPreProcessor();
        this.maxWordsPerSentence = -1;
        this.balancing = true;
        this.iteratorClass = CDVSentenceIterator.class;
    }

    public CDVTagger(Resource resource) {
        super(resource);
        this.entityEncoder = null;
        this.aspectEncoder = null;
        this.ff2rnn = new FeedForwardToRnnPreProcessor();
        this.maxWordsPerSentence = -1;
        this.balancing = true;
        this.iteratorClass = CDVSentenceIterator.class;
        setId("HTM");
    }

    @JsonIgnore
    /* renamed from: getNN, reason: merged with bridge method [inline-methods] */
    public ComputationGraph m18getNN() {
        return this.net;
    }

    public void initializeNetwork(ComputationGraph computationGraph) {
        this.net = computationGraph;
    }

    public void setTrainingParams(int i, int i2, int i3, int i4, boolean z, boolean z2) {
        super.setTrainingParams(i, i2, i3, i4, z);
        this.balancing = z2;
    }

    public void setTrainingLimits(int i, int i2, int i3) {
        this.numExamples = i;
        this.maxTimeSeriesLength = i2;
        this.maxWordsPerSentence = i3;
    }

    public void setInputEncoders(IEncoder iEncoder, IEncoder iEncoder2) {
        this.inputEncoder = iEncoder;
        this.flagEncoder = iEncoder2;
    }

    public void setEntityEncoder(IEncoder iEncoder) {
        this.entityEncoder = iEncoder;
    }

    public void setAspectEncoder(IEncoder iEncoder) {
        this.aspectEncoder = iEncoder;
    }

    @JsonIgnore
    public IEncoder getEntityEncoder() {
        return this.entityEncoder;
    }

    @JsonIgnore
    public IEncoder getAspectEncoder() {
        return this.aspectEncoder;
    }

    public List<Encoder> getEncoders() {
        return Lists.newArrayList(new Encoder[]{(Encoder) this.inputEncoder, (Encoder) this.flagEncoder, (Encoder) this.entityEncoder, (Encoder) this.aspectEncoder});
    }

    public void setEncoders(List<Encoder> list) {
        if (list.size() < 3) {
            throw new IllegalArgumentException("wrong number of encoders given (expected=3+, actual=" + list.size() + ")");
        }
        this.inputEncoder = list.get(0);
        this.flagEncoder = list.get(1);
        if (list.size() != 3) {
            if (list.size() != 4) {
                throw new IllegalArgumentException("wrong number of encoders given (expected=3+, actual=" + list.size() + ")");
            }
            this.entityEncoder = list.get(2);
            this.aspectEncoder = list.get(3);
            return;
        }
        Encoder encoder = list.get(2);
        if (encoder instanceof EntityIndex) {
            this.entityEncoder = encoder;
        } else {
            if (!(encoder instanceof AspectIndex)) {
                throw new IllegalArgumentException("got unknown encoder " + encoder.getClass().getName());
            }
            this.aspectEncoder = encoder;
        }
    }

    public int getMaxWordsPerSentence() {
        return this.maxWordsPerSentence;
    }

    public void setMaxWordsPerSentence(int i) {
        this.maxWordsPerSentence = i;
    }

    public Class<? extends DocumentSentenceIterator> getIteratorClass() {
        return this.iteratorClass;
    }

    public void setIteratorClass(Class<? extends DocumentSentenceIterator> cls) {
        this.iteratorClass = cls;
    }

    public void trainModel(Dataset dataset) {
        trainModel(dataset, this.numEpochs);
    }

    public void tag(Collection<Document> collection) {
        throw new UnsupportedOperationException("not implemented yet");
    }

    protected DocumentSentenceIterator createIterator(AbstractMultiDataSetIterator.Stage stage, Collection<Document> collection) {
        return this.iteratorClass.equals(CDVWordIterator.class) ? stage.equals(AbstractMultiDataSetIterator.Stage.TRAIN) ? new CDVWordIterator(stage, collection, this, this.numExamples, this.maxTimeSeriesLength, this.maxWordsPerSentence, this.batchSize, true, this.balancing) : new CDVWordIterator(stage, collection, this, -1, -1, this.maxWordsPerSentence, this.batchSize, false, this.balancing) : stage.equals(AbstractMultiDataSetIterator.Stage.TRAIN) ? new CDVSentenceIterator(stage, collection, this, this.numExamples, this.maxTimeSeriesLength, this.batchSize, true, this.balancing) : new CDVSentenceIterator(stage, collection, this, -1, -1, this.batchSize, false, this.balancing);
    }

    protected synchronized void trainModel(Dataset dataset, int i) {
        DocumentSentenceIterator createIterator = createIterator(AbstractMultiDataSetIterator.Stage.TRAIN, dataset.getDocuments());
        int i2 = this.numExamples / this.batchSize;
        this.timer.start();
        appendTrainLog("Training " + getName() + " with " + this.numExamples + " examples in " + i2 + " batches for " + i + " epochs.");
        int i3 = 0;
        Nd4j.getMemoryManager().togglePeriodicGc(false);
        for (int i4 = 1; i4 <= i; i4++) {
            appendTrainLog("Starting epoch " + i4 + " of " + i + "\t" + i3);
            triggerEpochListeners(true, i4 - 1);
            m18getNN().fit(createIterator);
            i3 += this.numExamples;
            this.timer.setSplit("epoch");
            appendTrainLog("Completed epoch " + i4 + " of " + i + "\t" + i3, this.timer.getLong("epoch"));
            triggerEpochListeners(false, i4 - 1);
            if (i4 < i) {
                createIterator.reset();
            }
            Nd4j.getMemoryManager().invokeGc();
        }
        this.timer.stop();
        appendTrainLog("Training complete", this.timer.getLong());
        setModelAvailable(true);
    }

    public EarlyStoppingResult<ComputationGraph> trainModel(Dataset dataset, Dataset dataset2, EarlyStoppingConfiguration earlyStoppingConfiguration) {
        DocumentSentenceIterator createIterator = createIterator(AbstractMultiDataSetIterator.Stage.TRAIN, dataset.getDocuments());
        DocumentSentenceIterator createIterator2 = createIterator(AbstractMultiDataSetIterator.Stage.TEST, dataset2.getDocuments());
        int numExamples = (int) (createIterator.getNumExamples() / this.batchSize);
        this.timer.start();
        appendTrainLog("Training " + getName() + " with " + createIterator.getNumExamples() + " examples in " + numExamples + " batches using early stopping.");
        earlyStoppingConfiguration.setScoreCalculator(new ClassificationScoreCalculator(this, this.entityEncoder, createIterator2));
        EarlyStoppingResult<ComputationGraph> fit = new EarlyStoppingGraphTrainer(earlyStoppingConfiguration, m18getNN(), createIterator, new EarlyStoppingListener<ComputationGraph>() { // from class: de.datexis.cdv.tagger.CDVTagger.1
            public void onStart(EarlyStoppingConfiguration<ComputationGraph> earlyStoppingConfiguration2, ComputationGraph computationGraph) {
                Nd4j.getWorkspaceManager().printAllocationStatisticsForCurrentThread();
            }

            public void onEpoch(int i, double d, EarlyStoppingConfiguration<ComputationGraph> earlyStoppingConfiguration2, ComputationGraph computationGraph) {
                CDVTagger.log.info("Finished epoch {} with score {}", Integer.valueOf(i), Double.valueOf(1.0d - d));
                Nd4j.getWorkspaceManager().printAllocationStatisticsForCurrentThread();
                Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread().destroyWorkspace();
            }

            public void onCompletion(EarlyStoppingResult<ComputationGraph> earlyStoppingResult) {
                CDVTagger.log.info("Finished training with result {}", earlyStoppingResult.toString());
            }

            public /* bridge */ /* synthetic */ void onEpoch(int i, double d, EarlyStoppingConfiguration earlyStoppingConfiguration2, Model model) {
                onEpoch(i, d, (EarlyStoppingConfiguration<ComputationGraph>) earlyStoppingConfiguration2, (ComputationGraph) model);
            }

            public /* bridge */ /* synthetic */ void onStart(EarlyStoppingConfiguration earlyStoppingConfiguration2, Model model) {
                onStart((EarlyStoppingConfiguration<ComputationGraph>) earlyStoppingConfiguration2, (ComputationGraph) model);
            }
        }).fit();
        this.timer.stop();
        appendTrainLog("Training complete", this.timer.getLong());
        this.net = fit.getBestModel();
        setModelAvailable(true);
        return fit;
    }

    public void testModel(Dataset dataset) {
        this.timer.start();
        attachCDVSentenceVectors(dataset.getDocuments(), AbstractMultiDataSetIterator.Stage.TEST);
        this.timer.stop();
        appendTestLog("Testing complete", this.timer.getLong());
    }

    protected void triggerEpochListeners(boolean z, int i) {
        Collection<TrainingListener> listeners = m18getNN().getListeners();
        m18getNN().getConfiguration().setEpochCount(i);
        if (listeners == null || listeners.isEmpty()) {
            return;
        }
        for (TrainingListener trainingListener : listeners) {
            if (z) {
                trainingListener.onEpochStart(m18getNN());
            } else {
                trainingListener.onEpochEnd(m18getNN());
            }
        }
    }

    public void attachCDVSentenceVectors(Collection<Document> collection, AbstractMultiDataSetIterator.Stage stage) {
        DocumentSentenceIterator createIterator = createIterator(stage, collection);
        while (createIterator.hasNext()) {
            attachCDVSentenceVectors(createIterator.nextDocumentBatch());
        }
    }

    protected void attachCDVSentenceVectors(DocumentSentenceIterator.DocumentBatch documentBatch) {
        INDArray[] output;
        synchronized (m18getNN()) {
            m18getNN().setLabels(documentBatch.dataset.getLabels());
            output = m18getNN().output(false, documentBatch.dataset.getFeatures(), documentBatch.dataset.getFeaturesMaskArrays(), documentBatch.dataset.getLabelsMaskArrays());
        }
        INDArray iNDArray = null;
        INDArray iNDArray2 = null;
        if (getEntityEncoder() != null && getAspectEncoder() != null) {
            iNDArray = output[0];
            iNDArray2 = output[1];
        } else if (getEntityEncoder() != null) {
            iNDArray = output[0];
        } else if (getAspectEncoder() != null) {
            iNDArray2 = output[0];
        }
        int i = 0;
        Iterator it = documentBatch.docs.iterator();
        while (it.hasNext()) {
            int i2 = 0;
            for (Sentence sentence : ((Document) it.next()).getSentences()) {
                if (iNDArray != null) {
                    sentence.putVector(getEntityEncoder().getClass(), EncodingHelpers.getTimeStep(iNDArray, i, i2));
                }
                if (iNDArray2 != null) {
                    sentence.putVector(getAspectEncoder().getClass(), EncodingHelpers.getTimeStep(iNDArray2, i, i2));
                }
                if (0 != 0) {
                    sentence.putVector(CDVTagger.class, EncodingHelpers.getTimeStep((INDArray) null, i, i2));
                }
                i2++;
            }
            i++;
        }
    }

    public void attachCDVDocumentMatrix(Collection<Document> collection) {
        DocumentSentenceIterator createIterator = createIterator(AbstractMultiDataSetIterator.Stage.ENCODE, collection);
        while (createIterator.hasNext()) {
            attachCDVDocumentMatrix(createIterator.nextDocumentBatch());
        }
    }

    public void attachMatrixBaseline(Collection<Document> collection) {
        DocumentSentenceIterator createIterator = createIterator(AbstractMultiDataSetIterator.Stage.ENCODE, collection);
        while (createIterator.hasNext()) {
            attachMatrixBaseline(createIterator.nextDocumentBatch());
        }
    }

    protected void attachCDVDocumentMatrix(DocumentSentenceIterator.DocumentBatch documentBatch) {
        INDArray[] output;
        synchronized (m18getNN()) {
            m18getNN().setLabels(documentBatch.dataset.getLabels());
            output = m18getNN().output(false, documentBatch.dataset.getFeatures(), documentBatch.dataset.getFeaturesMaskArrays(), documentBatch.dataset.getLabelsMaskArrays());
        }
        INDArray iNDArray = null;
        INDArray iNDArray2 = null;
        if (getEntityEncoder() != null && getAspectEncoder() != null) {
            iNDArray = output[0];
            iNDArray2 = output[1];
        } else if (getEntityEncoder() != null) {
            iNDArray = output[0];
        } else if (getAspectEncoder() != null) {
            iNDArray2 = output[0];
        }
        int i = 0;
        for (Document document : documentBatch.docs) {
            if (document.countSentences() > 0) {
                if (iNDArray != null) {
                    INDArray iNDArray3 = iNDArray.get(new INDArrayIndex[]{NDArrayIndex.point(i), NDArrayIndex.all(), NDArrayIndex.interval(0, document.countSentences())});
                    for (int i2 = 0; i2 < iNDArray3.size(1); i2++) {
                        iNDArray3.getColumn(i2).assign(Transforms.unitVec(iNDArray3.getColumn(i2)));
                    }
                    document.putVector(getEntityEncoder().getClass(), iNDArray3);
                }
                if (iNDArray2 != null) {
                    INDArray iNDArray4 = iNDArray2.get(new INDArrayIndex[]{NDArrayIndex.point(i), NDArrayIndex.all(), NDArrayIndex.interval(0, document.countSentences())});
                    for (int i3 = 0; i3 < iNDArray4.size(1); i3++) {
                        iNDArray4.getColumn(i3).assign(Transforms.unitVec(iNDArray4.getColumn(i3)));
                    }
                    document.putVector(getAspectEncoder().getClass(), iNDArray4);
                }
            }
            i++;
        }
    }

    @Deprecated
    protected void attachMatrixBaseline(DocumentSentenceIterator.DocumentBatch documentBatch) {
        INDArray encodeTimeStepMatrix = EncodingHelpers.encodeTimeStepMatrix(documentBatch.docs, this.entityEncoder, documentBatch.maxDocLength, Sentence.class);
        int i = 0;
        for (Document document : documentBatch.docs) {
            if (document.countSentences() > 0) {
                document.putVector(getEntityEncoder().getClass(), encodeTimeStepMatrix.get(new INDArrayIndex[]{NDArrayIndex.point(i), NDArrayIndex.all(), NDArrayIndex.interval(0, document.countSentences())}));
            }
            i++;
        }
    }

    @Deprecated
    protected static void clearLayerStates(ComputationGraph computationGraph) {
        for (Layer layer : computationGraph.getLayers()) {
            layer.clear();
            layer.clearNoiseWeightParams();
        }
        for (GraphVertex graphVertex : computationGraph.getVertices()) {
            graphVertex.clearVertex();
        }
        computationGraph.clear();
        computationGraph.clearLayerMaskArrays();
    }

    public void enableTrainingUI() {
        InMemoryStatsStorage inMemoryStatsStorage = new InMemoryStatsStorage();
        m18getNN().addListeners(new TrainingListener[]{new StatsListener(inMemoryStatsStorage, 1)});
        UIServer.getInstance().attach(inMemoryStatsStorage);
        UIServer.getInstance().enableRemoteListener(inMemoryStatsStorage, true);
    }

    public void saveModel(Resource resource, String str) {
        Resource resolve = resource.resolve(str + ".zip");
        try {
            OutputStream outputStream = resolve.getOutputStream();
            Throwable th = null;
            try {
                try {
                    ModelSerializer.writeModel(this.net, outputStream, false);
                    setModel(resolve);
                    if (outputStream != null) {
                        if (0 != 0) {
                            try {
                                outputStream.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            outputStream.close();
                        }
                    }
                } finally {
                }
            } catch (Throwable th3) {
                th = th3;
                throw th3;
            }
        } catch (IOException e) {
            log.error(e.toString());
        }
    }

    public void loadModel(Resource resource) {
        try {
            InputStream inputStream = resource.getInputStream();
            Throwable th = null;
            try {
                this.net = ModelSerializer.restoreComputationGraph(inputStream, false);
                setModel(resource);
                setModelAvailable(true);
                log.info("loaded Computation Graph from " + resource.getFileName());
                if (inputStream != null) {
                    if (0 != 0) {
                        try {
                            inputStream.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        inputStream.close();
                    }
                }
            } finally {
            }
        } catch (IOException e) {
            log.error(e.toString());
        }
    }

    public long getEmbeddingVectorSize() {
        return this.embeddingLayerSize;
    }

    public INDArray encode(String str) {
        throw new UnsupportedOperationException("Not implemented yet.");
    }

    public INDArray encode(Span span) {
        throw new UnsupportedOperationException("Not implemented yet.");
    }

    public INDArray encode(Iterable<? extends Span> iterable) {
        throw new UnsupportedOperationException("Not implemented yet.");
    }

    public ComputationGraphConfiguration getGraphConfiguration() {
        return null;
    }

    public void setGraphConfiguration(JsonNode jsonNode) {
    }
}
