package ai.djl.huggingface.translator;

import ai.djl.huggingface.tokenizers.Encoding;
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import ai.djl.modality.Classifications;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.translate.Batchifier;
import ai.djl.translate.NoBatchifyTranslator;
import ai.djl.translate.TranslatorContext;
import ai.djl.util.JsonUtils;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.Reader;
import java.nio.file.Files;

/* loaded from: input_file:ai/djl/huggingface/translator/TextClassificationBatchTranslator.class */
public class TextClassificationBatchTranslator implements NoBatchifyTranslator<String[], Classifications[]> {
    private HuggingFaceTokenizer tokenizer;
    private Batchifier batchifier;
    private PretrainedConfig config;

    /* JADX INFO: Access modifiers changed from: package-private */
    public TextClassificationBatchTranslator(HuggingFaceTokenizer huggingFaceTokenizer, Batchifier batchifier) {
        this.tokenizer = huggingFaceTokenizer;
        this.batchifier = batchifier;
    }

    @Override // ai.djl.translate.Translator
    public void prepare(TranslatorContext translatorContext) throws IOException {
        BufferedReader newBufferedReader = Files.newBufferedReader(translatorContext.getModel().getModelPath().resolve("config.json"));
        try {
            this.config = (PretrainedConfig) JsonUtils.GSON.fromJson((Reader) newBufferedReader, PretrainedConfig.class);
            if (newBufferedReader != null) {
                newBufferedReader.close();
            }
        } catch (Throwable th) {
            if (newBufferedReader != null) {
                try {
                    newBufferedReader.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Override // ai.djl.translate.PreProcessor
    public NDList processInput(TranslatorContext translatorContext, String[] strArr) {
        NDManager nDManager = translatorContext.getNDManager();
        Encoding[] batchEncode = this.tokenizer.batchEncode(strArr);
        NDList[] nDListArr = new NDList[batchEncode.length];
        for (int i = 0; i < batchEncode.length; i++) {
            nDListArr[i] = batchEncode[i].toNDList(nDManager, false);
        }
        return this.batchifier.batchify(nDListArr);
    }

    @Override // ai.djl.translate.PostProcessor
    public Classifications[] processOutput(TranslatorContext translatorContext, NDList nDList) {
        NDList[] unbatchify = this.batchifier.unbatchify(nDList);
        Classifications[] classificationsArr = new Classifications[unbatchify.length];
        for (int i = 0; i < unbatchify.length; i++) {
            classificationsArr[i] = TextClassificationTranslator.toClassifications(this.config, unbatchify[i]);
        }
        return classificationsArr;
    }
}
