package de.datexis.sector;

import de.datexis.annotator.Annotator;
import de.datexis.annotator.AnnotatorComponent;
import de.datexis.common.Resource;
import de.datexis.common.WordHelpers;
import de.datexis.encoder.Encoder;
import de.datexis.encoder.LookupCacheEncoder;
import de.datexis.model.Annotation;
import de.datexis.model.Dataset;
import de.datexis.model.Document;
import de.datexis.model.Sentence;
import de.datexis.sector.encoder.ClassEncoder;
import de.datexis.sector.encoder.ClassTag;
import de.datexis.sector.encoder.HeadingEncoder;
import de.datexis.sector.encoder.HeadingTag;
import de.datexis.sector.eval.SectorEvaluation;
import de.datexis.sector.model.SectionAnnotation;
import de.datexis.sector.tagger.ScoreImprovementMinEpochsTerminationCondition;
import de.datexis.sector.tagger.SectorEncoder;
import de.datexis.sector.tagger.SectorTagger;
import de.datexis.sector.tagger.SectorTaggerIterator;
import de.datexis.tagger.AbstractMultiDataSetIterator;
import de.datexis.tagger.Tagger;
import java.io.BufferedOutputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.commons.lang3.StringUtils;
import org.deeplearning4j.datasets.iterator.AsyncMultiDataSetIterator;
import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration;
import org.deeplearning4j.earlystopping.termination.EpochTerminationCondition;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.distribution.impl.NormalDistribution;
import org.nd4j.linalg.dimensionalityreduction.PCA;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:de/datexis/sector/SectorAnnotator.class */
public class SectorAnnotator extends Annotator {
    protected static final Logger log = LoggerFactory.getLogger(SectorAnnotator.class);
    protected String presavedDatasetDirectory;

    /* loaded from: input_file:de/datexis/sector/SectorAnnotator$Builder.class */
    public static class Builder {
        protected Encoder[] encoders = new Encoder[0];
        protected ILossFunction lossFunc = LossFunctions.LossFunction.MCXENT.getILossFunction();
        protected Activation activation = Activation.SOFTMAX;
        protected boolean requireSubsampling = false;
        private int examplesPerEpoch = -1;
        private int maxTimeSeriesLength = -1;
        private int ffwLayerSize = 0;
        private int lstmLayerSize = 256;
        private int embeddingLayerSize = 128;
        private double learningRate = 0.01d;
        private double dropOut = 0.5d;
        private int iterations = 1;
        private int batchSize = 16;
        private int numEpochs = 1;
        private boolean enabletrainingUI = false;
        SectorTagger tagger = new SectorTagger();
        SectorAnnotator ann = new SectorAnnotator(this.tagger);

        public Builder withId(String str) {
            this.tagger.setId(str);
            return this;
        }

        public Builder withDataset(String str, WordHelpers.Language language) {
            this.ann.getProvenance().setDataset(str);
            this.ann.getProvenance().setLanguage(language.toString().toLowerCase());
            return this;
        }

        public Builder withLossFunction(LossFunctions.LossFunction lossFunction, Activation activation, boolean z) {
            this.lossFunc = lossFunction.getILossFunction();
            this.requireSubsampling = z;
            this.activation = activation;
            return this;
        }

        public Builder withLossFunction(ILossFunction iLossFunction, Activation activation, boolean z) {
            this.lossFunc = iLossFunction;
            this.requireSubsampling = z;
            this.activation = activation;
            return this;
        }

        public Builder withModelParams(int i, int i2, int i3) {
            this.ffwLayerSize = i;
            this.lstmLayerSize = i2;
            this.embeddingLayerSize = i3;
            return this;
        }

        public Builder withTrainingParams(double d, double d2, int i, int i2, int i3) {
            this.learningRate = d;
            this.dropOut = d2;
            this.examplesPerEpoch = i;
            this.batchSize = i2;
            this.numEpochs = i3;
            return this;
        }

        public Builder withTrainingParams(double d, double d2, int i, int i2, int i3, int i4) {
            this.learningRate = d;
            this.dropOut = d2;
            this.examplesPerEpoch = i;
            this.batchSize = i3;
            this.maxTimeSeriesLength = i2;
            this.numEpochs = i4;
            return this;
        }

        public Builder withInputEncoders(String str, Encoder encoder, Encoder encoder2, Encoder encoder3) {
            this.tagger.setInputEncoders(encoder, encoder2, encoder3);
            this.ann.getProvenance().setFeatures(str);
            this.ann.addComponent(encoder);
            this.ann.addComponent(encoder2);
            this.ann.addComponent(encoder3);
            return this;
        }

        public Builder withTargetEncoder(Encoder encoder) {
            this.tagger.setTargetEncoder(encoder);
            this.ann.addComponent(encoder);
            return this;
        }

        public Builder enableTrainingUI(boolean z) {
            this.enabletrainingUI = z;
            return this;
        }

        public Builder pretrain(Dataset dataset) {
            for (Encoder encoder : this.encoders) {
                encoder.trainModel(dataset.streamDocuments());
            }
            return this;
        }

        public SectorAnnotator build() {
            this.tagger.buildSECTORModel(this.ffwLayerSize, this.lstmLayerSize, this.embeddingLayerSize, this.iterations, this.learningRate, this.dropOut, this.lossFunc, this.activation);
            if (this.enabletrainingUI) {
                this.tagger.enableTrainingUI();
            }
            this.tagger.setRequireSubsampling(this.requireSubsampling);
            this.tagger.setTrainingParams(this.examplesPerEpoch, this.maxTimeSeriesLength, this.batchSize, this.numEpochs, true);
            this.ann.getProvenance().setTask(this.tagger.getId());
            this.tagger.setName(this.ann.getProvenance().toString());
            this.tagger.appendTrainLog(printParams());
            return this.ann;
        }

        private String printParams() {
            StringBuilder sb = new StringBuilder();
            sb.append("TRAINING PARAMS: ").append(this.tagger.getName()).append("\n");
            sb.append("\nEncoders:\n");
            for (Encoder encoder : this.tagger.getEncoders()) {
                sb.append(encoder.getId()).append("\t").append(encoder.getClass().getSimpleName()).append("\t").append(encoder.getEmbeddingVectorSize()).append("\n");
            }
            sb.append("\nNetwork Params:\n");
            sb.append("FF").append("\t").append(this.ffwLayerSize).append("\n");
            sb.append("BLSTM").append("\t").append(this.lstmLayerSize).append("\n");
            sb.append("EMB").append("\t").append(this.embeddingLayerSize).append("\n");
            sb.append("\nTraining Params:\n");
            sb.append("examples per epoch").append("\t").append(this.examplesPerEpoch).append("\n");
            sb.append("max time series length").append("\t").append(this.maxTimeSeriesLength).append("\n");
            sb.append("epochs").append("\t").append(this.numEpochs).append("\n");
            sb.append("iterations").append("\t").append(this.iterations).append("\n");
            sb.append("batch size").append("\t").append(this.batchSize).append("\n");
            sb.append("learning rate").append("\t").append(this.learningRate).append("\n");
            sb.append("dropout").append("\t").append(this.dropOut).append("\n");
            sb.append("loss").append("\t").append(this.lossFunc.toString()).append(this.requireSubsampling ? " (1-hot subsampled)" : " (1-hot/n-hot)").append("\n");
            sb.append("\n");
            return sb.toString();
        }
    }

    /* loaded from: input_file:de/datexis/sector/SectorAnnotator$SegmentationMethod.class */
    public enum SegmentationMethod {
        NONE,
        GOLD,
        NL,
        MAX,
        EMD,
        BEMD,
        BEMD_FIXED
    }

    public void setPresavedDatasetDirectory(Resource resource) {
        this.presavedDatasetDirectory = resource.getPath().toAbsolutePath().toString();
    }

    public SectorAnnotator() {
        this.presavedDatasetDirectory = "";
    }

    public SectorAnnotator(Tagger tagger) {
        super(tagger);
        this.presavedDatasetDirectory = "";
    }

    protected SectorAnnotator(AnnotatorComponent annotatorComponent) {
        super(annotatorComponent);
        this.presavedDatasetDirectory = "";
    }

    /* renamed from: getTagger, reason: merged with bridge method [inline-methods] */
    public SectorTagger m4getTagger() {
        return (SectorTagger) super.getTagger();
    }

    public LookupCacheEncoder getTargetEncoder() {
        return m4getTagger().getTargetEncoder();
    }

    public void annotate(Collection<Document> collection) {
        annotate(collection, SegmentationMethod.BEMD);
    }

    /* JADX WARN: Multi-variable type inference failed */
    public void annotate(Collection<Document> collection, SegmentationMethod segmentationMethod) {
        log.info("Running SECTOR neural net encoding...");
        m4getTagger().attachVectors(collection, AbstractMultiDataSetIterator.Stage.ENCODE, getTargetEncoder().getClass());
        if (segmentationMethod.equals(SegmentationMethod.NONE)) {
            return;
        }
        segment(collection, segmentationMethod, true);
    }

    public void segment(Collection<Document> collection, SegmentationMethod segmentationMethod, boolean z) {
        log.info("Predicting segmentation {}...", segmentationMethod.toString());
        detectSections(collection, segmentationMethod);
        if (z) {
        }
        log.info("Attaching Annotations...");
        Iterator<Document> it = collection.iterator();
        while (it.hasNext()) {
            attachVectorsToAnnotations(it.next(), getTargetEncoder());
        }
        log.info("Segmentation done.");
    }

    /* JADX WARN: Failed to find 'out' block for switch in B:10:0x0081. Please report as an issue. */
    /* JADX WARN: Removed duplicated region for block: B:14:0x0110 A[SYNTHETIC] */
    /* JADX WARN: Removed duplicated region for block: B:28:0x0057 A[SYNTHETIC] */
    /*
        Code decompiled incorrectly, please refer to instructions dump.
        To view partially-correct add '--show-bad-code' argument
    */
    protected void detectSections(java.util.Collection<de.datexis.model.Document> r5, de.datexis.sector.SectorAnnotator.SegmentationMethod r6) {
        /*
            Method dump skipped, instructions count: 380
            To view this dump add '--comments-level debug' option
        */
        throw new UnsupportedOperationException("Method not decompiled: de.datexis.sector.SectorAnnotator.detectSections(java.util.Collection, de.datexis.sector.SectorAnnotator$SegmentationMethod):void");
    }

    public double evaluateModel(Dataset dataset) {
        return evaluateModel(dataset, true, true, true);
    }

    public double evaluateModel(Dataset dataset, boolean z, boolean z2, boolean z3) {
        SectorEvaluation sectorEvaluation;
        if (getTargetEncoder().getClass() == HeadingEncoder.class) {
            HeadingEncoder headingEncoder = (HeadingEncoder) getComponent(HeadingEncoder.ID);
            sectorEvaluation = new SectorEvaluation(dataset.getName(), Annotation.Source.GOLD, Annotation.Source.PRED, headingEncoder);
            if (z) {
                log.info("Creating tags...");
                removeTags(dataset.getDocuments(), Annotation.Source.PRED);
                createHeadingTags(dataset.getDocuments(), Annotation.Source.GOLD, headingEncoder);
                createHeadingTags(dataset.getDocuments(), Annotation.Source.PRED, headingEncoder);
            }
        } else {
            if (getTargetEncoder().getClass() != ClassEncoder.class) {
                throw new IllegalArgumentException("Target encoder has no evaluation: " + getTargetEncoder().getClass().toString());
            }
            ClassEncoder classEncoder = (ClassEncoder) getComponent(ClassEncoder.ID);
            sectorEvaluation = new SectorEvaluation(dataset.getName(), Annotation.Source.GOLD, Annotation.Source.PRED, classEncoder);
            if (z) {
                log.info("Creating tags...");
                removeTags(dataset.getDocuments(), Annotation.Source.PRED);
                createClassTags(dataset.getDocuments(), Annotation.Source.GOLD, classEncoder);
                createClassTags(dataset.getDocuments(), Annotation.Source.PRED, classEncoder);
            }
        }
        sectorEvaluation.withSentenceClassEvaluation(z).withSegmentationEvaluation(z3).withSegmentClassEvaluation(z2).calculateScores(dataset);
        m4getTagger().appendTestLog(SectorEvaluation.printDatasetStats(dataset));
        m4getTagger().appendTestLog(sectorEvaluation.printEvaluationStats());
        m4getTagger().appendTestLog(sectorEvaluation.printSingleClassStats());
        return sectorEvaluation.getScore();
    }

    public void exportBatchesToFiles(Resource resource, Dataset dataset, int i, int i2) throws IOException {
        m4getTagger();
        if (i2 == -1) {
            i2 = 256;
        }
        AsyncMultiDataSetIterator asyncMultiDataSetIterator = new AsyncMultiDataSetIterator(new SectorTaggerIterator(AbstractMultiDataSetIterator.Stage.TRAIN, dataset.getDocuments(), m4getTagger(), dataset.getDocuments().size(), m4getTagger().getMaxTimeSeriesLength(), i, true, false), i2);
        this.presavedDatasetDirectory = resource.getPath().toAbsolutePath().toString();
        int i3 = 0;
        while (asyncMultiDataSetIterator.hasNext()) {
            BufferedOutputStream bufferedOutputStream = new BufferedOutputStream(new FileOutputStream(this.presavedDatasetDirectory + "/train_" + i3 + ".bin"));
            asyncMultiDataSetIterator.next().save(bufferedOutputStream);
            i3++;
            log.info("Exported Batch: " + i3);
            bufferedOutputStream.close();
        }
    }

    public void trainModelPresaved(int i) {
        m4getTagger().trainModelPresaved(this.presavedDatasetDirectory, i);
    }

    public void trainModel(Dataset dataset) {
        this.provenance.setDataset(dataset.getName());
        this.provenance.setLanguage(dataset.getLanguage());
        m4getTagger().trainModel(dataset);
    }

    public void trainModel(Dataset dataset, int i) {
        this.provenance.setDataset(dataset.getName());
        this.provenance.setLanguage(dataset.getLanguage());
        m4getTagger().trainModel(dataset, i);
    }

    public void trainModelEarlyStopping(Dataset dataset, Dataset dataset2, int i, int i2, int i3) {
        m4getTagger().appendTrainLog("Training complete " + m4getTagger().trainModel(dataset, dataset2, new EarlyStoppingConfiguration.Builder().evaluateEveryNEpochs(1).epochTerminationConditions(new EpochTerminationCondition[]{new ScoreImprovementMinEpochsTerminationCondition(i, i2, i3)}).saveLastModel(false).build()).toString());
    }

    private void createHeadingTags(Iterable<Document> iterable, Annotation.Source source, HeadingEncoder headingEncoder) {
        HeadingTag.Factory factory = new HeadingTag.Factory(headingEncoder);
        for (Document document : iterable) {
            if (!document.isTagAvaliable(source, HeadingTag.class)) {
                if (source.equals(Annotation.Source.GOLD)) {
                    factory.attachFromSectionAnnotations(document, source);
                } else if (source.equals(Annotation.Source.PRED)) {
                    factory.attachFromSentenceVectors(document, HeadingEncoder.class, source);
                }
            }
        }
    }

    private void createClassTags(Iterable<Document> iterable, Annotation.Source source, ClassEncoder classEncoder) {
        ClassTag.Factory factory = new ClassTag.Factory(classEncoder);
        for (Document document : iterable) {
            if (!document.isTagAvaliable(source, ClassTag.class)) {
                if (source.equals(Annotation.Source.GOLD)) {
                    factory.attachFromSectionAnnotations(document, source);
                } else if (source.equals(Annotation.Source.PRED)) {
                    factory.attachFromSentenceVectors(document, ClassEncoder.class, source);
                }
            }
        }
    }

    private static void removeTags(Iterable<Document> iterable, Annotation.Source source) {
        for (Document document : iterable) {
            Iterator it = document.getSentences().iterator();
            while (it.hasNext()) {
                ((Sentence) it.next()).clearTags(source);
            }
            document.setTagAvailable(source, HeadingTag.class, false);
            document.setTagAvailable(source, ClassTag.class, false);
        }
    }

    protected static void attachVectorsToAnnotations(Document document, LookupCacheEncoder lookupCacheEncoder) {
        for (SectionAnnotation sectionAnnotation : document.getAnnotations(Annotation.Source.GOLD, SectionAnnotation.class)) {
            if (lookupCacheEncoder.getClass() == ClassEncoder.class) {
                sectionAnnotation.putVector(ClassEncoder.class, lookupCacheEncoder.encode(sectionAnnotation.getSectionLabel()));
            } else if (lookupCacheEncoder.getClass() == HeadingEncoder.class) {
                sectionAnnotation.putVector(HeadingEncoder.class, lookupCacheEncoder.encode(sectionAnnotation.getSectionHeading()));
            }
        }
        for (SectionAnnotation sectionAnnotation2 : document.getAnnotations(Annotation.Source.PRED, SectionAnnotation.class)) {
            int i = 0;
            INDArray zeros = Nd4j.zeros(lookupCacheEncoder.getEmbeddingVectorSize(), 1L);
            Iterator it = ((List) document.streamSentencesInRange(sectionAnnotation2.getBegin(), sectionAnnotation2.getEnd(), false).collect(Collectors.toList())).iterator();
            while (it.hasNext()) {
                zeros.addi(((Sentence) it.next()).getVector(lookupCacheEncoder.getClass()));
                i++;
            }
            if (i > 1) {
                zeros.divi(Integer.valueOf(i));
            }
            if (lookupCacheEncoder.getClass() == ClassEncoder.class) {
                sectionAnnotation2.putVector(ClassEncoder.class, zeros);
                sectionAnnotation2.setSectionLabel(lookupCacheEncoder.getNearestNeighbour(zeros));
                sectionAnnotation2.setConfidence(lookupCacheEncoder.getMaxConfidence(zeros));
            } else if (lookupCacheEncoder.getClass() == HeadingEncoder.class) {
                sectionAnnotation2.putVector(HeadingEncoder.class, zeros);
                sectionAnnotation2.setSectionHeading(StringUtils.join(lookupCacheEncoder.getNearestNeighbours(zeros, 2), "/"));
                sectionAnnotation2.setConfidence(lookupCacheEncoder.getMaxConfidence(zeros));
            }
        }
    }

    private static void applySectionsFromGold(Document document) {
        for (SectionAnnotation sectionAnnotation : document.getAnnotations(Annotation.Source.GOLD, SectionAnnotation.class)) {
            SectionAnnotation sectionAnnotation2 = new SectionAnnotation(Annotation.Source.PRED);
            sectionAnnotation2.setBegin(sectionAnnotation.getBegin());
            sectionAnnotation2.setEnd(sectionAnnotation.getEnd());
            document.addAnnotation(sectionAnnotation2);
        }
    }

    private static void applySectionsFromNewlines(Document document) {
        SectionAnnotation sectionAnnotation = null;
        for (Sentence sentence : document.getSentences()) {
            boolean anyMatch = sentence.streamTokens().anyMatch(token -> {
                return token.getText().equals("*NL*") || token.getText().equals("\n");
            });
            if (sectionAnnotation == null) {
                sectionAnnotation = new SectionAnnotation(Annotation.Source.PRED);
                sectionAnnotation.setBegin(sentence.getBegin());
            }
            if (anyMatch) {
                sectionAnnotation.setEnd(sentence.getEnd());
                document.addAnnotation(sectionAnnotation);
                sectionAnnotation = null;
            }
        }
        if (sectionAnnotation != null) {
            log.warn("found last sentence without newline");
            sectionAnnotation.setEnd(document.getEnd());
            document.addAnnotation(sectionAnnotation);
        }
    }

    private static void applySectionsFromTargetLabels(Document document, LookupCacheEncoder lookupCacheEncoder, int i) {
        String str = "";
        INDArray transposei = Nd4j.create(new long[]{1, lookupCacheEncoder.getEmbeddingVectorSize()}).transposei();
        int i2 = 0;
        Annotation sectionAnnotation = new SectionAnnotation(Annotation.Source.PRED);
        sectionAnnotation.setBegin(document.getBegin());
        for (Sentence sentence : document.getSentences()) {
            INDArray vector = sentence.getVector(lookupCacheEncoder.getClass());
            if (!lookupCacheEncoder.getNearestNeighbours(vector, i).contains(str)) {
                if (!str.isEmpty()) {
                    document.addAnnotation(sectionAnnotation);
                }
                sectionAnnotation = new SectionAnnotation(Annotation.Source.PRED);
                sectionAnnotation.setBegin(sentence.getBegin());
                i2 = 0;
                transposei = Nd4j.create(new long[]{1, lookupCacheEncoder.getEmbeddingVectorSize()}).transposei();
            }
            transposei.addi(vector);
            i2++;
            String nearestNeighbour = lookupCacheEncoder.getNearestNeighbour(transposei.div(Integer.valueOf(i2)));
            sectionAnnotation.setEnd(sentence.getEnd());
            str = nearestNeighbour;
        }
        if (str.isEmpty()) {
            return;
        }
        document.addAnnotation(sectionAnnotation);
    }

    private static void applySectionsFromEdges(Document document, INDArray iNDArray) {
        if (document.countSentences() < 1) {
            log.warn("Empty document");
            return;
        }
        if (iNDArray == null || document.countSentences() < 2) {
            SectionAnnotation sectionAnnotation = new SectionAnnotation(Annotation.Source.PRED);
            sectionAnnotation.setBegin(document.getBegin());
            sectionAnnotation.setEnd(document.getEnd());
            document.addAnnotation(sectionAnnotation);
            return;
        }
        int i = 0;
        Annotation sectionAnnotation2 = new SectionAnnotation(Annotation.Source.PRED);
        sectionAnnotation2.setBegin(document.getBegin());
        int i2 = 0;
        for (Sentence sentence : document.getSentences()) {
            if (iNDArray.getDouble(i2) > 0.0d) {
                if (i > 0) {
                    document.addAnnotation(sectionAnnotation2);
                }
                sectionAnnotation2 = new SectionAnnotation(Annotation.Source.PRED);
                sectionAnnotation2.setBegin(sentence.getBegin());
                i = 0;
            }
            i++;
            sectionAnnotation2.setEnd(sentence.getEnd());
            i2++;
        }
        if (i > 0) {
            document.addAnnotation(sectionAnnotation2);
        }
    }

    private static INDArray detectSectionsFromEmbeddingDeviation(Document document) {
        if (document.countSentences() < 2) {
            return null;
        }
        return deviation(gaussianSmooth(pca(getEmbeddingMatrix(document), 16)));
    }

    private static INDArray detectSectionsFromBidirectionalEmbeddingDeviation(Document document) {
        if (document.countSentences() < 1) {
            return null;
        }
        long length = document.getSentence(0).getVector("embeddingFW").length();
        INDArray zeros = Nd4j.zeros(document.countSentences(), length);
        INDArray zeros2 = Nd4j.zeros(document.countSentences(), length);
        int i = 0;
        for (Sentence sentence : document.getSentences()) {
            zeros.getRow(i).assign(sentence.getVector("embeddingFW"));
            zeros2.getRow(i).assign(sentence.getVector("embeddingBW"));
            i++;
        }
        INDArray mmul = zeros.mmul(PCA.pca_factor(zeros.dup(), 16, false));
        INDArray mmul2 = zeros2.mmul(PCA.pca_factor(zeros2.dup(), 16, false));
        INDArray zeros3 = Nd4j.zeros(zeros.rows(), 1L);
        mmul.putColumn(0, zeros3);
        mmul2.putColumn(0, zeros3);
        mmul.putColumn(1, zeros3);
        mmul2.putColumn(1, zeros3);
        return deviation(gaussianSmooth(mmul, 1.5d), gaussianSmooth(mmul2, 1.5d));
    }

    protected static INDArray getLayerMatrix(Document document, String str) {
        INDArray zeros = Nd4j.zeros(document.countSentences(), document.getSentence(0).getVector(str).length());
        int i = 0;
        Iterator it = document.getSentences().iterator();
        while (it.hasNext()) {
            int i2 = i;
            i++;
            zeros.getRow(i2).assign(((Sentence) it.next()).getVector(str));
        }
        return zeros;
    }

    protected static INDArray getLayerMatrix(Document document, Class cls) {
        return getLayerMatrix(document, cls.getCanonicalName());
    }

    protected static INDArray getEmbeddingMatrix(Document document) {
        return getLayerMatrix(document, SectorEncoder.class);
    }

    protected static INDArray pca(INDArray iNDArray, int i) {
        return iNDArray.mmul(PCA.pca_factor(iNDArray.dup(), i, true));
    }

    protected static INDArray gaussianSmooth(INDArray iNDArray) {
        return gaussianSmooth(iNDArray, 2.5d);
    }

    protected static INDArray gaussianSmooth(INDArray iNDArray, double d) {
        INDArray dup = iNDArray.dup('c');
        INDArray zeros = Nd4j.zeros(dup.rows(), 1, 'c');
        INDArray zerosLike = Nd4j.zerosLike(iNDArray);
        for (int i = 0; i < zeros.length(); i++) {
            NormalDistribution normalDistribution = new NormalDistribution(i, d);
            for (int i2 = 0; i2 < zeros.length(); i2++) {
                zeros.putScalar(i2, normalDistribution.density(i2));
            }
            zerosLike.getRow(i).assign(dup.mulColumnVector(zeros).sum(new int[]{0}));
        }
        return zerosLike;
    }

    protected static INDArray deviation(INDArray iNDArray, INDArray iNDArray2) {
        INDArray zeros = Nd4j.zeros(iNDArray.rows(), 1L);
        int i = 1;
        while (i < zeros.rows()) {
            double sqrt = Math.sqrt((i < zeros.rows() - 1 ? Transforms.cosineDistance(iNDArray.getRow(i), iNDArray.getRow(i + 1)) : 0.0d) * (i > 2 ? Transforms.cosineDistance(iNDArray2.getRow(i - 1), iNDArray2.getRow(i - 2)) : 0.0d));
            zeros.putScalar(i, 0L, Double.isNaN(sqrt) ? 0.0d : sqrt);
            i++;
        }
        return zeros;
    }

    protected static INDArray deviation(INDArray iNDArray) {
        INDArray zeros = Nd4j.zeros(iNDArray.rows(), 1L);
        for (int i = 1; i < zeros.rows(); i++) {
            zeros.putScalar(i, 0L, Transforms.cosineDistance(iNDArray.getRow(i), iNDArray.getRow(i - 1)));
        }
        return zeros;
    }

    protected static INDArray detectEdges(INDArray iNDArray) {
        if (iNDArray == null) {
            return null;
        }
        INDArray zeros = Nd4j.zeros(iNDArray.rows(), 1L);
        for (int i = 1; i < zeros.rows() - 1; i++) {
            zeros.putScalar(i, 0L, (iNDArray.getDouble((long) (i - 1)) >= iNDArray.getDouble((long) i) || iNDArray.getDouble((long) (i + 1)) >= iNDArray.getDouble((long) i)) ? 0.0d : 1.0d);
        }
        zeros.putScalar(0L, 0L, 1.0d);
        return zeros;
    }

    protected static INDArray detectEdges(INDArray iNDArray, int i) {
        if (iNDArray == null) {
            return null;
        }
        INDArray zeros = Nd4j.zeros(iNDArray.rows(), 1L);
        for (int i2 = 1; i2 < zeros.rows() - 1; i2++) {
            if (iNDArray.getDouble(i2 - 1) >= iNDArray.getDouble(i2) || iNDArray.getDouble(i2 + 1) >= iNDArray.getDouble(i2)) {
                zeros.putScalar(i2, 0L, 0.0d);
            } else {
                zeros.putScalar(i2, 0L, iNDArray.getDouble(i2));
            }
        }
        INDArray zeros2 = Nd4j.zeros(iNDArray.rows(), 1L);
        INDArray iNDArray2 = Nd4j.sortWithIndices(Nd4j.toFlattened(new INDArray[]{zeros}).dup(), 1, false)[0];
        INDArray iNDArray3 = Nd4j.sortWithIndices(Nd4j.toFlattened(new INDArray[]{iNDArray}).dup(), 1, false)[0];
        for (int i3 = 0; i3 < i - 1; i3++) {
            int i4 = iNDArray2.getInt(new int[]{i3});
            if (i4 != 0) {
                if (zeros.getDouble(i4) == 0.0d) {
                    break;
                }
                zeros2.putScalar(i4, 0L, 1.0d);
            }
        }
        int i5 = 0;
        while (i5 < iNDArray.rows() && zeros2.sumNumber().intValue() < i - 1) {
            int i6 = i5;
            i5++;
            int i7 = iNDArray3.getInt(new int[]{i6});
            if (i7 != 0 && zeros2.getDouble(i7) != 1.0d) {
                zeros2.putScalar(i7, 0L, 1.0d);
            }
        }
        zeros2.putScalar(0L, 0L, 1.0d);
        return zeros2;
    }

    protected static INDArray deltaMatrix(INDArray iNDArray) {
        INDArray zeros = Nd4j.zeros(iNDArray.rows(), 1L);
        INDArray zeros2 = Nd4j.zeros(iNDArray.columns());
        for (int i = 0; i < iNDArray.rows(); i++) {
            INDArray row = iNDArray.getRow(i);
            zeros.putScalar(i, 0L, Transforms.cosineDistance(zeros2, row));
            zeros2 = row.dup();
        }
        zeros.putScalar(0L, 0L, 1.0d);
        return zeros;
    }
}
