package org.clulab.scala_transformers.encoder;

import org.clulab.scala_transformers.tokenizer.LongTokenization;
import org.clulab.scala_transformers.tokenizer.LongTokenization$;
import org.clulab.scala_transformers.tokenizer.Tokenizer;
import org.clulab.shaded.org.ejml.data.FMatrixRMaj;
import scala.None$;
import scala.Option;
import scala.Predef$;
import scala.Some;
import scala.Some$;
import scala.Tuple2;
import scala.collection.ArrayOps$;
import scala.collection.StringOps$;
import scala.collection.immutable.Seq;
import scala.reflect.ClassTag$;
import scala.runtime.ObjectRef;

/* compiled from: TokenClassifier.scala */
/* loaded from: input_file:org/clulab/scala_transformers/encoder/TokenClassifier.class */
public class TokenClassifier {
    private final Encoder encoder;
    private final int maxTokens;
    private final LinearLayer[] tasks;
    private final Tokenizer tokenizer;

    public static TokenClassifier fromFiles(String str) {
        return TokenClassifier$.MODULE$.fromFiles(str);
    }

    public static TokenClassifier fromResources(String str) {
        return TokenClassifier$.MODULE$.fromResources(str);
    }

    public static Tuple2<String, Object>[][] mapTokenLabelsAndScoresToWords(Tuple2<String, Object>[][] tuple2Arr, long[] jArr) {
        return TokenClassifier$.MODULE$.mapTokenLabelsAndScoresToWords(tuple2Arr, jArr);
    }

    public static String[] mapTokenLabelsToWords(String[] strArr, long[] jArr) {
        return TokenClassifier$.MODULE$.mapTokenLabelsToWords(strArr, jArr);
    }

    public static boolean mkSingleTokenMask(long j, int i, long[] jArr) {
        return TokenClassifier$.MODULE$.mkSingleTokenMask(j, i, jArr);
    }

    public static boolean[] mkTokenMask(long[] jArr) {
        return TokenClassifier$.MODULE$.mkTokenMask(jArr);
    }

    public TokenClassifier(Encoder encoder, int i, LinearLayer[] linearLayerArr, Tokenizer tokenizer) {
        this.encoder = encoder;
        this.maxTokens = i;
        this.tasks = linearLayerArr;
        this.tokenizer = tokenizer;
    }

    public Encoder encoder() {
        return this.encoder;
    }

    public int maxTokens() {
        return this.maxTokens;
    }

    public LinearLayer[] tasks() {
        return this.tasks;
    }

    public Tokenizer tokenizer() {
        return this.tokenizer;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public Tuple2<String, Object>[][][] predictWithScores(Seq<String> seq, String str) {
        LongTokenization apply = LongTokenization$.MODULE$.apply(tokenizer().tokenize((String[]) seq.toArray(ClassTag$.MODULE$.apply(String.class))));
        long[] jArr = apply.tokenIds();
        long[] wordIds = apply.wordIds();
        String[] strArr = apply.tokens();
        if (jArr.length > maxTokens()) {
            throw new EncoderMaxTokensRuntimeException(new StringBuilder(108).append("Encoder error: the following text contains more tokens than the maximum number accepted by this encoder (").append(maxTokens()).append("): ").append(Predef$.MODULE$.wrapRefArray(strArr).mkString(", ")).toString());
        }
        FMatrixRMaj forward = encoder().forward(jArr);
        Tuple2<String, Object>[][][] tuple2Arr = (Tuple2[][][]) new Tuple2[tasks().length];
        ObjectRef create = ObjectRef.create(None$.MODULE$);
        ArrayOps$.MODULE$.indices$extension(Predef$.MODULE$.refArrayOps(tasks())).foreach(i -> {
            if (tasks()[i].dual()) {
                return;
            }
            Tuple2<String, Object>[][] predictWithScores = tasks()[i].predictWithScores(forward, (Option<int[][]>) None$.MODULE$, (Option<boolean[]>) None$.MODULE$);
            tuple2Arr[i] = TokenClassifier$.MODULE$.mapTokenLabelsAndScoresToWords(predictWithScores, apply.wordIds());
            String name = tasks()[i].name();
            if (name == null) {
                if (str != null) {
                    return;
                }
            } else if (!name.equals(str)) {
                return;
            }
            create.elem = Some$.MODULE$.apply(ArrayOps$.MODULE$.map$extension(Predef$.MODULE$.refArrayOps(predictWithScores), tuple2Arr2 -> {
                return (int[]) ArrayOps$.MODULE$.map$extension(Predef$.MODULE$.refArrayOps(tuple2Arr2), tuple2 -> {
                    return StringOps$.MODULE$.toInt$extension(Predef$.MODULE$.augmentString((String) tuple2._1()));
                }, ClassTag$.MODULE$.apply(Integer.TYPE));
            }, ClassTag$.MODULE$.apply(Integer.TYPE).wrap()));
        });
        if (((Option) create.elem).isDefined()) {
            Some apply2 = Some$.MODULE$.apply(TokenClassifier$.MODULE$.mkTokenMask(wordIds));
            ArrayOps$.MODULE$.indices$extension(Predef$.MODULE$.refArrayOps(tasks())).foreach(i2 -> {
                if (tasks()[i2].dual()) {
                    tuple2Arr[i2] = TokenClassifier$.MODULE$.mapTokenLabelsAndScoresToWords(tasks()[i2].predictWithScores(forward, (Option<int[][]>) create.elem, (Option<boolean[]>) apply2), apply.wordIds());
                }
            });
        }
        return tuple2Arr;
    }

    public String predictWithScores$default$2() {
        return "Deps Head";
    }

    /* JADX WARN: Multi-variable type inference failed */
    public String[][] predict(Seq<String> seq, String str) {
        LongTokenization apply = LongTokenization$.MODULE$.apply(tokenizer().tokenize((String[]) seq.toArray(ClassTag$.MODULE$.apply(String.class))));
        long[] jArr = apply.tokenIds();
        long[] wordIds = apply.wordIds();
        FMatrixRMaj forward = encoder().forward(jArr);
        String[][] strArr = (String[][]) new String[tasks().length];
        ObjectRef create = ObjectRef.create(None$.MODULE$);
        ArrayOps$.MODULE$.indices$extension(Predef$.MODULE$.refArrayOps(tasks())).foreach(i -> {
            if (tasks()[i].dual()) {
                return;
            }
            String[] predict = tasks()[i].predict(forward, (Option<int[]>) None$.MODULE$, (Option<boolean[]>) None$.MODULE$);
            strArr[i] = TokenClassifier$.MODULE$.mapTokenLabelsToWords(predict, apply.wordIds());
            String name = tasks()[i].name();
            if (name == null) {
                if (str != null) {
                    return;
                }
            } else if (!name.equals(str)) {
                return;
            }
            create.elem = Some$.MODULE$.apply(ArrayOps$.MODULE$.map$extension(Predef$.MODULE$.refArrayOps(predict), str2 -> {
                return StringOps$.MODULE$.toInt$extension(Predef$.MODULE$.augmentString(str2));
            }, ClassTag$.MODULE$.apply(Integer.TYPE)));
        });
        if (((Option) create.elem).isDefined()) {
            Some apply2 = Some$.MODULE$.apply(TokenClassifier$.MODULE$.mkTokenMask(wordIds));
            ArrayOps$.MODULE$.indices$extension(Predef$.MODULE$.refArrayOps(tasks())).foreach(i2 -> {
                if (tasks()[i2].dual()) {
                    strArr[i2] = TokenClassifier$.MODULE$.mapTokenLabelsToWords(tasks()[i2].predict(forward, (Option<int[]>) create.elem, (Option<boolean[]>) apply2), apply.wordIds());
                }
            });
        }
        return strArr;
    }

    public String predict$default$2() {
        return "Deps Head";
    }
}
