package de.datexis.encoder;

import com.google.common.collect.Lists;
import de.datexis.model.Document;
import de.datexis.model.Sentence;
import de.datexis.model.Span;
import de.datexis.model.Token;
import java.util.Collections;
import java.util.List;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:de/datexis/encoder/EncodingHelpers.class */
public class EncodingHelpers {
    public static INDArray createTimeStepMatrix(long j, long j2, long j3) {
        return Nd4j.zeros(DataType.FLOAT, new long[]{j, j2, j3});
    }

    public static void putTimeStep(INDArray iNDArray, long j, long j2, INDArray iNDArray2) {
        iNDArray.slice(j, 0).slice(j2, 1).assign(iNDArray2);
    }

    public static INDArray getTimeStep(INDArray iNDArray, long j, long j2) {
        return iNDArray.slice(j, 0).slice(j2, 1).reshape(iNDArray.size(1), 1L);
    }

    public static INDArray encodeTimeStepMatrix(List<? extends Span> list, IEncoder iEncoder, int i, Class<? extends Span> cls) {
        INDArray zeros = Nd4j.zeros(DataType.FLOAT, new long[]{list.size(), iEncoder.getEmbeddingVectorSize(), i});
        for (int i2 = 0; i2 < list.size(); i2++) {
            Span span = list.get(i2);
            List list2 = Collections.EMPTY_LIST;
            if ((span instanceof Document) && cls == Token.class) {
                list2 = Lists.newArrayList(((Document) span).getTokens());
            } else if ((span instanceof Document) && cls == Sentence.class) {
                list2 = Lists.newArrayList(((Document) span).getSentences());
            } else if ((span instanceof Sentence) && cls == Token.class) {
                list2 = Lists.newArrayList(((Sentence) span).getTokens());
            } else if ((span instanceof Sentence) && cls == Sentence.class) {
                list2 = Lists.newArrayList(new Sentence[]{(Sentence) span});
            }
            for (int i3 = 0; i3 < list2.size() && i3 < i; i3++) {
                zeros.slice(i2, 0).slice(i3, 1).assign(iEncoder.encode((Span) list2.get(i3)));
            }
        }
        return zeros;
    }

    public static INDArray encodeBatchMatrix(List<? extends Span> list, IEncoder iEncoder) {
        INDArray zeros = Nd4j.zeros(DataType.FLOAT, new long[]{list.size(), iEncoder.getEmbeddingVectorSize()});
        for (int i = 0; i < list.size(); i++) {
            zeros.slice(i).assign(iEncoder.encode(list.get(i)));
        }
        return zeros;
    }
}
