package cc.unitmesh.cf;

import ai.djl.huggingface.tokenizers.Encoding;
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OnnxValue;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtSession;
import ai.onnxruntime.OrtUtil;
import com.sun.jna.Function;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import kotlin.Metadata;
import kotlin.TuplesKt;
import kotlin.collections.ArraysKt;
import kotlin.collections.CollectionsKt;
import kotlin.collections.MapsKt;
import kotlin.jvm.internal.DefaultConstructorMarker;
import kotlin.jvm.internal.Intrinsics;
import kotlin.jvm.internal.SourceDebugExtension;
import kotlin.ranges.IntRange;
import org.jetbrains.annotations.NotNull;

/* compiled from: STSemantic.kt */
@Metadata(mv = {1, 9, 0}, k = 1, xi = 48, d1 = {"��0\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0010 \n\u0002\u0010\u0006\n��\n\u0002\u0010\u000e\n\u0002\b\u0003\u0018�� \u000f2\u00020\u0001:\u0001\u000fB\u001d\u0012\u0006\u0010\u0002\u001a\u00020\u0003\u0012\u0006\u0010\u0004\u001a\u00020\u0005\u0012\u0006\u0010\u0006\u001a\u00020\u0007¢\u0006\u0002\u0010\bJ\u0016\u0010\t\u001a\b\u0012\u0004\u0012\u00020\u000b0\n2\u0006\u0010\f\u001a\u00020\rH\u0016J\b\u0010\u000e\u001a\u00020\u0003H\u0016R\u000e\u0010\u0006\u001a\u00020\u0007X\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\u0004\u001a\u00020\u0005X\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\u0002\u001a\u00020\u0003X\u0082\u0004¢\u0006\u0002\n��¨\u0006\u0010"}, d2 = {"Lcc/unitmesh/cf/STSemantic;", "Lcc/unitmesh/cf/Semantic;", "tokenizer", "Lai/djl/huggingface/tokenizers/HuggingFaceTokenizer;", "session", "Lai/onnxruntime/OrtSession;", "env", "Lai/onnxruntime/OrtEnvironment;", "(Lai/djl/huggingface/tokenizers/HuggingFaceTokenizer;Lai/onnxruntime/OrtSession;Lai/onnxruntime/OrtEnvironment;)V", "embed", "", "", "input", "", "getTokenizer", "Companion", "sentence-transformers"})
@SourceDebugExtension({"SMAP\nSTSemantic.kt\nKotlin\n*S Kotlin\n*F\n+ 1 STSemantic.kt\ncc/unitmesh/cf/STSemantic\n+ 2 _Arrays.kt\nkotlin/collections/ArraysKt___ArraysKt\n*L\n1#1,84:1\n11115#2:85\n11450#2,3:86\n*S KotlinDebug\n*F\n+ 1 STSemantic.kt\ncc/unitmesh/cf/STSemantic\n*L\n57#1:85\n57#1:86,3\n*E\n"})
/* loaded from: input_file:cc/unitmesh/cf/STSemantic.class */
public final class STSemantic implements Semantic {

    @NotNull
    public static final Companion Companion = new Companion(null);

    @NotNull
    private final HuggingFaceTokenizer tokenizer;

    @NotNull
    private final OrtSession session;

    @NotNull
    private final OrtEnvironment env;

    /* compiled from: STSemantic.kt */
    @Metadata(mv = {1, 9, 0}, k = 1, xi = 48, d1 = {"��\u0012\n\u0002\u0018\u0002\n\u0002\u0010��\n\u0002\b\u0002\n\u0002\u0018\u0002\n��\b\u0086\u0003\u0018��2\u00020\u0001B\u0007\b\u0002¢\u0006\u0002\u0010\u0002J\u0006\u0010\u0003\u001a\u00020\u0004¨\u0006\u0005"}, d2 = {"Lcc/unitmesh/cf/STSemantic$Companion;", "", "()V", "create", "Lcc/unitmesh/cf/STSemantic;", "sentence-transformers"})
    /* loaded from: input_file:cc/unitmesh/cf/STSemantic$Companion.class */
    public static final class Companion {
        private Companion() {
        }

        @NotNull
        public final STSemantic create() {
            ClassLoader contextClassLoader = Thread.currentThread().getContextClassLoader();
            InputStream resourceAsStream = contextClassLoader.getResourceAsStream("model/tokenizer.json");
            Intrinsics.checkNotNull(resourceAsStream);
            InputStream resourceAsStream2 = contextClassLoader.getResourceAsStream("model/model.onnx");
            Intrinsics.checkNotNull(resourceAsStream2);
            HuggingFaceTokenizer newInstance = HuggingFaceTokenizer.newInstance(resourceAsStream, (Map<String, String>) null);
            OrtEnvironment environment = OrtEnvironment.getEnvironment();
            OrtSession createSession = environment.createSession(resourceAsStream2.readAllBytes(), new OrtSession.SessionOptions());
            Intrinsics.checkNotNull(newInstance);
            Intrinsics.checkNotNull(createSession);
            Intrinsics.checkNotNull(environment);
            return new STSemantic(newInstance, createSession, environment);
        }

        public /* synthetic */ Companion(DefaultConstructorMarker defaultConstructorMarker) {
            this();
        }
    }

    public STSemantic(@NotNull HuggingFaceTokenizer tokenizer, @NotNull OrtSession session, @NotNull OrtEnvironment env) {
        Intrinsics.checkNotNullParameter(tokenizer, "tokenizer");
        Intrinsics.checkNotNullParameter(session, "session");
        Intrinsics.checkNotNullParameter(env, "env");
        this.tokenizer = tokenizer;
        this.session = session;
        this.env = env;
    }

    @Override // cc.unitmesh.cf.Semantic
    @NotNull
    public HuggingFaceTokenizer getTokenizer() {
        return this.tokenizer;
    }

    @Override // cc.unitmesh.cf.Semantic
    @NotNull
    public List<Double> embed(@NotNull String input) {
        Intrinsics.checkNotNullParameter(input, "input");
        Encoding encode = this.tokenizer.encode(input, true);
        long[] ids = encode.getIds();
        long[] attentionMask = encode.getAttentionMask();
        long[] typeIds = encode.getTypeIds();
        if (encode.getIds().length >= 512) {
            Intrinsics.checkNotNull(ids);
            ids = CollectionsKt.toLongArray(ArraysKt.slice(ids, new IntRange(0, 510)));
            Intrinsics.checkNotNull(attentionMask);
            attentionMask = CollectionsKt.toLongArray(ArraysKt.slice(attentionMask, new IntRange(0, 510)));
            Intrinsics.checkNotNull(typeIds);
            typeIds = CollectionsKt.toLongArray(ArraysKt.slice(typeIds, new IntRange(0, 510)));
        }
        OnnxValue onnxValue = this.session.run(MapsKt.mapOf(TuplesKt.to("input_ids", OnnxTensor.createTensor(this.env, OrtUtil.reshape(ids, new long[]{1, ids.length}))), TuplesKt.to("attention_mask", OnnxTensor.createTensor(this.env, OrtUtil.reshape(attentionMask, new long[]{1, attentionMask.length}))), TuplesKt.to("token_type_ids", OnnxTensor.createTensor(this.env, OrtUtil.reshape(typeIds, new long[]{1, typeIds.length}))))).get(0);
        Intrinsics.checkNotNull(onnxValue, "null cannot be cast to non-null type ai.onnxruntime.OnnxTensor");
        float[] array = ((OnnxTensor) onnxValue).getFloatBuffer().array();
        float[] fArr = new float[Function.USE_VARARGS];
        for (int i = 0; i < 384; i++) {
            float f = 0.0f;
            int length = ids.length;
            for (int i2 = 0; i2 < length; i2++) {
                f += array[(i2 * Function.USE_VARARGS) + i];
            }
            fArr[i] = f / ids.length;
        }
        ArrayList arrayList = new ArrayList(fArr.length);
        for (float f2 : fArr) {
            arrayList.add(Double.valueOf(f2));
        }
        return arrayList;
    }
}
