package de.kherud.llama;

import com.sun.jna.Pointer;
import de.kherud.llama.Parameters;
import de.kherud.llama.foreign.LlamaLibrary;
import de.kherud.llama.foreign.NativeSize;
import de.kherud.llama.foreign.llama_timings;
import de.kherud.llama.foreign.llama_token_data;
import de.kherud.llama.foreign.llama_token_data_array;
import java.nio.ByteBuffer;
import java.nio.IntBuffer;
import java.nio.charset.StandardCharsets;
import java.util.Iterator;
import java.util.NoSuchElementException;
import java.util.function.BiConsumer;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

/* loaded from: input_file:de/kherud/llama/LlamaModel.class */
public class LlamaModel implements AutoCloseable {
    private static BiConsumer<LogLevel, String> logCallback;
    private final Parameters params;
    private final LlamaLibrary.llama_model model;
    private final LlamaLibrary.llama_context ctx;
    private final Pointer logitsPointer;
    private final SliceableIntBuffer contextBuffer;
    private final SliceableIntBuffer tokenBuffer;
    private SliceableByteBuffer tokenPieceBuffer;
    private final llama_token_data.ByReference[] candidateData;
    private final llama_token_data_array candidates;
    private final int nVocab;
    private final int tokenBos;
    private final int tokenEos;
    private final int tokenNl;
    private int nPast;
    private int nContext;
    private int nBuffered;

    /* loaded from: input_file:de/kherud/llama/LlamaModel$LlamaIterator.class */
    private class LlamaIterator implements Iterator<Output> {
        private final StringBuilder builder = new StringBuilder();
        private boolean hasNext = true;
        private int nRemain;

        public LlamaIterator(String str) {
            this.nRemain = LlamaModel.this.params.nPredict;
            setup(str);
        }

        @Override // java.util.Iterator
        public boolean hasNext() {
            return this.hasNext;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.Iterator
        public Output next() {
            if (!this.hasNext) {
                throw new NoSuchElementException();
            }
            LlamaModel.this.evaluate();
            Output sample = LlamaModel.this.sample();
            this.builder.append(sample.text);
            LlamaModel.this.truncateContext(1);
            LlamaModel.this.contextBuffer.put(LlamaModel.this.nContext, sample.token);
            LlamaModel.this.nContext++;
            this.nRemain--;
            checkHasNext();
            return sample;
        }

        private void checkHasNext() {
            this.hasNext = ((LlamaModel.this.params.nPredict >= 0 && this.nRemain <= 0) || LlamaModel.this.contextBuffer.get(LlamaModel.this.nContext - 1) == LlamaModel.this.tokenEos || isAntiPrompt()) ? false : true;
        }

        private void setup(String str) {
            if (LlamaModel.this.nContext == 0 && !str.startsWith(" ")) {
                str = " " + str;
            }
            LlamaModel.this.addContext(LlamaModel.this.tokenize(str, true));
        }

        private boolean isAntiPrompt() {
            Iterator<String> it = LlamaModel.this.params.antiprompt.iterator();
            while (it.hasNext()) {
                if (this.builder.lastIndexOf(it.next()) > 0) {
                    return true;
                }
            }
            return false;
        }
    }

    /* loaded from: input_file:de/kherud/llama/LlamaModel$Output.class */
    public final class Output {
        public final int token;
        public final float probability;
        public final String text;

        private Output(int i, float f) {
            this.token = i;
            this.probability = f;
            this.text = LlamaModel.this.decodeToken(i);
        }

        public String toString() {
            return this.text;
        }

        public TokenType getType() {
            return TokenType.fromCode(LlamaLibrary.llama_token_get_type(LlamaModel.this.ctx, this.token));
        }

        public float getScore() {
            return LlamaLibrary.llama_token_get_score(LlamaModel.this.ctx, this.token);
        }
    }

    public LlamaModel(String str) {
        this(str, new Parameters.Builder().build());
    }

    public LlamaModel(String str, Parameters parameters) {
        this.nPast = 0;
        this.nContext = 0;
        this.nBuffered = 0;
        this.params = parameters;
        this.model = LlamaLibrary.llama_load_model_from_file(str, parameters.ctx);
        if (this.model == null) {
            throw new RuntimeException("error: unable to load model");
        }
        this.ctx = LlamaLibrary.llama_new_context_with_model(this.model, parameters.ctx);
        if (parameters.loraAdapter != null && parameters.loraBase != null && LlamaLibrary.llama_model_apply_lora_from_file(this.model, parameters.loraAdapter, parameters.loraBase, parameters.nThreads) != 0) {
            throw new RuntimeException("error: unable to apply lora");
        }
        this.contextBuffer = new SliceableIntBuffer(IntBuffer.allocate(parameters.ctx.n_ctx));
        this.tokenBuffer = new SliceableIntBuffer(IntBuffer.allocate(parameters.ctx.n_ctx));
        this.tokenPieceBuffer = new SliceableByteBuffer(ByteBuffer.allocate(64));
        this.nVocab = getVocabularySize();
        this.logitsPointer = LlamaLibrary.llama_get_logits(this.ctx).getPointer();
        this.candidateData = (llama_token_data.ByReference[]) new llama_token_data.ByReference().toArray(this.nVocab);
        this.candidates = new llama_token_data_array();
        this.tokenBos = LlamaLibrary.llama_token_bos(this.ctx);
        this.tokenEos = LlamaLibrary.llama_token_eos(this.ctx);
        this.tokenNl = LlamaLibrary.llama_token_nl(this.ctx);
    }

    public Iterable<Output> generate(final String str) {
        return new Iterable<Output>() { // from class: de.kherud.llama.LlamaModel.1
            @Override // java.lang.Iterable
            @NotNull
            public Iterator<Output> iterator() {
                return new LlamaIterator(str);
            }
        };
    }

    public String complete(String str) {
        StringBuilder sb = new StringBuilder();
        LlamaIterator llamaIterator = new LlamaIterator(str);
        while (llamaIterator.hasNext()) {
            sb.append(llamaIterator.next());
        }
        return sb.toString();
    }

    public float[] getEmbedding(String str) {
        if (this.params.ctx.embedding == 0) {
            throw new IllegalStateException("embedding mode not activated (see parameters)");
        }
        addContext(tokenize(str, false));
        evaluate();
        return LlamaLibrary.llama_get_embeddings(this.ctx).getPointer().getFloatArray(0L, getEmbeddingSize());
    }

    public void reset() {
        LlamaLibrary.llama_reset_timings(this.ctx);
        this.contextBuffer.clear();
        this.nContext = 0;
        this.nPast = 0;
    }

    public int[] encode(String str) {
        SliceableIntBuffer sliceableIntBuffer = tokenize(str, false);
        int[] iArr = new int[sliceableIntBuffer.capacity()];
        System.arraycopy(sliceableIntBuffer.delegate.array(), 0, iArr, 0, sliceableIntBuffer.capacity());
        return iArr;
    }

    public String decode(int[] iArr) {
        StringBuilder sb = new StringBuilder();
        for (int i : iArr) {
            sb.append(decodeToken(i));
        }
        return sb.toString();
    }

    public static void setLogger(@Nullable BiConsumer<LogLevel, String> biConsumer) {
        logCallback = biConsumer;
        if (biConsumer == null) {
            LlamaLibrary.llama_log_set(null, null);
        } else {
            LlamaLibrary.llama_log_set((i, str, pointer) -> {
                biConsumer.accept(LogLevel.fromCode(i), str);
            }, null);
        }
    }

    public int getContextSize() {
        return LlamaLibrary.llama_n_ctx(this.ctx);
    }

    public int getEmbeddingSize() {
        return LlamaLibrary.llama_n_embd(this.ctx);
    }

    public int getVocabularySize() {
        return LlamaLibrary.llama_n_vocab(this.ctx);
    }

    public VocabularyType getVocabularyType() {
        return VocabularyType.fromCode(LlamaLibrary.llama_vocab_type(this.ctx));
    }

    public long getMemorySize() {
        return LlamaLibrary.llama_model_size(this.model);
    }

    public long getAmountParameters() {
        return LlamaLibrary.llama_model_n_params(this.model);
    }

    public llama_timings getTimings() {
        return LlamaLibrary.llama_get_timings(this.ctx);
    }

    @Override // java.lang.AutoCloseable
    public void close() {
        LlamaLibrary.llama_free_model(this.model);
        LlamaLibrary.llama_free(this.ctx);
    }

    public String toString() {
        byte[] bArr = new byte[512];
        return new String(bArr, 0, Math.min(LlamaLibrary.llama_model_desc(this.model, bArr, new NativeSize(bArr.length)), bArr.length), StandardCharsets.UTF_8);
    }

    private SliceableIntBuffer tokenize(String str, boolean z) {
        int llama_tokenize = LlamaLibrary.llama_tokenize(this.ctx, str, str.length(), this.tokenBuffer.delegate, this.params.ctx.n_ctx, z ? (byte) 1 : (byte) 0);
        if (llama_tokenize < 0) {
            throw new RuntimeException("tokenization failed due to unknown reasons");
        }
        return this.tokenBuffer.slice(0, llama_tokenize);
    }

    private void evaluate() {
        while (this.nPast < this.nContext) {
            int i = this.nContext - this.nPast;
            if (i > this.params.ctx.n_batch) {
                i = this.params.ctx.n_batch;
            }
            if (LlamaLibrary.llama_eval(this.ctx, this.contextBuffer.slice(this.nPast, i).delegate, i, this.nPast, this.params.nThreads) != 0) {
                log(LogLevel.ERROR, String.format("evaluation failed (%d to evaluate, %d past, %d threads)", Integer.valueOf(i), Integer.valueOf(this.nPast), Integer.valueOf(this.params.nThreads)));
                throw new RuntimeException("token evaluation failed");
            }
            this.nPast += i;
        }
    }

    private Output sample() {
        float[] floatArray = this.logitsPointer.getFloatArray(0L, this.nVocab);
        float f = floatArray[this.tokenNl];
        this.params.logitBias.forEach((num, f2) -> {
            int intValue = num.intValue();
            floatArray[intValue] = floatArray[intValue] + f2.floatValue();
        });
        for (int i = 0; i < this.nVocab; i++) {
            this.candidateData[i].setId(i);
            this.candidateData[i].setLogit(floatArray[i]);
        }
        this.candidates.setData(this.candidateData[0]);
        this.candidates.setSize(new NativeSize(this.nVocab));
        this.candidates.setSorted((byte) 0);
        samplePenalty();
        if (!this.params.penalizeNl) {
            this.candidateData[this.tokenNl].setLogit(f);
        }
        if (this.params.grammar != null) {
            LlamaLibrary.llama_sample_grammar(this.ctx, this.candidates, this.params.grammar.foreign);
        }
        int sampleGreedy = this.params.temperature == 0.0f ? sampleGreedy() : this.params.mirostat == Parameters.MiroStat.V1 ? sampleMirostatV1() : this.params.mirostat == Parameters.MiroStat.V2 ? sampleMirostatV2() : sampleTopK();
        if (this.params.grammar != null) {
            LlamaLibrary.llama_grammar_accept_token(this.ctx, this.params.grammar.foreign, sampleGreedy);
        }
        return new Output(sampleGreedy, this.candidateData[sampleGreedy].p);
    }

    private void samplePenalty() {
        int min = Math.min(Math.min(this.nContext, this.params.repeatLastN < 0 ? this.params.ctx.n_ctx : this.params.repeatLastN), this.params.ctx.n_ctx);
        NativeSize nativeSize = new NativeSize(min);
        SliceableIntBuffer slice = this.tokenBuffer.slice(this.nContext - min, min);
        LlamaLibrary.llama_sample_repetition_penalty(this.ctx, this.candidates, slice.delegate, nativeSize, this.params.repeatPenalty);
        LlamaLibrary.llama_sample_frequency_and_presence_penalties(this.ctx, this.candidates, slice.delegate, nativeSize, this.params.frequencyPenalty, this.params.presencePenalty);
    }

    private int sampleGreedy() {
        int llama_sample_token_greedy = LlamaLibrary.llama_sample_token_greedy(this.ctx, this.candidates);
        if (this.params.nProbs > 0) {
            LlamaLibrary.llama_sample_softmax(this.ctx, this.candidates);
        }
        return llama_sample_token_greedy;
    }

    private int sampleMirostatV1() {
        LlamaLibrary.llama_sample_temperature(this.ctx, this.candidates, this.params.temperature);
        return LlamaLibrary.llama_sample_token_mirostat(this.ctx, this.candidates, this.params.mirostatTau, this.params.mirostatEta, this.params.mirostatM, this.params.mirostatMu);
    }

    private int sampleMirostatV2() {
        LlamaLibrary.llama_sample_temperature(this.ctx, this.candidates, this.params.temperature);
        return LlamaLibrary.llama_sample_token_mirostat_v2(this.ctx, this.candidates, this.params.mirostatTau, this.params.mirostatEta, this.params.mirostatMu);
    }

    private int sampleTopK() {
        NativeSize nativeSize = new NativeSize(Math.max(1, this.params.nProbs));
        LlamaLibrary.llama_sample_top_k(this.ctx, this.candidates, this.params.topK <= 0 ? this.nVocab : this.params.topK, nativeSize);
        LlamaLibrary.llama_sample_tail_free(this.ctx, this.candidates, this.params.tfsZ, nativeSize);
        LlamaLibrary.llama_sample_typical(this.ctx, this.candidates, this.params.typicalP, nativeSize);
        LlamaLibrary.llama_sample_top_p(this.ctx, this.candidates, this.params.topP, nativeSize);
        LlamaLibrary.llama_sample_temperature(this.ctx, this.candidates, this.params.temperature);
        return LlamaLibrary.llama_sample_token(this.ctx, this.candidates);
    }

    private void addContext(SliceableIntBuffer sliceableIntBuffer) {
        truncateContext(sliceableIntBuffer.capacity());
        System.arraycopy(sliceableIntBuffer.delegate.array(), 0, this.contextBuffer.delegate.array(), this.nContext, sliceableIntBuffer.capacity());
        this.nContext += sliceableIntBuffer.capacity();
    }

    private void truncateContext(int i) {
        if (this.nContext + i > this.params.ctx.n_ctx) {
            int i2 = (this.params.ctx.n_ctx / 2) - i;
            log(LogLevel.INFO, "truncating context from " + this.nContext + " to " + i2 + " tokens (+" + i + " to add)");
            System.arraycopy(this.contextBuffer.delegate.array(), this.nContext - i2, this.contextBuffer.delegate.array(), 0, i2);
            this.nPast = 0;
            this.nContext = i2;
        }
    }

    private String decodeToken(int i) {
        int i2;
        int capacity = this.tokenPieceBuffer.capacity() - this.nBuffered;
        int llama_token_to_piece = LlamaLibrary.llama_token_to_piece(this.ctx, i, this.tokenPieceBuffer.slice(this.nBuffered, capacity).delegate, capacity);
        while (true) {
            i2 = llama_token_to_piece;
            if (i2 >= 0 && i2 <= capacity) {
                break;
            }
            ByteBuffer allocate = ByteBuffer.allocate(this.tokenPieceBuffer.capacity() * 2);
            for (int i3 = 0; i3 < this.nBuffered; i3++) {
                allocate.put(i3, this.tokenPieceBuffer.get(i3));
            }
            this.tokenPieceBuffer = new SliceableByteBuffer(allocate);
            capacity = this.tokenPieceBuffer.capacity() - this.nBuffered;
            llama_token_to_piece = LlamaLibrary.llama_token_to_piece(this.ctx, i, this.tokenPieceBuffer.slice(this.nBuffered, capacity).delegate, capacity);
        }
        if ((this.tokenPieceBuffer.get(this.nBuffered) & 128) != 0) {
            this.nBuffered += i2;
            return "";
        }
        int min = this.nBuffered + Math.min(i2, this.tokenPieceBuffer.capacity());
        this.nBuffered = 0;
        return new String(this.tokenPieceBuffer.delegate.array(), 0, min, StandardCharsets.UTF_8);
    }

    private void log(LogLevel logLevel, String str) {
        if (logCallback != null) {
            logCallback.accept(logLevel, str);
        }
    }

    static {
        LlamaLibrary.llama_backend_init((byte) 0);
    }
}
