package de.kherud.llama;

import com.sun.jna.Memory;
import com.sun.jna.Native;
import com.sun.jna.Pointer;
import de.kherud.llama.foreign.LlamaLibrary;
import de.kherud.llama.foreign.NativeSize;
import java.io.File;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/* loaded from: input_file:de/kherud/llama/LlamaGrammar.class */
public class LlamaGrammar {
    private final ParseState state;
    final LlamaLibrary.llama_grammar foreign;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:de/kherud/llama/LlamaGrammar$Pair.class */
    public static final class Pair<A, B> {
        final A a;
        final B b;

        Pair(A a, B b) {
            this.a = a;
            this.b = b;
        }
    }

    /* loaded from: input_file:de/kherud/llama/LlamaGrammar$ParseState.class */
    static final class ParseState {
        final Map<String, Integer> symbolIds = new HashMap();
        final List<List<Pair<Integer, Integer>>> rules = new ArrayList();
        private String string;

        ParseState(String str) {
            parse(str.getBytes(StandardCharsets.UTF_8));
        }

        private void parse(byte[] bArr) {
            int parseSpace = parseSpace(bArr, 0, true);
            while (true) {
                int i = parseSpace;
                if (i >= bArr.length) {
                    return;
                } else {
                    parseSpace = parseRule(bArr, i);
                }
            }
        }

        private LlamaLibrary.llama_grammar create() {
            ArrayList arrayList = new ArrayList();
            for (List<Pair<Integer, Integer>> list : this.rules) {
                if (!list.isEmpty()) {
                    Memory memory = new Memory(list.size() * 8);
                    for (int i = 0; i < list.size(); i++) {
                        Pair<Integer, Integer> pair = list.get(i);
                        memory.setInt(i * 8, pair.a.intValue());
                        memory.setInt((i * 8) + 4, pair.b.intValue());
                    }
                    arrayList.add(memory);
                }
            }
            Memory memory2 = new Memory(Native.POINTER_SIZE * arrayList.size());
            for (int i2 = 0; i2 < arrayList.size(); i2++) {
                memory2.setPointer(i2 * Native.POINTER_SIZE, (Pointer) arrayList.get(i2));
            }
            return LlamaLibrary.llama_grammar_init(memory2, new NativeSize(arrayList.size()), new NativeSize(this.symbolIds.get("root").intValue()));
        }

        private int parseSpace(byte[] bArr, int i, boolean z) {
            byte b;
            while (i < bArr.length && ((b = bArr[i]) == 32 || b == 9 || b == 35 || (z && (b == 13 || b == 10)))) {
                if (b == 35) {
                    while (i < bArr.length && bArr[i] != 13 && bArr[i] != 10) {
                        i++;
                    }
                } else {
                    i++;
                }
            }
            return i;
        }

        private int parseRule(byte[] bArr, int i) throws RuntimeException {
            int parseName = parseName(bArr, i);
            int parseSpace = parseSpace(bArr, parseName, false);
            int i2 = parseName - i;
            int symbolId = getSymbolId(bArr, i, i2);
            String str = new String(bArr, i, i2, StandardCharsets.UTF_8);
            if (bArr.length <= parseSpace + 2 || bArr[parseSpace] != 58 || bArr[parseSpace + 1] != 58 || bArr[parseSpace + 2] != 61) {
                throw new RuntimeException("Expecting ::= at position " + parseSpace);
            }
            int parseAlternates = parseAlternates(bArr, parseSpace(bArr, parseSpace + 3, true), str, symbolId, false);
            if (parseAlternates < bArr.length) {
                if (bArr[parseAlternates] == 13) {
                    parseAlternates += (parseAlternates + 1 >= bArr.length || bArr[parseAlternates + 1] != 10) ? 1 : 2;
                } else {
                    if (bArr[parseAlternates] != 10) {
                        throw new RuntimeException("Expecting newline or end at position " + parseAlternates);
                    }
                    parseAlternates++;
                }
            }
            return parseSpace(bArr, parseAlternates, true);
        }

        private int parseName(byte[] bArr, int i) {
            int i2 = i;
            while (i2 < bArr.length && isWordChar(bArr[i2])) {
                i2++;
            }
            if (i2 == i) {
                throw new RuntimeException("Expecting name at position " + i);
            }
            return i2;
        }

        private boolean isWordChar(byte b) {
            return (97 <= b && b <= 122) || (65 <= b && b <= 90) || b == 45 || (48 <= b && b <= 57);
        }

        private int getSymbolId(byte[] bArr, int i, int i2) {
            String str = new String(bArr, i, i2, StandardCharsets.UTF_8);
            if (!this.symbolIds.containsKey(str)) {
                this.symbolIds.put(str, Integer.valueOf(this.symbolIds.size()));
            }
            return this.symbolIds.get(str).intValue();
        }

        private int parseAlternates(byte[] bArr, int i, String str, int i2, boolean z) {
            int i3;
            ArrayList arrayList = new ArrayList();
            int parseSequence = parseSequence(str, bArr, i, arrayList, z);
            while (true) {
                i3 = parseSequence;
                if (i3 >= bArr.length || bArr[i3] != 124) {
                    break;
                }
                arrayList.add(new Pair<>(1, 0));
                parseSequence = parseSequence(str, bArr, parseSpace(bArr, i3 + 1, true), arrayList, z);
            }
            arrayList.add(new Pair<>(0, 0));
            addRule(i2, arrayList);
            return i3;
        }

        private void addRule(int i, List<Pair<Integer, Integer>> list) {
            while (this.rules.size() <= i) {
                this.rules.add(null);
            }
            this.rules.set(i, list);
        }

        private int parseSequence(String str, byte[] bArr, int i, List<Pair<Integer, Integer>> list, boolean z) {
            int size = list.size();
            int i2 = i;
            while (i2 < bArr.length) {
                if (bArr[i2] != 34) {
                    if (bArr[i2] != 91) {
                        if (!isWordChar(bArr[i2])) {
                            if (bArr[i2] != 40) {
                                if (bArr[i2] != 42 && bArr[i2] != 43 && bArr[i2] != 63) {
                                    break;
                                }
                                if (size == list.size()) {
                                    throw new RuntimeException("expecting preceding item to */+/? at " + i2);
                                }
                                int generateSymbolId = generateSymbolId(str);
                                ArrayList arrayList = new ArrayList(list.subList(size, list.size()));
                                if (bArr[i2] == 42 || bArr[i2] == 43) {
                                    arrayList.add(new Pair<>(2, Integer.valueOf(generateSymbolId)));
                                }
                                arrayList.add(new Pair<>(1, 0));
                                if (bArr[i2] == 43) {
                                    arrayList.addAll(list.subList(size, list.size()));
                                }
                                arrayList.add(new Pair<>(0, 0));
                                addRule(generateSymbolId, arrayList);
                                list.subList(size, list.size()).clear();
                                list.add(new Pair<>(2, Integer.valueOf(generateSymbolId)));
                                i2 = parseSpace(bArr, i2 + 1, z);
                            } else {
                                int parseSpace = parseSpace(bArr, i2 + 1, true);
                                int generateSymbolId2 = generateSymbolId(str);
                                int parseAlternates = parseAlternates(bArr, parseSpace, str, generateSymbolId2, true);
                                size = list.size();
                                list.add(new Pair<>(2, Integer.valueOf(generateSymbolId2)));
                                if (bArr[parseAlternates] != 41) {
                                    throw new RuntimeException("expecting ')' at " + parseAlternates);
                                }
                                i2 = parseSpace(bArr, parseAlternates + 1, z);
                            }
                        } else {
                            int parseName = parseName(bArr, i2);
                            int symbolId = getSymbolId(bArr, i2, parseName - i2);
                            i2 = parseSpace(bArr, parseName, z);
                            size = list.size();
                            list.add(new Pair<>(2, Integer.valueOf(symbolId)));
                        }
                    } else {
                        int i3 = i2 + 1;
                        int i4 = 3;
                        if (bArr[i3] == 94) {
                            i3++;
                            i4 = 4;
                        }
                        size = list.size();
                        while (bArr[i3] != 93) {
                            Pair<Integer, Integer> parseChar = parseChar(bArr, i3);
                            i3 = parseChar.b.intValue();
                            list.add(new Pair<>(Integer.valueOf(size < list.size() ? 6 : i4), parseChar.a));
                            if (bArr[i3] == 45 && bArr[i3 + 1] != 93) {
                                Pair<Integer, Integer> parseChar2 = parseChar(bArr, i3 + 1);
                                i3 = parseChar2.b.intValue();
                                list.add(new Pair<>(5, parseChar2.a));
                            }
                        }
                        i2 = parseSpace(bArr, i3 + 1, z);
                    }
                } else {
                    int i5 = i2 + 1;
                    size = list.size();
                    while (bArr[i5] != 34) {
                        Pair<Integer, Integer> parseChar3 = parseChar(bArr, i5);
                        i5 = parseChar3.b.intValue();
                        list.add(new Pair<>(3, parseChar3.a));
                    }
                    i2 = parseSpace(bArr, i5 + 1, z);
                }
            }
            return i2;
        }

        private Pair<Integer, Integer> parseChar(byte[] bArr, int i) throws RuntimeException {
            if (bArr[i] != 92) {
                return decodeUtf8(bArr, i);
            }
            switch (bArr[i + 1]) {
                case 34:
                case 91:
                case 92:
                case 93:
                    return new Pair<>(Integer.valueOf(bArr[i + 1]), Integer.valueOf(i + 2));
                case 85:
                    return parseHex(bArr, i + 2, 8);
                case 110:
                    return new Pair<>(10, Integer.valueOf(i + 2));
                case 114:
                    return new Pair<>(13, Integer.valueOf(i + 2));
                case 116:
                    return new Pair<>(9, Integer.valueOf(i + 2));
                case 117:
                    return parseHex(bArr, i + 2, 4);
                case 120:
                    return parseHex(bArr, i + 2, 2);
                default:
                    throw new RuntimeException("Unknown escape at " + i);
            }
        }

        private Pair<Integer, Integer> parseHex(byte[] bArr, int i, int i2) throws RuntimeException {
            int i3;
            int i4;
            int i5 = i + i2;
            int i6 = 0;
            while (i < i5 && bArr[i] != 0) {
                i6 <<= 4;
                byte b = bArr[i];
                if (97 <= b && b <= 102) {
                    i3 = i6;
                    i4 = (b - 97) + 10;
                } else if (65 <= b && b <= 70) {
                    i3 = i6;
                    i4 = (b - 65) + 10;
                } else {
                    if (48 > b || b > 57) {
                        break;
                    }
                    i3 = i6;
                    i4 = b - 48;
                }
                i6 = i3 + i4;
                i++;
            }
            if (i != i5) {
                throw new RuntimeException("Expecting " + i2 + " hex chars at " + i);
            }
            return new Pair<>(Integer.valueOf(i6), Integer.valueOf(i));
        }

        private Pair<Integer, Integer> decodeUtf8(byte[] bArr, int i) {
            byte b = bArr[i];
            int i2 = new int[]{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4}[(byte) (b >> 4)];
            int i3 = b & ((byte) ((1 << (8 - i2)) - 1));
            int i4 = i + i2;
            int i5 = i + 1;
            while (i5 < i4 && bArr[i5] != 0) {
                i3 = (i3 << 6) + (bArr[i5] & 63);
                i5++;
            }
            return new Pair<>(Integer.valueOf(i3), Integer.valueOf(i5));
        }

        private int generateSymbolId(String str) {
            int size = this.symbolIds.size();
            this.symbolIds.put(str + "_" + size, Integer.valueOf(size));
            return size;
        }

        public String toString() {
            if (this.string != null) {
                return this.string;
            }
            HashMap hashMap = new HashMap();
            for (Map.Entry<String, Integer> entry : this.symbolIds.entrySet()) {
                hashMap.put(entry.getValue(), entry.getKey());
            }
            StringBuilder sb = new StringBuilder();
            int size = this.rules.size();
            for (int i = 0; i < size; i++) {
                sb.append(i).append(": ");
                appendRule(sb, hashMap, i, this.rules.get(i));
            }
            this.string = sb.toString();
            return this.string;
        }

        private void appendRule(StringBuilder sb, Map<Integer, String> map, int i, List<Pair<Integer, Integer>> list) {
            if (list.isEmpty() || list.get(list.size() - 1).a.intValue() != 0) {
                throw new RuntimeException("Malformed rule, does not end with LLAMA_GRETYPE_END: " + i);
            }
            sb.append(map.get(Integer.valueOf(i))).append(" ::= ");
            int size = list.size() - 1;
            for (int i2 = 0; i2 < size; i2++) {
                Pair<Integer, Integer> pair = list.get(i2);
                switch (pair.a.intValue()) {
                    case 0:
                        throw new RuntimeException("Unexpected end of rule: " + i + "," + i2);
                    case 1:
                        sb.append("| ");
                        break;
                    case 2:
                        sb.append(map.get(pair.b)).append(" ");
                        break;
                    case 3:
                        sb.append("[");
                        printGrammarChar(sb, pair.b.intValue());
                        break;
                    case 4:
                        sb.append("[^");
                        printGrammarChar(sb, pair.b.intValue());
                        break;
                    case 5:
                        if (i2 == 0 || !isCharElement(list.get(i2 - 1))) {
                            throw new RuntimeException("LLAMA_GRETYPE_CHAR_RNG_UPPER without preceding char: " + i + "," + i2);
                        }
                        sb.append("-");
                        printGrammarChar(sb, pair.b.intValue());
                        break;
                        break;
                    case 6:
                        if (i2 == 0 || !isCharElement(list.get(i2 - 1))) {
                            throw new RuntimeException("LLAMA_GRETYPE_CHAR_ALT without preceding char: " + i + "," + i2);
                        }
                        printGrammarChar(sb, pair.b.intValue());
                        break;
                    default:
                        throw new RuntimeException("Unknown type: " + pair.b);
                }
                if (isCharElement(pair)) {
                    switch (list.get(i2 + 1).b.intValue()) {
                        case 5:
                        case 6:
                            break;
                        default:
                            sb.append("] ");
                            break;
                    }
                }
            }
            sb.append("\n");
        }

        private void printGrammarChar(StringBuilder sb, int i) {
            if (i < 32 || i > 127) {
                sb.append(String.format("<U+%04X>", Integer.valueOf(i)));
            } else {
                sb.append((char) i);
            }
        }

        private boolean isCharElement(Pair<Integer, Integer> pair) {
            switch (pair.a.intValue()) {
                case 3:
                case 4:
                case 5:
                case 6:
                    return true;
                default:
                    return false;
            }
        }
    }

    public LlamaGrammar(File file) throws IOException {
        this(file.toPath());
    }

    public LlamaGrammar(Path path) throws IOException {
        this(Files.readString(path, StandardCharsets.UTF_8));
    }

    public LlamaGrammar(String str) {
        this.state = new ParseState(str);
        this.foreign = this.state.create();
    }

    public String toString() {
        return this.state.toString();
    }
}
