package de.datexis.sector.tagger;

import de.datexis.common.Resource;
import de.datexis.encoder.Encoder;
import de.datexis.encoder.EncodingHelpers;
import de.datexis.model.Document;
import de.datexis.model.Sentence;
import de.datexis.model.Span;
import de.datexis.tagger.AbstractMultiDataSetIterator;
import de.datexis.tagger.DocumentSentenceIterator;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.shade.jackson.annotation.JsonIgnore;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:de/datexis/sector/tagger/SectorEncoder.class */
public class SectorEncoder extends Encoder {
    protected SectorTagger tagger;

    public SectorEncoder() {
        this("SECTOR", new SectorTagger());
    }

    public SectorEncoder(String str) {
        this(str, new SectorTagger());
    }

    public SectorEncoder(String str, SectorTagger sectorTagger) {
        super(str);
        this.log = LoggerFactory.getLogger(SectorEncoder.class);
        this.tagger = sectorTagger;
        setModelFilename(this.tagger.getModel());
        setModelAvailable(true);
    }

    @JsonIgnore
    public SectorTagger getTagger() {
        return this.tagger;
    }

    public void setTagger(SectorTagger sectorTagger) {
        this.tagger = sectorTagger;
    }

    public long getEmbeddingVectorSize() {
        return this.tagger.getEmbeddingLayerSize();
    }

    public INDArray encode(Span span) {
        throw new IllegalArgumentException("SECTOR is only implemented to encode over Documents.");
    }

    public INDArray encode(String str) {
        throw new IllegalArgumentException("SECTOR is only implemented to encode over Documents.");
    }

    public void encodeEach(Document document, Class<? extends Span> cls) {
        encodeEach(Collections.singleton(document), cls);
    }

    public void encodeEach(Collection<Document> collection, Class<? extends Span> cls) {
        if (cls != Sentence.class) {
            throw new IllegalArgumentException("SECTOR is only implemented to encode Sentences over a Document");
        }
        this.tagger.tag(collection);
    }

    public INDArray encodeMatrix(List<Document> list, int i, Class<? extends Span> cls) {
        if (cls != Sentence.class) {
            throw new IllegalArgumentException("SECTOR is only implemented to encode Sentences over a Document");
        }
        SectorTaggerIterator sectorTaggerIterator = new SectorTaggerIterator(AbstractMultiDataSetIterator.Stage.ENCODE, (Collection<Document>) list, this.tagger, this.tagger.getBatchSize(), false, this.tagger.requireSubsampling);
        INDArray iNDArray = null;
        while (true) {
            INDArray iNDArray2 = iNDArray;
            if (!sectorTaggerIterator.hasNext()) {
                return iNDArray2;
            }
            DocumentSentenceIterator.DocumentBatch nextDocumentBatch = sectorTaggerIterator.nextDocumentBatch();
            Map<String, INDArray> encodeMatrix = this.tagger.encodeMatrix(nextDocumentBatch);
            INDArray iNDArray3 = encodeMatrix.get("target");
            INDArray iNDArray4 = encodeMatrix.get("embedding");
            encodeMatrix.get("BLSTM");
            int i2 = 0;
            Iterator it = nextDocumentBatch.docs.iterator();
            while (it.hasNext()) {
                int i3 = 0;
                for (Sentence sentence : ((Document) it.next()).getSentences()) {
                    if (i3 >= i) {
                        break;
                    }
                    if (iNDArray3 != null) {
                        sentence.putVector(this.tagger.getTargetEncoder().getClass(), EncodingHelpers.getTimeStep(iNDArray3, i2, i3));
                    }
                    if (iNDArray4 != null) {
                        sentence.putVector(SectorEncoder.class, EncodingHelpers.getTimeStep(iNDArray4, i2, i3));
                    }
                    i3++;
                }
                i2++;
            }
            if (i > nextDocumentBatch.maxDocLength) {
                iNDArray4 = Nd4j.append(iNDArray4, i - nextDocumentBatch.maxDocLength, 0.0d, 2);
            }
            iNDArray = iNDArray2 == null ? iNDArray4 : Nd4j.concat(0, new INDArray[]{iNDArray2, iNDArray4});
        }
    }

    public void encodeEach(Sentence sentence, Class<? extends Span> cls) {
        throw new IllegalArgumentException("SECTOR is only implemented to encode over Documents.");
    }

    public void trainModel(Collection<Document> collection) {
        throw new UnsupportedOperationException("You need to train SectorTagger.");
    }

    public void loadModel(Resource resource) {
        this.tagger.loadModel(resource);
        setModelAvailable(true);
        setModel(resource);
    }

    public void saveModel(Resource resource, String str) {
        this.tagger.saveModel(resource, str);
        setModelFilename(this.tagger.getModel());
    }

    public String getName() {
        return this.tagger.getName();
    }

    public void setName(String str) {
        this.tagger.setName(str);
    }

    public int getBatchSize() {
        return this.tagger.getBatchSize();
    }

    public void setBatchSize(int i) {
        this.tagger.setBatchSize(i);
    }

    public int getEmbeddingLayerSize() {
        return this.tagger.getEmbeddingLayerSize();
    }

    public void setEmbeddingLayerSize(int i) {
        this.tagger.setEmbeddingLayerSize(i);
    }

    public void setMultiClass(boolean z) {
        this.tagger.setRequireSubsampling(z);
    }

    public boolean isMultiClass() {
        return this.tagger.isRequireSubsampling();
    }

    public void setNumEpochs(int i) {
        this.tagger.setNumEpochs(i);
    }

    public int getNumEpochs() {
        return this.tagger.getNumEpochs();
    }

    public void setRandomize(boolean z) {
        this.tagger.setRandomize(z);
    }

    public boolean isRandomize() {
        return this.tagger.isRandomize();
    }
}
