"""Python implementation of llama grammar parser directly translated from C++ source file in vendor/llama.cpp/common/grammar-parser.cpp."""

# flake8: noqa
from pathlib import Path
import sys
from ctypes import *  # type: ignore
from enum import Enum
from itertools import islice, groupby
from typing import (
    Any,
    Callable,
    Dict,
    Set,
    Generic,
    List,
    Optional,
    OrderedDict,
    TextIO,
    Tuple,
    TypeVar,
    Union,
    overload,
)

import llama_cpp.llama_cpp as llama_cpp

# Type aliases
llama_grammar_element = llama_cpp.llama_grammar_element
llama_grammar_element_p = llama_cpp.llama_grammar_element_p
llama_grammar_p = llama_cpp.llama_grammar_p

# Type variables
Ptr = TypeVar("Ptr", bound="const_char_p")
T = TypeVar("T")
U = TypeVar("U")
V = TypeVar("V")
W = TypeVar("W")


class Sentinel:
    """Used to mark the end of a iterator of std::vector & std::map."""


class LlamaGrammar:
    """Keeps reference counts of all the arguments, so that they are not
    garbage collected by Python."""

    def __del__(self) -> None:
        """Free the grammar pointer when the object is deleted."""
        if self.grammar is not None:
            llama_cpp.llama_grammar_free(self.grammar)
            self.grammar = None

    def __init__(
        self,
        parsed_grammar: "parse_state",
    ) -> None:
        """Initialize the grammar pointer from the parsed state."""
        self._grammar_rules = (
            parsed_grammar.c_rules()
        )  # type: std.vector[std.vector[LlamaGrammarElement]]
        self._n_rules = self._grammar_rules.size()  # type: int
        self._start_rule_index = parsed_grammar.symbol_ids.at("root")  # type: int
        self.init()

    @classmethod
    def from_string(cls, grammar: str, verbose: bool = True) -> "LlamaGrammar":
        """Convert a GBNF grammar to a Llama grammar."""
        parsed_grammar = parse(const_char_p(grammar))  # type: parse_state
        if parsed_grammar.rules.empty():
            raise ValueError(
                f"{cls.from_string.__name__}: error parsing grammar file: parsed_grammar.rules is empty"
            )
        if verbose:
            print(f"{cls.from_string.__name__} grammar:", file=sys.stderr)
            print_grammar(sys.stderr, parsed_grammar)
            print(file=sys.stderr)
        return cls(parsed_grammar)

    @classmethod
    def from_json_schema(
        cls,
        json_schema: str,
        verbose: bool = True,
    ) -> "LlamaGrammar":
        """Convert a JSON schema to a Llama grammar."""
        return cls.from_string(json_schema_to_gbnf(json_schema), verbose=verbose)

    @classmethod
    def from_file(cls, file: Union[str, Path], verbose: bool = True) -> "LlamaGrammar":
        try:
            with open(file) as f:
                grammar = f.read()
        except Exception as err:
            raise Exception(
                f"{cls.from_file.__name__}: error reading grammar file: {err}"
            )

        if grammar:
            return cls.from_string(grammar, verbose=verbose)

        raise ValueError(
            f"{cls.from_file.__name__}: error parsing grammar file: params_grammer is empty"
        )

    def init(self) -> None:
        # Step 1: Convert LlamaGrammarElement to llama_grammar_element
        self._element_lists = [
            [
                llama_grammar_element(c_int(elem.type.value), c_uint32(elem.value))
                for elem in subvector
            ]
            for subvector in self._grammar_rules
        ]  # type: List[List[llama_grammar_element]]

        # Step 2: Convert each list to llama_grammar_element array and get pointer
        self._element_arrays = [
            (llama_grammar_element * len(sublist))(*sublist)
            for sublist in self._element_lists
        ]  # type: List[Array[llama_grammar_element]]

        # Step 3: Get pointer of each array
        self._element_array_pointers = [
            cast(subarray, llama_grammar_element_p) for subarray in self._element_arrays
        ]  # type: List[llama_grammar_element_p]

        # Step 4: Make array of these pointers and get its pointer
        self._rules = (llama_grammar_element_p * len(self._element_array_pointers))(
            *self._element_array_pointers
        )
        self.grammar = llama_cpp.llama_grammar_init(
            self._rules, c_size_t(self._n_rules), c_size_t(self._start_rule_index)
        )

    def reset(self) -> None:
        if self.grammar is not None:
            llama_cpp.llama_grammar_free(self.grammar)
        self.init()


class LlamaGrammarElement:
    def __init__(self, type: "llama_gretype", value: int):
        self.type = type
        self.value = value  # Unicode code point or rule ID


class const_char_p:
    """C++ implementation of const char *."""

    def __init__(self, value: Union[str, Ptr], move: Optional[int] = None):
        if isinstance(value, const_char_p):
            # We're copying an existing const_char_p
            self.value = value.value
            self.pos = value.pos + (move or 0)
            return

        # We're creating a new const_char_p
        self.value = value
        self.pos = move or 0

    def __str__(self) -> str:
        assert self.value is not None, "null pointer"
        return self.value[self.pos :]

    def __getitem__(self, index: int) -> str:
        value = str(self)
        return value[index] if index < len(value) else ""

    @overload
    def __add__(self: Ptr, other: int) -> Ptr:
        ...

    @overload
    def __add__(self: Ptr, other: Ptr) -> int:
        ...

    def __add__(self: Ptr, other: Union[int, Ptr]) -> Union[int, Ptr]:
        return (
            self.__class__(self.value, self.pos + other)
            if isinstance(other, int)
            else self.pos + other.pos
        )

    @overload
    def __sub__(self: Ptr, other: int) -> Ptr:
        ...

    @overload
    def __sub__(self: Ptr, other: Ptr) -> int:
        ...

    def __sub__(self: Ptr, other: Union[int, Ptr]) -> Union[int, Ptr]:
        return (
            self.__class__(self.value, self.pos - other)
            if isinstance(other, int)
            else self.pos - other.pos
        )

    def __eq__(self: Ptr, other: Ptr) -> bool:
        assert self.value == other.value, "comparing pointers from different strings"
        return self.pos == other.pos

    def __lt__(self: Ptr, other: Ptr) -> bool:
        assert self.value == other.value, "comparing pointers from different strings"
        return self.pos < other.pos

    def __gt__(self: Ptr, other: Ptr) -> bool:
        assert self.value == other.value, "comparing pointers from different strings"
        return self.pos > other.pos


class std:
    @staticmethod
    def string(ptr: const_char_p, length: Optional[int] = None) -> str:
        """C++ implementation of std::string constructor."""
        value = str(ptr)
        if length is not None:
            value = value[:length]
        return value

    class vector(Generic[T], List[T]):
        """C++ implementation of std::vector."""

        class iterator:
            def __init__(self, vector: "std.vector[T]", index: int):
                self._vector = vector
                self._index = index
                self._version = vector._version

            def _check_version(self):
                if self._version != self._vector._version:
                    raise RuntimeError("Iterator used after vector was modified.")

            def __iter__(self):
                return self

            def __next__(self) -> T:
                self._check_version()
                if self._index >= self._vector.size():
                    raise StopIteration
                value = self._vector[self._index]
                self._index += 1
                return value

            def __add__(self, value: int) -> "std.vector[T].iterator":
                return self.__class__(self._vector, self._index + value)

            def __sub__(self, value: int) -> "std.vector[T].iterator":
                return self.__class__(self._vector, self._index - value)

        def __init__(self):
            self._version = 0

        def modify(self):
            # This is a bit of a hack to make sure iterators are invalidated
            self._version += 1

        def push_back(self, value: T) -> None:
            self.modify()
            self.append(value)

        def pop_back(self) -> None:
            self.modify()
            if not self.empty():
                self.pop()

        def back(self) -> T:
            return self[-1]

        def size(self) -> int:
            return len(self)

        def clear(self) -> None:
            self.modify()
            super().clear()

        def empty(self) -> bool:
            return self.size() == 0

        def data(self) -> "std.vector[T]":
            return self

        def resize(
            self,
            new_size: int,
            fill_value_factory: Optional[Callable[[], T]] = None,
        ) -> None:
            if new_size > self.size():
                if fill_value_factory is None:
                    raise ValueError("A fill value factory function must be provided.")
                self.reserve(new_size, fill_value_factory)
            elif new_size < self.size():
                self[:] = self[:new_size]

        def reserve(self, capacity: int, fill_value_factory: Callable[[], T]) -> None:
            if capacity > self.size():
                fill_value = fill_value_factory()
                self.extend([fill_value] * (capacity - self.size()))

        def front(self) -> T:
            if not self.empty():
                return self[0]
            else:
                raise IndexError("Vector is empty.")

        def assign(self, count: int, value: T) -> None:
            self.clear()
            self.extend([value] * count)

        def insert(
            self,
            pos: "std.vector[T].iterator",
            first: "std.vector[T].iterator",
            last: "std.vector[T].iterator",
        ) -> None:
            self[pos._index : pos._index] = list(
                islice(first._vector, first._index, last._index)
            )

        def begin(self) -> "std.vector[T].iterator":
            return self.iterator(self, 0)

        def end(self) -> "std.vector[T].iterator":
            return self.iterator(self, self.size())

    class map(Generic[T, U], OrderedDict[T, U]):
        """C++ implementation of std::map."""

        class iterator(Generic[V, W]):
            def __init__(self, _map: "std.map[T, U]", key: Union[T, Sentinel]):
                self._map = _map
                self.iter = iter(_map)
                self.key = key
                self._advance()

            def _sanitize_key(self) -> T:
                if isinstance(self.key, Sentinel):
                    raise StopIteration
                return self.key

            def _advance(self) -> None:
                try:
                    while next(self.iter) != self.key:
                        pass
                except StopIteration:
                    self.key = Sentinel()

            def __next__(self) -> Tuple[T, U]:
                key = self._sanitize_key()
                if key in self._map:
                    value = self._map[key]
                    self._advance()
                    return key, value
                else:
                    raise StopIteration

            def get(self) -> Tuple[T, U]:
                key = self._sanitize_key()
                return key, self._map[key]

            @property
            def first(self) -> T:
                return self._sanitize_key()

            @property
            def second(self) -> U:
                return self._map[self._sanitize_key()]

        def insert(
            self, key: T, value: U
        ) -> Tuple["std.map[T, U].iterator[T, U]", bool]:
            if key in self:
                return self.iterator(self, key), False
            else:
                self[key] = value
                return self.iterator(self, key), True

        def find(self, key: T) -> "std.map[T, U].iterator[T, U]":
            if key in self:
                return self.iterator(self, key)
            else:
                return self.end()

        def at(self, key: T) -> U:
            if key in self:
                return self[key]
            else:
                raise KeyError("The provided key is not found in the map.")

        def erase(self, iterator: "std.map[T, U].iterator[T, U]") -> None:
            key = iterator.first
            if key in self:
                del self[key]

        def size(self) -> int:
            return len(self)

        def empty(self) -> bool:
            return self.size() == 0

        def lower_bound(self, key: T) -> "std.map[T, U].iterator[T, U]":
            try:
                keys = sorted(list(self.keys()))  # type: ignore
                for k in keys:
                    if k >= key:
                        return self.iterator(self, k)
                raise ValueError("No key found that is not less than the input key")
            except TypeError:
                raise TypeError("Keys of type T cannot be sorted.")

        def begin(self) -> "std.map[T, U].iterator[T, U]":
            return self.iterator(self, next(iter(self)))

        def end(self) -> "std.map[T, U].iterator[T, U]":
            return self.iterator(self, Sentinel())


# // grammar element type
# enum llama_gretype {
#     // end of rule definition
#     LLAMA_GRETYPE_END            = 0,

#     // start of alternate definition for rule
#     LLAMA_GRETYPE_ALT            = 1,

#     // non-terminal element: reference to rule
#     LLAMA_GRETYPE_RULE_REF       = 2,

#     // terminal element: character (code point)
#     LLAMA_GRETYPE_CHAR           = 3,

#     // inverse char(s) ([^a], [^a-b] [^abc])
#     LLAMA_GRETYPE_CHAR_NOT       = 4,

#     // modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to
#     // be an inclusive range ([a-z])
#     LLAMA_GRETYPE_CHAR_RNG_UPPER = 5,


#     // modifies a preceding LLAMA_GRETYPE_CHAR or
#     // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])
#     LLAMA_GRETYPE_CHAR_ALT       = 6,
# };
class llama_gretype(Enum):
    """grammar element type"""

    LLAMA_GRETYPE_END = 0  # end of rule definition
    LLAMA_GRETYPE_ALT = 1  # start of alternate definition for rule
    LLAMA_GRETYPE_RULE_REF = 2  # non-terminal element: reference to rule
    LLAMA_GRETYPE_CHAR = 3  # terminal element: character (code point)
    LLAMA_GRETYPE_CHAR_NOT = 4  # inverse char(s) ([^a], [^a-b] [^abc])
    LLAMA_GRETYPE_CHAR_RNG_UPPER = 5  # modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to be an inclusive range ([a-z])
    LLAMA_GRETYPE_CHAR_ALT = 6  # modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])


# struct parse_state {
#     std::map<std::string, uint32_t>                 symbol_ids;
#     std::vector<std::vector<llama_grammar_element>> rules;
#     std::vector<const llama_grammar_element *> c_rules();
# };
class parse_state:
    def __init__(self):
        self.symbol_ids: std.map[str, int] = std.map()
        self.rules: std.vector[std.vector[LlamaGrammarElement]] = std.vector()

    # std::vector<const llama_grammar_element *> parse_state::c_rules() {
    #     std::vector<const llama_grammar_element *> ret;
    #     for (const auto & rule : rules) {
    #         ret.push_back(rule.data());
    #     }
    #     return ret;
    # }
    def c_rules(self) -> std.vector[std.vector[LlamaGrammarElement]]:
        ret = std.vector()  # type: std.vector[std.vector[LlamaGrammarElement]]
        for rule in self.rules:
            ret.push_back(rule.data())
        return ret

    def __repr__(self) -> str:
        return (
            f"parse_state(symbol_ids={len(self.symbol_ids)}, rules={len(self.rules)})"
        )


# struct llama_grammar {
#     const std::vector<std::vector<llama_grammar_element>>   rules;
#     std::vector<std::vector<const llama_grammar_element *>> stacks;
# };
# class llama_grammar:
#     def __init__(
#         self,
#         rules: std.vector[std.vector[llama_grammar_element]],
#         stacks: std.vector[std.vector[llama_grammar_element]],
#     ):
#         self.rules = rules
#         self.stacks = stacks


# uint32_t get_symbol_id(parse_state & state, const char * src, size_t len) {
#     uint32_t next_id = static_cast<uint32_t>(state.symbol_ids.size());
#     auto result = state.symbol_ids.insert(std::make_pair(std::string(src, len), next_id));
#     return result.first->second;
# }
def get_symbol_id(state: parse_state, src: const_char_p, len: int) -> int:
    next_id = state.symbol_ids.size()  # type: int
    result = state.symbol_ids.insert(std.string(src, len), next_id)
    return result[0].second  # type: ignore


# uint32_t generate_symbol_id(parse_state & state, const std::string & base_name) {
#     uint32_t next_id = static_cast<uint32_t>(state.symbol_ids.size());
#     state.symbol_ids[base_name + '_' + std::to_string(next_id)] = next_id;
#     return next_id;
# }
def generate_symbol_id(state: parse_state, base_name: str) -> int:
    next_id = state.symbol_ids.size()  # type: int
    state.symbol_ids[base_name + "_" + str(next_id)] = next_id
    return next_id


# void add_rule(
#         parse_state & state,
#         uint32_t      rule_id,
#         const std::vector<llama_grammar_element> & rule) {
#     if (state.rules.size() <= rule_id) {
#         state.rules.resize(rule_id + 1);
#     }
#     state.rules[rule_id] = rule;
# }
def add_rule(
    state: parse_state,
    rule_id: int,
    rule: std.vector[LlamaGrammarElement],
) -> None:
    if state.rules.size() <= rule_id:
        state.rules.resize(
            rule_id + 1,
            fill_value_factory=std.vector[LlamaGrammarElement],
        )
    state.rules[rule_id] = rule


# std::pair<uint32_t, const char *> decode_utf8(const char * src) {
#     static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
#     uint8_t  first_byte = static_cast<uint8_t>(*src);
#     uint8_t  highbits   = first_byte >> 4;
#     int      len        = lookup[highbits];
#     uint8_t  mask       = (1 << (8 - len)) - 1;
#     uint32_t value      = first_byte & mask;
#     const char * end    = src + len; // may overrun!
#     const char * pos    = src + 1;
#     for ( ; pos < end && *pos; pos++) {
#         value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
#     }
#     return std::make_pair(value, pos);
# }
def decode_utf8(src: const_char_p) -> Tuple[int, const_char_p]:
    """Decodes a UTF-8 character from the source string."""
    # Get the codepoint of the first character
    value = ord(src[0])
    # Move the pointer ahead one character
    pos = src + 1

    return value, pos


# bool is_word_char(char c) {
#     return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9');
# }
def is_word_char(c: str) -> bool:
    return ("a" <= c <= "z") or ("A" <= c <= "Z") or c == "-" or ("0" <= c <= "9")


# std::pair<uint32_t, const char *> parse_hex(const char * src, int size) {
#     const char * pos   = src;
#     const char * end   = src + size;
#     uint32_t     value = 0;
#     for ( ; pos < end && *pos; pos++) {
#         value <<= 4;
#         char c = *pos;
#         if ('a' <= c && c <= 'f') {
#             value += c - 'a' + 10;
#         } else if ('A' <= c && c <= 'F') {
#             value += c - 'A' + 10;
#         } else if ('0' <= c && c <= '9') {
#             value += c - '0';
#         } else {
#             break;
#         }
#     }
#     if (pos != end) {
#         throw std::runtime_error("expecting " + std::to_string(size) + " hex chars at " + src);
#     }
#     return std::make_pair(value, pos);
# }
def parse_hex(src: const_char_p, size: int) -> Tuple[int, const_char_p]:
    pos = const_char_p(src)  # type: const_char_p
    end = src + size  # type: const_char_p
    value = 0  # type: int
    while pos < end and pos[0]:
        value <<= 4
        c = pos[0]  # type: str
        if "a" <= c <= "f":
            value += ord(c) - ord("a") + 10
        elif "A" <= c <= "F":
            value += ord(c) - ord("A") + 10
        elif "0" <= c <= "9":
            value += ord(c) - ord("0")
        else:
            break
        pos += 1
    if pos != end:
        raise RuntimeError("expecting " + str(size) + " hex chars at " + str(src))
    return (value, pos)


# std::pair<uint32_t, const char *> parse_char(const char * src) {
#     if (*src == '\\') {
#         switch (src[1]) {
#             case 'x': return parse_hex(src + 2, 2);
#             case 'u': return parse_hex(src + 2, 4);
#             case 'U': return parse_hex(src + 2, 8);
#             case 't': return std::make_pair('\t', src + 2);
#             case 'r': return std::make_pair('\r', src + 2);
#             case 'n': return std::make_pair('\n', src + 2);
#             case '\\':
#             case '"':
#             case '[':
#             case ']':
#                 return std::make_pair(src[1], src + 2);
#             default:
#                 throw std::runtime_error(std::string("unknown escape at ") + src);
#         }
#     } else if (*src) {
#         return decode_utf8(src);
#     }
#     throw std::runtime_error("unexpected end of input");
# }
def parse_char(src: const_char_p) -> Tuple[int, const_char_p]:
    if src[0] == "\\":
        case = src[1]  # type: str
        if case == "x":
            return parse_hex(src + 2, 2)
        elif case == "u":
            return parse_hex(src + 2, 4)
        elif case == "U":
            return parse_hex(src + 2, 8)
        elif case == "t":
            return (ord("\t"), src + 2)  # implicit cast
        elif case == "r":
            return (ord("\r"), src + 2)  # implicit cast
        elif case == "n":
            return (ord("\n"), src + 2)  # implicit cast
        elif case in ("\\", '"', "[", "]"):
            return (ord(case), src + 2)  # implicit cast
        else:
            raise RuntimeError("unknown escape at " + str(src))
    elif src[0]:
        return decode_utf8(src)
    else:
        raise RuntimeError("unexpected end of input")


# const char * parse_name(const char * src) {
#     const char * pos = src;
#     while (is_word_char(*pos)) {
#         pos++;
#     }
#     if (pos == src) {
#         throw std::runtime_error(std::string("expecting name at ") + src);
#     }
#     return pos;
# }
def parse_name(src: const_char_p) -> const_char_p:
    pos = const_char_p(src)  # type: const_char_p
    while is_word_char(pos[0]):
        pos += 1
    if pos == src:
        raise RuntimeError("expecting name at " + str(src))
    return pos


# const char * parse_space(const char * src, bool newline_ok) {
#     const char * pos = src;
#     while (*pos == ' ' || *pos == '\t' || *pos == '#' ||
#             (newline_ok && (*pos == '\r' || *pos == '\n'))) {
#         if (*pos == '#') {
#             while (*pos && *pos != '\r' && *pos != '\n') {
#                 pos++;
#             }
#         } else {
#             pos++;
#         }
#     }
#     return pos;
# }
def parse_space(src: const_char_p, newline_ok: bool) -> const_char_p:
    pos = const_char_p(src)  # type: const_char_p
    while pos[0] in (" ", "\t", "#") or (newline_ok and pos[0] in ("\r", "\n")):
        if pos[0] == "#":
            while pos[0] is not None and pos[0] not in ("\r", "\n"):
                pos += 1
        else:
            pos += 1
    return pos


# const char * parse_sequence(
#         parse_state                        & state,
#         const char                         * src,
#         const std::string                  & rule_name,
#         std::vector<llama_grammar_element> & out_elements,
#         bool                                 is_nested) {
def parse_sequence(
    state: parse_state,
    src: const_char_p,
    rule_name: str,
    out_elements: std.vector[LlamaGrammarElement],
    is_nested: bool,
) -> const_char_p:
    # size_t last_sym_start = out_elements.size();
    # const char * pos = src;
    last_sym_start = out_elements.size()  # type: int
    pos = const_char_p(src)  # type: const_char_p
    # while (*pos) {
    while pos[0]:
        # if (*pos == '"') { // literal string
        #     pos++;
        #     last_sym_start = out_elements.size();
        #     while (*pos != '"') {
        #         auto char_pair = parse_char(pos);
        #                 pos       = char_pair.second;
        #         out_elements.push_back({LLAMA_GRETYPE_CHAR, char_pair.first});
        #     }
        #     pos = parse_space(pos + 1, is_nested);
        if pos[0] == '"':  # literal string
            pos += 1
            last_sym_start = out_elements.size()
            while pos[0] != '"':
                char_pair = parse_char(pos)  # type: Tuple[int, const_char_p]
                pos = char_pair[1]
                out_elements.push_back(
                    LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_CHAR, char_pair[0])
                )
            pos = parse_space(pos + 1, is_nested)
        # } else if (*pos == '[') { // char range(s)
        #     pos++;
        #     enum llama_gretype start_type = LLAMA_GRETYPE_CHAR;
        elif pos[0] == "[":  # char range(s)
            pos += 1
            start_type = llama_gretype.LLAMA_GRETYPE_CHAR  # type: llama_gretype
            # if (*pos == '^') {
            #     pos++;
            #     start_type = LLAMA_GRETYPE_CHAR_NOT;
            # }
            # last_sym_start = out_elements.size();
            if pos[0] == "^":
                pos += 1
                start_type = llama_gretype.LLAMA_GRETYPE_CHAR_NOT
            last_sym_start = out_elements.size()
            # while (*pos != ']') {
            #     auto char_pair = parse_char(pos);
            #             pos       = char_pair.second;
            #     enum llama_gretype type = last_sym_start < out_elements.size()
            #         ? LLAMA_GRETYPE_CHAR_ALT
            #         : start_type;
            #     out_elements.push_back({type, char_pair.first});
            while pos[0] != "]":
                char_pair = parse_char(pos)  # type: Tuple[int, const_char_p]
                pos = char_pair[1]
                type = (
                    llama_gretype.LLAMA_GRETYPE_CHAR_ALT
                    if last_sym_start < out_elements.size()
                    else start_type
                )  # type: llama_gretype
                out_elements.push_back(LlamaGrammarElement(type, char_pair[0]))
                #     if (pos[0] == '-' && pos[1] != ']') {
                #         auto endchar_pair = parse_char(pos + 1);
                #                 pos          = endchar_pair.second;
                #         out_elements.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first});
                #     }
                # }
                if pos[0] == "-" and pos[1] != "]":
                    endchar_pair = parse_char(pos + 1)  # type: Tuple[int, const_char_p]
                    pos = endchar_pair[1]
                    out_elements.push_back(
                        LlamaGrammarElement(
                            llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER,
                            endchar_pair[0],
                        )
                    )
            # pos = parse_space(pos + 1, is_nested);
            pos = parse_space(pos + 1, is_nested)
        # } else if (is_word_char(*pos)) { // rule reference
        #     const char * name_end    = parse_name(pos);
        #     uint32_t     ref_rule_id = get_symbol_id(state, pos, name_end - pos);
        #     pos = parse_space(name_end, is_nested);
        #     last_sym_start = out_elements.size();
        #     out_elements.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id});
        elif is_word_char(pos[0]):  # rule reference
            name_end = parse_name(pos)  # type: const_char_p
            ref_rule_id = get_symbol_id(state, pos, name_end - pos)  # type: int
            pos = parse_space(name_end, is_nested)
            last_sym_start = out_elements.size()
            out_elements.push_back(
                LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_RULE_REF, ref_rule_id)
            )
        # } else if (*pos == '(') { // grouping
        #     // parse nested alternates into synthesized rule
        #     pos = parse_space(pos + 1, true);
        #     uint32_t sub_rule_id = generate_symbol_id(state, rule_name);
        #     pos = parse_alternates(state, pos, rule_name, sub_rule_id, true);
        #     last_sym_start = out_elements.size();
        #     // output reference to synthesized rule
        #     out_elements.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
        #     if (*pos != ')') {
        #         throw std::runtime_error(std::string("expecting ')' at ") + pos);
        #     }
        #     pos = parse_space(pos + 1, is_nested);
        elif pos[0] == "(":  # grouping
            # parse nested alternates into synthesized rule
            pos = parse_space(pos + 1, True)
            sub_rule_id = generate_symbol_id(state, rule_name)  # type: int
            pos = parse_alternates(state, pos, rule_name, sub_rule_id, True)
            last_sym_start = out_elements.size()
            # output reference to synthesized rule
            out_elements.push_back(
                LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_RULE_REF, sub_rule_id)
            )
            if pos[0] != ")":
                raise RuntimeError("expecting ')' at " + str(pos))
            pos = parse_space(pos + 1, is_nested)
        # } else if (*pos == '*' || *pos == '+' || *pos == '?') { // repetition operator
        #     if (last_sym_start == out_elements.size()) {
        #         throw std::runtime_error(std::string("expecting preceeding item to */+/? at ") + pos);
        #     }
        elif pos[0] in ("*", "+", "?"):  # repetition operator
            if last_sym_start == out_elements.size():
                raise RuntimeError("expecting preceding item to */+/? at " + str(pos))
            # // apply transformation to previous symbol (last_sym_start to end) according to
            # // rewrite rules:
            # // S* --> S' ::= S S' |
            # // S+ --> S' ::= S S' | S
            # // S? --> S' ::= S |
            # uint32_t sub_rule_id = generate_symbol_id(state, rule_name);
            # std::vector<llama_grammar_element> sub_rule;
            # // add preceding symbol to generated rule
            # sub_rule.insert(
            #     sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end());
            sub_rule_id = generate_symbol_id(state, rule_name)  # type: int
            sub_rule = std.vector[
                LlamaGrammarElement
            ]()  # type: std.vector[LlamaGrammarElement]
            sub_rule.insert(
                sub_rule.end(),
                out_elements.begin() + last_sym_start,
                out_elements.end(),
            )
            # if (*pos == '*' || *pos == '+') {
            #     // cause generated rule to recurse
            #     sub_rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
            # }
            # // mark start of alternate def
            # sub_rule.push_back({LLAMA_GRETYPE_ALT, 0});
            if pos[0] in ("*", "+"):
                sub_rule.push_back(
                    LlamaGrammarElement(
                        llama_gretype.LLAMA_GRETYPE_RULE_REF, sub_rule_id
                    )
                )
            sub_rule.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_ALT, 0))
            # if (*pos == '+') {
            #     // add preceding symbol as alternate only for '+' (otherwise empty)
            #     sub_rule.insert(
            #         sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end());
            # }
            # sub_rule.push_back({LLAMA_GRETYPE_END, 0});
            # add_rule(state, sub_rule_id, sub_rule);
            # // in original rule, replace previous symbol with reference to generated rule
            # out_elements.resize(last_sym_start);
            # out_elements.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
            # pos = parse_space(pos + 1, is_nested);
            if pos[0] == "+":
                # add preceding symbol as alternate only for '+' (otherwise empty)
                sub_rule.insert(
                    sub_rule.end(),
                    out_elements.begin() + last_sym_start,
                    out_elements.end(),
                )
            sub_rule.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_END, 0))
            add_rule(state, sub_rule_id, sub_rule)
            # in original rule, replace previous symbol with reference to generated rule
            out_elements.resize(last_sym_start)
            out_elements.push_back(
                LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_RULE_REF, sub_rule_id)
            )
            pos = parse_space(pos + 1, is_nested)
        # } else {
        #     break;
        # }
        else:
            break
    #     }
    #     return pos;
    # }
    return pos


# const char * parse_alternates(
#         parse_state       & state,
#         const char        * src,
#         const std::string & rule_name,
#         uint32_t            rule_id,
#         bool                is_nested) {
#     std::vector<llama_grammar_element> rule;
#     const char * pos = parse_sequence(state, src, rule_name, rule, is_nested);
#     while (*pos == '|') {
#         rule.push_back({LLAMA_GRETYPE_ALT, 0});
#         pos = parse_space(pos + 1, true);
#         pos = parse_sequence(state, pos, rule_name, rule, is_nested);
#     }
#     rule.push_back({LLAMA_GRETYPE_END, 0});
#     add_rule(state, rule_id, rule);
#     return pos;
# }
def parse_alternates(
    state: parse_state,
    src: const_char_p,
    rule_name: str,
    rule_id: int,
    is_nested: bool,
) -> const_char_p:
    rule = std.vector()  # type: std.vector[LlamaGrammarElement]
    pos = parse_sequence(state, src, rule_name, rule, is_nested)  # type: const_char_p
    while pos[0] == "|":
        rule.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_ALT, 0))
        pos = parse_space(pos + 1, True)
        pos = parse_sequence(state, pos, rule_name, rule, is_nested)
    rule.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_END, 0))
    add_rule(state, rule_id, rule)
    return pos


# const char * parse_rule(parse_state & state, const char * src) {
#     const char * name_end = parse_name(src);
#     const char * pos      = parse_space(name_end, false);
#     size_t       name_len = name_end - src;
#     uint32_t     rule_id  = get_symbol_id(state, src, name_len);
#     const std::string name(src, name_len);

#     if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) {
#         throw std::runtime_error(std::string("expecting ::= at ") + pos);
#     }
#     pos = parse_space(pos + 3, true);

#     pos = parse_alternates(state, pos, name, rule_id, false);


#     if (*pos == '\r') {
#         pos += pos[1] == '\n' ? 2 : 1;
#     } else if (*pos == '\n') {
#         pos++;
#     } else if (*pos) {
#         throw std::runtime_error(std::string("expecting newline or end at ") + pos);
#     }
#     return parse_space(pos, true);
# }
def parse_rule(state: parse_state, src: const_char_p) -> const_char_p:
    name_end = parse_name(src)  # type: const_char_p
    pos = parse_space(name_end, False)  # type: const_char_p
    name_len = name_end - src  # type: int
    rule_id = get_symbol_id(state, src, name_len)  # type: int
    name = std.string(src, name_len)  # type: str

    if not (pos[0] == ":" and pos[1] == ":" and pos[2] == "="):
        raise RuntimeError("expecting ::= at " + str(pos))

    pos = parse_space(pos + 3, True)  # type: const_char_p
    pos = parse_alternates(state, pos, name, rule_id, False)  # type: const_char_p

    if pos[0] == "\r":
        pos += 2 if pos[1] == "\n" else 1
    elif pos[0] == "\n":
        pos += 1
    elif pos[0]:
        raise RuntimeError("expecting newline or end at " + str(pos))
    return parse_space(pos, True)


# parse_state parse(const char * src) {
#     try {
#         parse_state state;
#         const char * pos = parse_space(src, true);
#         while (*pos) {
#             pos = parse_rule(state, pos);
#         }
#         return state;
#     } catch (const std::exception & err) {
#         fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what());
#         return parse_state();
#     }
# }
def parse(src: const_char_p) -> parse_state:
    try:
        state = parse_state()  # type: parse_state
        pos = parse_space(src, True)  # type: const_char_p
        while pos[0]:
            pos = parse_rule(state, pos)
        return state
    except Exception as err:
        print(f"{parse.__name__}: error parsing grammar: {err}")
        return parse_state()


# void print_grammar_char(FILE * file, uint32_t c) {
#     if (0x20 <= c && c <= 0x7f) {
#         fprintf(file, "%c", static_cast<char>(c));
#     } else {
#         // cop out of encoding UTF-8
#         fprintf(file, "<U+%04X>", c);
#     }
# }
def print_grammar_char(file: TextIO, c: int) -> None:
    if 0x20 <= c and c <= 0x7F:
        file.write(chr(c))
    else:
        # cop out of encoding UTF-8
        file.write(f"<U+{c:04X}>")


# bool is_char_element(llama_grammar_element elem) {
#     switch (elem.type) {
#         case LLAMA_GRETYPE_CHAR:           return true;
#         case LLAMA_GRETYPE_CHAR_NOT:       return true;
#         case LLAMA_GRETYPE_CHAR_ALT:       return true;
#         case LLAMA_GRETYPE_CHAR_RNG_UPPER: return true;
#         default:                           return false;
#     }
# }
def is_char_element(elem: LlamaGrammarElement) -> bool:
    return elem.type in (
        llama_gretype.LLAMA_GRETYPE_CHAR,
        llama_gretype.LLAMA_GRETYPE_CHAR_NOT,
        llama_gretype.LLAMA_GRETYPE_CHAR_ALT,
        llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER,
    )


# void print_rule(
#         FILE     * file,
#         uint32_t   rule_id,
#         const std::vector<llama_grammar_element> & rule,
#         const std::map<uint32_t, std::string>    & symbol_id_names) {
def print_rule(
    file: TextIO,
    rule_id: int,
    rule: std.vector[LlamaGrammarElement],
    symbol_id_names: std.map[int, str],
) -> None:
    #     if (rule.empty() || rule.back().type != LLAMA_GRETYPE_END) {
    #         throw std::runtime_error(
    #             "malformed rule, does not end with LLAMA_GRETYPE_END: " + std::to_string(rule_id));
    #     }
    #     fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str());
    if rule.empty() or rule.back().type != llama_gretype.LLAMA_GRETYPE_END:
        raise RuntimeError(
            "malformed rule, does not end with LLAMA_GRETYPE_END: " + str(rule_id)
        )
    print(f"{symbol_id_names.at(rule_id)} ::=", file=file, end=" ")
    #     for (size_t i = 0, end = rule.size() - 1; i < end; i++) {
    #         llama_grammar_element elem = rule[i];
    #         switch (elem.type) {
    #             case LLAMA_GRETYPE_END:
    #                 throw std::runtime_error(
    #                     "unexpected end of rule: " + std::to_string(rule_id) + "," +
    #                     std::to_string(i));
    #             case LLAMA_GRETYPE_ALT:
    #                 fprintf(file, "| ");
    #                 break;
    #             case LLAMA_GRETYPE_RULE_REF:
    #                 fprintf(file, "%s ", symbol_id_names.at(elem.value).c_str());
    #                 break;
    #             case LLAMA_GRETYPE_CHAR:
    #                 fprintf(file, "[");
    #                 print_grammar_char(file, elem.value);
    #                 break;
    #             case LLAMA_GRETYPE_CHAR_NOT:
    #                 fprintf(file, "[^");
    #                 print_grammar_char(file, elem.value);
    #                 break;
    #             case LLAMA_GRETYPE_CHAR_RNG_UPPER:
    #                 if (i == 0 || !is_char_element(rule[i - 1])) {
    #                     throw std::runtime_error(
    #                         "LLAMA_GRETYPE_CHAR_RNG_UPPER without preceding char: " +
    #                         std::to_string(rule_id) + "," + std::to_string(i));
    #                 }
    #                 fprintf(file, "-");
    #                 print_grammar_char(file, elem.value);
    #                 break;
    #             case LLAMA_GRETYPE_CHAR_ALT:
    #                 if (i == 0 || !is_char_element(rule[i - 1])) {
    #                     throw std::runtime_error(
    #                         "LLAMA_GRETYPE_CHAR_ALT without preceding char: " +
    #                         std::to_string(rule_id) + "," + std::to_string(i));
    #                 }
    #                 print_grammar_char(file, elem.value);
    #                 break;
    #         }
    for i, elem in enumerate(rule[:-1]):
        case = elem.type  # type: llama_gretype
        if case is llama_gretype.LLAMA_GRETYPE_END:
            raise RuntimeError("unexpected end of rule: " + str(rule_id) + "," + str(i))
        elif case is llama_gretype.LLAMA_GRETYPE_ALT:
            print("| ", file=file, end="")
        elif case is llama_gretype.LLAMA_GRETYPE_RULE_REF:
            print(f"{symbol_id_names.at(elem.value)} ", file=file, end="")
        elif case is llama_gretype.LLAMA_GRETYPE_CHAR:
            print("[", file=file, end="")
            print_grammar_char(file, elem.value)
        elif case is llama_gretype.LLAMA_GRETYPE_CHAR_NOT:
            print("[^", file=file, end="")
            print_grammar_char(file, elem.value)
        elif case is llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER:
            if i == 0 or not is_char_element(rule[i - 1]):
                raise RuntimeError(
                    "LLAMA_GRETYPE_CHAR_RNG_UPPER without preceding char: "
                    + str(rule_id)
                    + ","
                    + str(i)
                )
            print("-", file=file, end="")
            print_grammar_char(file, elem.value)
        elif case is llama_gretype.LLAMA_GRETYPE_CHAR_ALT:
            if i == 0 or not is_char_element(rule[i - 1]):
                raise RuntimeError(
                    "LLAMA_GRETYPE_CHAR_ALT without preceding char: "
                    + str(rule_id)
                    + ","
                    + str(i)
                )
            print_grammar_char(file, elem.value)
        # if (is_char_element(elem)) {
        #     switch (rule[i + 1].type) {
        #         case LLAMA_GRETYPE_CHAR_ALT:
        #         case LLAMA_GRETYPE_CHAR_RNG_UPPER:
        #             break;
        #         default:
        #             fprintf(file, "] ");
        if is_char_element(elem):
            if rule[i + 1].type in (
                llama_gretype.LLAMA_GRETYPE_CHAR_ALT,
                llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER,
            ):
                pass
            else:
                print("] ", file=file, end="")
    #             }
    #         }
    #     }
    #     fprintf(file, "\n");
    # }
    print(file=file)


# void print_grammar(FILE * file, const parse_state & state) {
#     try {
#         std::map<uint32_t, std::string> symbol_id_names;
#         for (auto kv : state.symbol_ids) {
#             symbol_id_names[kv.second] = kv.first;
#         }
#         for (size_t i = 0, end = state.rules.size(); i < end; i++) {
#             // fprintf(file, "%zu: ", i);
#             // print_rule_binary(file, state.rules[i]);
#             print_rule(file, i, state.rules[i], symbol_id_names);
#             // fprintf(file, "\n");
#         }
#     } catch (const std::exception & err) {
#         fprintf(stderr, "\n%s: error printing grammar: %s\n", __func__, err.what());
#     }
# }
def print_grammar(file: TextIO, state: parse_state) -> None:
    try:
        symbol_id_names = std.map()  # type: std.map[int, str]
        for kv in state.symbol_ids.items():
            symbol_id_names[kv[1]] = kv[0]

        for i, rule in enumerate(state.rules):
            print_rule(file, i, rule, symbol_id_names)
    except Exception as err:
        print(
            f"{print_grammar.__name__}: error printing grammar: {err}",
            file=sys.stderr,
        )


"""llama.cpp gbnf rules from vendor/llama.cpp/grammars"""

ARITHMETIC_GBNF = r"""
root  ::= (expr "=" ws term "\n")+
expr  ::= term ([-+*/] term)*
term  ::= ident | num | "(" ws expr ")" ws
ident ::= [a-z] [a-z0-9_]* ws
num   ::= [0-9]+ ws
ws    ::= [ \t\n]*
"""

C_GBNF = r"""
root ::= (declaration)*

declaration ::= dataType identifier "(" parameter? ")" "{" statement* "}"

dataType  ::= "int" ws | "float" ws | "char" ws
identifier ::= [a-zA-Z_] [a-zA-Z_0-9]*

parameter ::= dataType identifier

statement ::=
    ( dataType identifier ws "=" ws expression ";" ) |
    ( identifier ws "=" ws expression ";" ) |
    ( identifier ws "(" argList? ")" ";" ) |
    ( "return" ws expression ";" ) |
    ( "while" "(" condition ")" "{" statement* "}" ) |
    ( "for" "(" forInit ";" ws condition ";" ws forUpdate ")" "{" statement* "}" ) |
    ( "if" "(" condition ")" "{" statement* "}" ("else" "{" statement* "}")? ) |
    ( singleLineComment ) |
    ( multiLineComment )

forInit ::= dataType identifier ws "=" ws expression | identifier ws "=" ws expression
forUpdate ::= identifier ws "=" ws expression

condition ::= expression relationOperator expression
relationOperator ::= ("<=" | "<" | "==" | "!=" | ">=" | ">")

expression ::= term (("+" | "-") term)*
term ::= factor(("*" | "/") factor)*

factor ::= identifier | number | unaryTerm | funcCall | parenExpression
unaryTerm ::= "-" factor
funcCall ::= identifier "(" argList? ")"
parenExpression ::= "(" ws expression ws ")"

argList ::= expression ("," ws expression)*

number ::= [0-9]+

singleLineComment ::= "//" [^\n]* "\n"
multiLineComment ::= "/*" ( [^*] | ("*" [^/]) )* "*/"

ws ::= ([ \t\n]+)
"""

CHESS_GBNF = r"""
root   ::= object
value  ::= object | array | string | number | ("true" | "false" | "null") ws

object ::=
  "{" ws (
            string ":" ws value
    ("," ws string ":" ws value)*
  )? "}" ws

array  ::=
  "[" ws (
            value
    ("," ws value)*
  )? "]" ws

string ::=
  "\"" (
    [^"\\] |
    "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
  )* "\"" ws

number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws

# Optional space: by convention, applied in this grammar after literal chars when allowed
ws ::= ([ \t\n] ws)?
"""

JAPANESE_GBNF = r"""
root   ::= object
value  ::= object | array | string | number | ("true" | "false" | "null") ws

object ::=
  "{" ws (
            string ":" ws value
    ("," ws string ":" ws value)*
  )? "}" ws

array  ::=
  "[" ws (
            value
    ("," ws value)*
  )? "]" ws

string ::=
  "\"" (
    [^"\\] |
    "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
  )* "\"" ws

number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws

# Optional space: by convention, applied in this grammar after literal chars when allowed
ws ::= ([ \t\n] ws)?
"""

JSON_ARR_GBNF = r"""
# This is the same as json.gbnf but we restrict whitespaces at the end of the root array
# Useful for generating JSON arrays

root   ::= arr
value  ::= object | array | string | number | ("true" | "false" | "null") ws

arr  ::=
  "[\n" ws (
            value
    (",\n" ws value)*
  )? "]"

object ::=
  "{" ws (
            string ":" ws value
    ("," ws string ":" ws value)*
  )? "}" ws

array  ::=
  "[" ws (
            value
    ("," ws value)*
  )? "]" ws

string ::=
  "\"" (
    [^"\\\x7F\x00-\x1F] |
    "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
  )* "\"" ws

number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws

# Optional space: by convention, applied in this grammar after literal chars when allowed
ws ::= ([ \t\n] ws)?
"""


JSON_GBNF = r"""
root   ::= object
value  ::= object | array | string | number | ("true" | "false" | "null") ws

object ::=
  "{" ws (
            string ":" ws value
    ("," ws string ":" ws value)*
  )? "}" ws

array  ::=
  "[" ws (
            value
    ("," ws value)*
  )? "]" ws

string ::=
  "\"" (
    [^"\\\x7F\x00-\x1F] |
    "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
  )* "\"" ws

number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws

ws ::= ([ \t\n] ws)?
"""

LIST_GBNF = r"""
root ::= item+

# Excludes various line break characters
item ::= "- " [^\r\n\x0b\x0c\x85\u2028\u2029]+ "\n"
"""

"""llama.cpp json-schema to grammar converter from vendor/llama.cpp/examples/json-schema-to-grammar.py"""
import json
import re
from typing import List, Optional

# whitespace is constrained to a single space char to prevent model "running away" in
# whitespace. Also maybe improves generation quality?
SPACE_RULE = '" "?'


INVALID_RULE_CHARS_RE = re.compile(r"[^a-zA-Z0-9-]+")
GRAMMAR_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"]')
GRAMMAR_LITERAL_ESCAPES = {"\r": "\\r", "\n": "\\n", '"': '\\"'}

# whitespace is constrained to a single space char to prevent model "running away" in
# whitespace. Also maybe improves generation quality?
SPACE_RULE = '" "?'


def _build_repetition(item_rule, min_items, max_items, separator_rule=None, item_rule_is_literal=False):
    if not separator_rule:
        if min_items == 0 and max_items == 1:
            return f'{item_rule}?'
        elif min_items == 1 and max_items is None:
            return f'{item_rule}+'

    result = ''

    if min_items > 0:
        if item_rule_is_literal and separator_rule is None:
            result = '"' + (item_rule[1:-1] * min_items) + '"'
        else:
            result = (f' {separator_rule} ' if separator_rule else ' ').join([item_rule] * min_items)

    def opt_repetitions(up_to_n, prefix_with_sep=False):
        '''
            - n=4, no sep:             '(a (a (a (a)?)?)?)?'
            - n=4, sep=',', prefix:    '("," a ("," a ("," a ("," a)?)?)?)?'
            - n=4, sep=',', no prefix: '(a ("," a ("," a ("," a)?)?)?)?'
        '''

        content = f'{separator_rule} {item_rule}' if prefix_with_sep and separator_rule else item_rule
        if up_to_n == 0:
            return ''
        elif up_to_n == 1:
            return f'({content})?'
        elif separator_rule and not prefix_with_sep:
            return f'({content} {opt_repetitions(up_to_n - 1, prefix_with_sep=True)})?'
        else:
            return (f'({content} ' * up_to_n).rstrip() + (')?' * up_to_n)

    if min_items > 0 and max_items != min_items:
        result += ' '

    if max_items is not None:
        result += opt_repetitions(max_items - min_items, prefix_with_sep=min_items > 0)
    else:
        item_operator = f'({separator_rule + " " if separator_rule else ""}{item_rule})'

        if min_items == 0 and separator_rule:
            result = f'({item_rule} {item_operator}*)?'
        else:
            result += f'{item_operator}*'

    return result



class BuiltinRule:
    def __init__(self, content: str, deps: list = None):
        self.content = content
        self.deps = deps or []

_up_to_15_digits = _build_repetition('[0-9]', 0, 15)

PRIMITIVE_RULES = {
    'boolean'      : BuiltinRule('("true" | "false") space', []),
    'decimal-part' : BuiltinRule('[0-9] ' + _up_to_15_digits, []),
    'integral-part': BuiltinRule('[0-9] | [1-9] ' + _up_to_15_digits, []),
    'number'       : BuiltinRule('("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space', ['integral-part', 'decimal-part']),
    'integer'      : BuiltinRule('("-"? integral-part) space', ['integral-part']),
    'value'        : BuiltinRule('object | array | string | number | boolean | null', ['object', 'array', 'string', 'number', 'boolean', 'null']),
    'object'       : BuiltinRule('"{" space ( string ":" space value ("," space string ":" space value)* )? "}" space', ['string', 'value']),
    'array'        : BuiltinRule('"[" space ( value ("," space value)* )? "]" space', ['value']),
    'uuid'         : BuiltinRule(r'"\"" ' + ' "-" '.join('[0-9a-fA-F]' * n for n in [8, 4, 4, 4, 12]) + r' "\"" space', []),
    'char'         : BuiltinRule(r'[^"\\] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])', []),
    'string'       : BuiltinRule(r'"\"" char* "\"" space', ['char']),
    'null'         : BuiltinRule('"null" space', []),
}

# TODO: support "uri", "email" string formats
STRING_FORMAT_RULES = {
    'date'            : BuiltinRule('[0-9] [0-9] [0-9] [0-9] "-" ( "0" [1-9] | "1" [0-2] ) "-" ( \"0\" [1-9] | [1-2] [0-9] | "3" [0-1] )', []),
    'time'            : BuiltinRule('([01] [0-9] | "2" [0-3]) ":" [0-5] [0-9] ":" [0-5] [0-9] ( "." [0-9] [0-9] [0-9] )? ( "Z" | ( "+" | "-" ) ( [01] [0-9] | "2" [0-3] ) ":" [0-5] [0-9] )', []),
    'date-time'       : BuiltinRule('date "T" time', ['date', 'time']),
    'date-string'     : BuiltinRule('"\\"" date "\\"" space', ['date']),
    'time-string'     : BuiltinRule('"\\"" time "\\"" space', ['time']),
    'date-time-string': BuiltinRule('"\\"" date-time "\\"" space', ['date-time']),
}

DOTALL = '[\\U00000000-\\U0010FFFF]'
DOT = '[^\\x0A\\x0D]'

RESERVED_NAMES = set(["root", "dot", *PRIMITIVE_RULES.keys(), *STRING_FORMAT_RULES.keys()])


NON_LITERAL_SET = set('|.()[]{}*+?')
ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = set('[]()|{}*+?')




class SchemaConverter:
    def __init__(self, *, prop_order, allow_fetch, dotall, raw_pattern):
        self._prop_order = prop_order
        self._allow_fetch = allow_fetch
        self._dotall = dotall
        self._raw_pattern = raw_pattern
        self._rules = {
            'space': SPACE_RULE,
        }
        self._refs = {}
        self._refs_being_resolved = set()

    def _format_literal(self, literal):
        escaped = GRAMMAR_LITERAL_ESCAPE_RE.sub(
            lambda m: GRAMMAR_LITERAL_ESCAPES.get(m.group(0)), literal
        )
        return f'"{escaped}"'

    def not_literal(self, literal: str, dotall: bool = True, maybe_escaped_underscores = False) -> str:
        '''
            not_literal('a') -> '[^a]'
            not_literal('abc') -> '([^a] | "a" ([^b] | "b" ([^c])?)?)?'
        '''
        assert len(literal) > 0, 'Empty literal not supported'
        def recurse(i: int):
            c = literal[i]
            if maybe_escaped_underscores and c == '_':
                yield f'[^{c}\\\\]'
                yield ' | '
                yield f'"\\\\"? "{c}"'
            else:
                yield f'[^{c}]'
            if i < len(literal) - 1:
                yield ' | '
                yield self._format_literal(c)
                yield ' ('
                yield from recurse(i + 1)
                yield ')?'

        return ''.join(('(', *recurse(0), ')'))

    def _add_rule(self, name, rule):
        esc_name = INVALID_RULE_CHARS_RE.sub('-', name)
        if esc_name not in self._rules or self._rules[esc_name] == rule:
            key = esc_name
        else:
            i = 0
            while f'{esc_name}{i}' in self._rules and self._rules[f'{esc_name}{i}'] != rule:
                i += 1
            key = f'{esc_name}{i}'
        self._rules[key] = rule
        return key

    def resolve_refs(self, schema: dict, url: str):
        '''
            Resolves all $ref fields in the given schema, fetching any remote schemas,
            replacing $ref with absolute reference URL and populating self._refs with the
            respective referenced (sub)schema dictionaries.
        '''
        def visit(n: dict):
            if isinstance(n, list):
                return [visit(x) for x in n]
            elif isinstance(n, dict):
                ref = n.get('$ref')
                if ref is not None and ref not in self._refs:
                    if ref.startswith('https://'):
                        assert self._allow_fetch, 'Fetching remote schemas is not allowed (use --allow-fetch for force)'
                        import requests

                        frag_split = ref.split('#')
                        base_url = frag_split[0]

                        target = self._refs.get(base_url)
                        if target is None:
                            target = self.resolve_refs(requests.get(ref).json(), base_url)
                            self._refs[base_url] = target

                        if len(frag_split) == 1 or frag_split[-1] == '':
                            return target
                    elif ref.startswith('#/'):
                        target = schema
                        ref = f'{url}{ref}'
                        n['$ref'] = ref
                    else:
                        raise ValueError(f'Unsupported ref {ref}')

                    for sel in ref.split('#')[-1].split('/')[1:]:
                        assert target is not None and sel in target, f'Error resolving ref {ref}: {sel} not in {target}'
                        target = target[sel]

                    self._refs[ref] = target
                else:
                    for v in n.values():
                        visit(v)

            return n
        return visit(schema)

    def _generate_union_rule(self, name, alt_schemas):
        return ' | '.join((
            self.visit(alt_schema, f'{name}{"-" if name else "alternative-"}{i}')
            for i, alt_schema in enumerate(alt_schemas)
        ))

    def _visit_pattern(self, pattern, name):
        '''
            Transforms a regular expression pattern into a GBNF rule.

            Input: https://json-schema.org/understanding-json-schema/reference/regular_expressions
            Output: https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md

            Unsupported features: negative/positive lookaheads, greedy/non-greedy modifiers.

            Mostly a 1:1 translation, except for {x} / {x,} / {x,y} quantifiers for which
            we define sub-rules to keep the output lean.
        '''

        assert pattern.startswith('^') and pattern.endswith('$'), 'Pattern must start with "^" and end with "$"'
        pattern = pattern[1:-1]
        sub_rule_ids = {}

        i = 0
        length = len(pattern)

        def to_rule(s: Tuple[str, bool]) -> str:
            (txt, is_literal) = s
            return "\"" + txt + "\"" if is_literal else txt

        def transform() -> Tuple[str, bool]:
            '''
                Parse a unit at index i (advancing it), and return its string representation + whether it's a literal.
            '''
            nonlocal i
            nonlocal pattern
            nonlocal sub_rule_ids

            start = i
            # For each component of this sequence, store its string representation and whether it's a literal.
            # We only need a flat structure here to apply repetition operators to the last item, and
            # to merge literals at the and (we're parsing grouped ( sequences ) recursively and don't treat '|' specially
            # (GBNF's syntax is luckily very close to regular expressions!)
            seq: list[Tuple[str, bool]] = []

            def get_dot():
                if self._dotall:
                    rule = DOTALL
                else:
                    # Accept any character... except \n and \r line break chars (\x0A and \xOD)
                    rule = DOT
                return self._add_rule(f'dot', rule)

            def join_seq():
                nonlocal seq
                ret = []
                for is_literal, g in groupby(seq, lambda x: x[1]):
                    if is_literal:
                        ret.append((''.join(x[0] for x in g), True))
                    else:
                        ret.extend(g)
                if len(ret) == 1:
                    return ret[0]
                return (' '.join(to_rule(x) for x in seq), False)

            while i < length:
                c = pattern[i]
                if c == '.':
                    seq.append((get_dot(), False))
                    i += 1
                elif c == '(':
                    i += 1
                    if i < length:
                        assert pattern[i] != '?', f'Unsupported pattern syntax "{pattern[i]}" at index {i} of /{pattern}/'
                    seq.append((f'({to_rule(transform())})', False))
                elif c == ')':
                    i += 1
                    assert start > 0 and pattern[start-1] == '(', f'Unbalanced parentheses; start = {start}, i = {i}, pattern = {pattern}'
                    return join_seq()
                elif c == '[':
                    square_brackets = c
                    i += 1
                    while i < length and pattern[i] != ']':
                        if pattern[i] == '\\':
                            square_brackets += pattern[i:i+2]
                            i += 2
                        else:
                            square_brackets += pattern[i]
                            i += 1
                    assert i < length, f'Unbalanced square brackets; start = {start}, i = {i}, pattern = {pattern}'
                    square_brackets += ']'
                    i += 1
                    seq.append((square_brackets, False))
                elif c == '|':
                    seq.append(('|', False))
                    i += 1
                elif c in ('*', '+', '?'):
                    seq[-1] = (to_rule(seq[-1]) + c, False)
                    i += 1
                elif c == '{':
                    curly_brackets = c
                    i += 1
                    while i < length and pattern[i] != '}':
                        curly_brackets += pattern[i]
                        i += 1
                    assert i < length, f'Unbalanced curly brackets; start = {start}, i = {i}, pattern = {pattern}'
                    curly_brackets += '}'
                    i += 1
                    nums = [s.strip() for s in curly_brackets[1:-1].split(',')]
                    min_times = 0
                    max_times = None
                    try:
                        if len(nums) == 1:
                            min_times = int(nums[0])
                            max_times = min_times
                        else:
                            assert len(nums) == 2
                            min_times = int(nums[0]) if nums[0] else 0
                            max_times = int(nums[1]) if nums[1] else None
                    except ValueError:
                        raise ValueError(f'Invalid quantifier {curly_brackets} in /{pattern}/')

                    (sub, sub_is_literal) = seq[-1]

                    if not sub_is_literal:
                        id = sub_rule_ids.get(sub)
                        if id is None:
                            id = self._add_rule(f'{name}-{len(sub_rule_ids) + 1}', sub)
                            sub_rule_ids[sub] = id
                        sub = id

                    seq[-1] = (_build_repetition(f'"{sub}"' if sub_is_literal else sub, min_times, max_times, item_rule_is_literal=sub_is_literal), False)
                else:
                    literal = ''
                    while i < length:
                        if pattern[i] == '\\' and i < length - 1:
                            next = pattern[i + 1]
                            if next in ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS:
                                i += 1
                                literal += pattern[i]
                                i += 1
                            else:
                                literal += pattern[i:i+2]
                                i += 2
                        elif pattern[i] == '"' and not self._raw_pattern:
                            literal += '\\"'
                            i += 1
                        elif pattern[i] not in NON_LITERAL_SET and \
                                (i == length - 1 or literal == '' or pattern[i+1] == '.' or pattern[i+1] not in NON_LITERAL_SET):
                            literal += pattern[i]
                            i += 1
                        else:
                            break
                    if literal:
                        seq.append((literal, True))

            return join_seq()

        return self._add_rule(
            name,
            to_rule(transform()) if self._raw_pattern \
                else "\"\\\"\" " + to_rule(transform()) + " \"\\\"\" space")


    def _resolve_ref(self, ref):
        ref_name = ref.split('/')[-1]
        if ref_name not in self._rules and ref not in self._refs_being_resolved:
            self._refs_being_resolved.add(ref)
            resolved = self._refs[ref]
            ref_name = self.visit(resolved, ref_name)
            self._refs_being_resolved.remove(ref)
        return ref_name

    def _generate_constant_rule(self, value):
        return self._format_literal(json.dumps(value))

    def visit(self, schema, name):
        schema_type = schema.get('type')
        schema_format = schema.get('format')
        rule_name = name + '-' if name in RESERVED_NAMES else name or 'root'

        if (ref := schema.get('$ref')) is not None:
            return self._add_rule(rule_name, self._resolve_ref(ref))

        elif 'oneOf' in schema or 'anyOf' in schema:
            return self._add_rule(rule_name, self._generate_union_rule(name, schema.get('oneOf') or schema['anyOf']))

        elif isinstance(schema_type, list):
            return self._add_rule(rule_name, self._generate_union_rule(name, [{'type': t} for t in schema_type]))

        elif 'const' in schema:
            return self._add_rule(rule_name, self._generate_constant_rule(schema['const']))

        elif 'enum' in schema:
            rule = ' | '.join((self._generate_constant_rule(v) for v in schema['enum']))
            return self._add_rule(rule_name, rule)

        elif schema_type in (None, 'object') and \
             ('properties' in schema or \
              ('additionalProperties' in schema and schema['additionalProperties'] is not True)):
            required = set(schema.get('required', []))
            properties = list(schema.get('properties', {}).items())
            return self._add_rule(rule_name, self._build_object_rule(properties, required, name, schema.get('additionalProperties')))

        elif schema_type in (None, 'object') and 'allOf' in schema:
            required = set()
            properties = []
            hybrid_name = name
            def add_component(comp_schema, is_required):
                if (ref := comp_schema.get('$ref')) is not None:
                    comp_schema = self._refs[ref]

                if 'properties' in comp_schema:
                    for prop_name, prop_schema in comp_schema['properties'].items():
                        properties.append((prop_name, prop_schema))
                        if is_required:
                            required.add(prop_name)

            for t in schema['allOf']:
                if 'anyOf' in t:
                    for tt in t['anyOf']:
                        add_component(tt, is_required=False)
                else:
                    add_component(t, is_required=True)

            return self._add_rule(rule_name, self._build_object_rule(properties, required, hybrid_name, additional_properties=[]))

        elif schema_type in (None, 'array') and ('items' in schema or 'prefixItems' in schema):
            items = schema.get('items') or schema['prefixItems']
            if isinstance(items, list):
                return self._add_rule(
                    rule_name,
                    '"[" space ' +
                    ' "," space '.join(
                        self.visit(item, f'{name}{"-" if name else ""}tuple-{i}')
                        for i, item in enumerate(items)) +
                    ' "]" space')
            else:
                item_rule_name = self.visit(items, f'{name}{"-" if name else ""}item')
                min_items = schema.get("minItems", 0)
                max_items = schema.get("maxItems")
                return self._add_rule(rule_name, '"[" space ' + _build_repetition(item_rule_name, min_items, max_items, separator_rule='"," space') + ' "]" space')

        elif schema_type in (None, 'string') and 'pattern' in schema:
            return self._visit_pattern(schema['pattern'], rule_name)

        elif schema_type in (None, 'string') and re.match(r'^uuid[1-5]?$', schema_format or ''):
            return self._add_primitive(
                'root' if rule_name == 'root' else schema_format,
                PRIMITIVE_RULES['uuid']
            )

        elif schema_type in (None, 'string') and f'{schema_format}-string' in STRING_FORMAT_RULES:
            prim_name = f'{schema_format}-string'
            return self._add_rule(rule_name, self._add_primitive(prim_name, STRING_FORMAT_RULES[prim_name]))

        elif schema_type == 'string' and ('minLength' in schema or 'maxLength' in schema):
            char_rule = self._add_primitive('char', PRIMITIVE_RULES['char'])
            min_len = schema.get('minLength', 0)
            max_len = schema.get('maxLength')

            return self._add_rule(rule_name, r'"\"" ' + _build_repetition(char_rule, min_len, max_len) + r' "\"" space')

        elif (schema_type == 'object') or (len(schema) == 0):
            return self._add_rule(rule_name, self._add_primitive('object', PRIMITIVE_RULES['object']))

        else:
            assert schema_type in PRIMITIVE_RULES, f'Unrecognized schema: {schema}'
            # TODO: support minimum, maximum, exclusiveMinimum, exclusiveMaximum at least for zero
            return self._add_primitive('root' if rule_name == 'root' else schema_type, PRIMITIVE_RULES[schema_type])

    def _add_primitive(self, name: str, rule: BuiltinRule):
        n = self._add_rule(name, rule.content)

        for dep in rule.deps:
            dep_rule = PRIMITIVE_RULES.get(dep) or STRING_FORMAT_RULES.get(dep)
            assert dep_rule, f'Rule {dep} not known'
            if dep not in self._rules:
                self._add_primitive(dep, dep_rule)
        return n

    def _build_object_rule(self, properties: List[Tuple[str, Any]], required: Set[str], name: str, additional_properties: Union[bool, Any]):
        prop_order = self._prop_order
        # sort by position in prop_order (if specified) then by original order
        sorted_props = [kv[0] for _, kv in sorted(enumerate(properties), key=lambda ikv: (prop_order.get(ikv[1][0], len(prop_order)), ikv[0]))]

        prop_kv_rule_names = {}
        for prop_name, prop_schema in properties:
            prop_rule_name = self.visit(prop_schema, f'{name}{"-" if name else ""}{prop_name}')
            prop_kv_rule_names[prop_name] = self._add_rule(
                f'{name}{"-" if name else ""}{prop_name}-kv',
                fr'{self._format_literal(json.dumps(prop_name))} space ":" space {prop_rule_name}'
            )
        required_props = [k for k in sorted_props if k in required]
        optional_props = [k for k in sorted_props if k not in required]

        if additional_properties == True or isinstance(additional_properties, dict):
            sub_name = f'{name}{"-" if name else ""}additional'
            value_rule = self.visit({} if additional_properties == True else additional_properties, f'{sub_name}-value')
            prop_kv_rule_names["*"] = self._add_rule(
                f'{sub_name}-kv',
                self._add_primitive('string', PRIMITIVE_RULES['string']) + f' ":" space {value_rule}'
            )
            optional_props.append("*")

        rule = '"{" space '
        rule += ' "," space '.join(prop_kv_rule_names[k] for k in required_props)

        if optional_props:
            rule += ' ('
            if required_props:
                rule += ' "," space ( '

            def get_recursive_refs(ks, first_is_optional):
                [k, *rest] = ks
                kv_rule_name = prop_kv_rule_names[k]
                if k == '*':
                    res = self._add_rule(
                        f'{name}{"-" if name else ""}additional-kvs',
                        f'{kv_rule_name} ( "," space ' + kv_rule_name + ' )*'
                    )
                elif first_is_optional:
                    res = f'( "," space {kv_rule_name} )?'
                else:
                    res = kv_rule_name
                if len(rest) > 0:
                    res += ' ' + self._add_rule(
                        f'{name}{"-" if name else ""}{k}-rest',
                        get_recursive_refs(rest, first_is_optional=True)
                    )
                return res

            rule += ' | '.join(
                get_recursive_refs(optional_props[i:], first_is_optional=False)
                for i in range(len(optional_props))
            )
            if required_props:
                rule += ' )'
            rule += ' )?'

        rule += ' "}" space'

        return rule

    def format_grammar(self):
        return '\n'.join(
            f'{name} ::= {rule}'
            for name, rule in sorted(self._rules.items(), key=lambda kv: kv[0])
        )
def json_schema_to_gbnf(schema: str, prop_order: Optional[List[str]] = None):
    prop_order = prop_order or []
    schema = json.loads(schema)
    prop_order = {name: idx for idx, name in enumerate(prop_order)}
    converter = SchemaConverter(prop_order=prop_order, allow_fetch=False, dotall=False, raw_pattern=False)
    schema = converter.resolve_refs(schema, "stdin")
    converter.visit(schema, "")
    return converter.format_grammar()