package de.datexis.encoder.impl;

import de.datexis.encoder.StaticEncoder;
import de.datexis.model.Document;
import de.datexis.model.Sentence;
import de.datexis.model.Span;
import de.datexis.model.Token;
import java.util.ArrayList;
import java.util.Iterator;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.shade.jackson.annotation.JsonIgnore;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:de/datexis/encoder/impl/PositionEncoder.class */
public class PositionEncoder extends StaticEncoder {
    public PositionEncoder() {
        super("POS");
        this.log = LoggerFactory.getLogger(PositionEncoder.class);
    }

    public PositionEncoder(String str) {
        super(str);
        this.log = LoggerFactory.getLogger(PositionEncoder.class);
    }

    @Override // de.datexis.annotator.AnnotatorComponent, de.datexis.annotator.IComponent
    public String getName() {
        return "Positional Encoder";
    }

    @Override // de.datexis.encoder.IEncoder
    @JsonIgnore
    public long getEmbeddingVectorSize() {
        return wordAsVector("", false, false, false, false).length();
    }

    public void setVectorSize(int i) {
        if (i != getEmbeddingVectorSize()) {
            throw new IllegalArgumentException("Vector size of saved Encoder (" + getEmbeddingVectorSize() + ") differs from implementation (" + i + ")");
        }
    }

    @Override // de.datexis.encoder.IEncoder
    public INDArray encode(Span span) {
        throw new IllegalArgumentException("PositionEncoder is sequential, you need to call encodeEach()");
    }

    @Override // de.datexis.encoder.IEncoder
    public INDArray encode(String str) {
        throw new IllegalArgumentException("PositionEncoder is sequential, you need to call encodeEach()");
    }

    public INDArray tokenAsVector(Token token, boolean z, boolean z2, boolean z3, boolean z4) {
        return wordAsVector(token.getText(), z, z2, z3, z4);
    }

    public INDArray wordAsVector(String str, boolean z, boolean z2, boolean z3, boolean z4) {
        ArrayList arrayList = new ArrayList();
        arrayList.add(Boolean.valueOf(z));
        arrayList.add(Boolean.valueOf(z2));
        arrayList.add(Boolean.valueOf(z3));
        arrayList.add(Boolean.valueOf(z4));
        INDArray zeros = Nd4j.zeros(arrayList.size(), 1L);
        int i = 0;
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            int i2 = i;
            i++;
            zeros.put(i2, 0, Double.valueOf(((Boolean) it.next()).booleanValue() ? 1.0d : 0.0d));
        }
        return zeros;
    }

    @Override // de.datexis.encoder.Encoder
    public void encodeEach(Document document, Class<? extends Span> cls) {
        if (cls != Token.class) {
            throw new IllegalArgumentException("PositionEncoder is only implemented to encode Tokens over Documents.");
        }
        boolean z = true;
        Iterator<Sentence> it = document.getSentences().iterator();
        while (it.hasNext()) {
            Sentence next = it.next();
            boolean z2 = !it.hasNext();
            boolean z3 = true;
            Iterator<Token> it2 = next.getTokens().iterator();
            while (it2.hasNext()) {
                Token next2 = it2.next();
                boolean z4 = !it2.hasNext();
                next2.putVector(PositionEncoder.class, tokenAsVector(next2, z && z3, z3, z4, z2 && z4));
                z3 = false;
            }
            z = false;
        }
    }

    @Override // de.datexis.encoder.Encoder
    public void encodeEach(Sentence sentence, Class<? extends Span> cls) {
        throw new IllegalArgumentException("PositionEncoder is only implemented to encode Tokens over Documents.");
    }
}
