package ai.djl.huggingface.translator;

import ai.djl.huggingface.tokenizers.Encoding;
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import ai.djl.modality.nlp.qa.QAInput;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.index.NDIndex;
import ai.djl.translate.ArgumentsUtil;
import ai.djl.translate.Batchifier;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import java.io.IOException;
import java.util.Map;

/* loaded from: input_file:ai/djl/huggingface/translator/QuestionAnsweringTranslator.class */
public class QuestionAnsweringTranslator implements Translator<QAInput, String> {
    private HuggingFaceTokenizer tokenizer;
    private boolean includeTokenTypes;
    private Batchifier batchifier;

    /* loaded from: input_file:ai/djl/huggingface/translator/QuestionAnsweringTranslator$Builder.class */
    public static final class Builder {
        private HuggingFaceTokenizer tokenizer;
        private boolean includeTokenTypes;
        private Batchifier batchifier = Batchifier.STACK;

        Builder(HuggingFaceTokenizer huggingFaceTokenizer) {
            this.tokenizer = huggingFaceTokenizer;
        }

        public Builder optIncludeTokenTypes(boolean z) {
            this.includeTokenTypes = z;
            return this;
        }

        public Builder optBatchifier(Batchifier batchifier) {
            this.batchifier = batchifier;
            return this;
        }

        public void configure(Map<String, ?> map) {
            optIncludeTokenTypes(ArgumentsUtil.booleanValue(map, "includeTokenTypes"));
            optBatchifier(Batchifier.fromString(ArgumentsUtil.stringValue(map, "batchifier", "stack")));
        }

        public QuestionAnsweringTranslator build() throws IOException {
            return new QuestionAnsweringTranslator(this.tokenizer, this.includeTokenTypes, this.batchifier);
        }
    }

    QuestionAnsweringTranslator(HuggingFaceTokenizer huggingFaceTokenizer, boolean z, Batchifier batchifier) {
        this.tokenizer = huggingFaceTokenizer;
        this.includeTokenTypes = z;
        this.batchifier = batchifier;
    }

    @Override // ai.djl.translate.Translator
    public Batchifier getBatchifier() {
        return this.batchifier;
    }

    @Override // ai.djl.translate.PreProcessor
    public NDList processInput(TranslatorContext translatorContext, QAInput qAInput) {
        Encoding encode = this.tokenizer.encode(qAInput.getQuestion(), qAInput.getParagraph());
        translatorContext.setAttachment("encoding", encode);
        return encode.toNDList(translatorContext.getNDManager(), this.includeTokenTypes);
    }

    @Override // ai.djl.translate.PostProcessor
    public String processOutput(TranslatorContext translatorContext, NDList nDList) {
        return decode(nDList, (Encoding) translatorContext.getAttachment("encoding"), this.tokenizer);
    }

    @Override // ai.djl.translate.Translator
    /* renamed from: toBatchTranslator, reason: merged with bridge method [inline-methods] */
    public Translator<QAInput[], String[]> toBatchTranslator2(Batchifier batchifier) {
        this.tokenizer.enableBatch();
        return new QuestionAnsweringBatchTranslator(this.tokenizer, this.includeTokenTypes, batchifier);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static String decode(NDList nDList, Encoding encoding, HuggingFaceTokenizer huggingFaceTokenizer) {
        NDArray duplicate = nDList.get(0).duplicate();
        NDArray duplicate2 = nDList.get(1).duplicate();
        duplicate.set(new NDIndex(0), (Number) (-100000));
        duplicate2.set(new NDIndex(0), (Number) (-100000));
        int i = (int) duplicate.argMax().getLong(new long[0]);
        int i2 = (int) duplicate2.argMax().getLong(new long[0]);
        if (i > i2) {
            i = i2;
            i2 = i;
        }
        int i3 = (i2 - i) + 1;
        long[] jArr = new long[i3];
        System.arraycopy(encoding.getIds(), i, jArr, 0, i3);
        return huggingFaceTokenizer.decode(jArr).trim();
    }

    public static Builder builder(HuggingFaceTokenizer huggingFaceTokenizer) {
        return new Builder(huggingFaceTokenizer);
    }

    public static Builder builder(HuggingFaceTokenizer huggingFaceTokenizer, Map<String, ?> map) {
        Builder builder = builder(huggingFaceTokenizer);
        builder.configure(map);
        return builder;
    }
}
