File size: 2,759 Bytes
79e3511
 
d80fbc9
79e3511
 
 
 
 
 
 
 
d80fbc9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79e3511
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d80fbc9
79e3511
 
 
 
 
 
d80fbc9
 
 
 
 
79e3511
d80fbc9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import logging
import json
from typing import Dict, List
from pydantic.dataclasses import dataclass

from transformers import PreTrainedTokenizerFast
from tokenizers.decoders import Decoder

logger = logging.getLogger(__name__)


# fmt: off
# https://huggingface.co/docs/transformers/main/en/chat_templating
PROMPT_TEMPLATE = (
    "{{ '<|bos|>' }}" 
    
    "{{ '<rating>' }}"
    "{% if 'rating' not in messages or messages['rating'] is none %}"
    "{{ 'rating:sfw, rating:general' }}"
    "{% else %}"
    "{{ messages['rating'] }}"
    "{% endif %}"
    "{{ '</rating>' }}"

    "{{ '<copyright>' }}"
    "{% if 'copyright' not in messages or messages['copyright'] is none %}"
    "{{ '' }}"
    "{% else %}"
    "{{ messages['copyright'] }}"
    "{% endif %}"
    "{{ '</copyright>' }}"

    "{{ '<character>' }}"
    "{% if 'character' not in messages or messages['character'] is none %}"
    "{{ '' }}"
    "{% else %}"
    "{{ messages['character'] }}"
    "{% endif %}"
    "{{ '</character>' }}"

    "{{ '<general>' }}"
    "{% if 'general' not in messages or messages['general'] is none %}"
    "{{ '' }}"
    "{% else %}"
    "{{ messages['general'] }}"
    "{% endif %}"
).strip()
# fmt: on


@dataclass
class Category:
    name: str
    bos_token_id: int
    eos_token_id: int


@dataclass
class TagCategoryConfig:
    categories: Dict[str, Category]
    category_to_token_ids: Dict[str, List[int]]


def load_tag_category_config(config_json: str):
    with open(config_json, "rb") as file:
        config: TagCategoryConfig = TagCategoryConfig(**json.loads(file.read()))

    return config


class DartDecoder:
    def __init__(self, special_tokens: List[str]):
        self.special_tokens = list(special_tokens)

    def decode_chain(self, tokens: List[str]) -> List[str]:
        new_tokens = []
        is_specials = []

        for i, token in enumerate(tokens):
            is_specials.append(token in self.special_tokens)

            if i == 0:
                new_tokens.append(token)
                continue

            # this token or previous token is special
            if is_specials[i] or is_specials[i - 1]:
                new_tokens.append(token)
                continue

            new_tokens.append(f", {token}")

        return new_tokens


class DartTokenizer(PreTrainedTokenizerFast):
    """Dart tokenizer"""

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

        self._tokenizer.decoder = Decoder.custom(  # type: ignore
            DartDecoder(list(self.get_added_vocab().keys()))
        )

    @property
    def default_chat_template(self):
        """
        Danbooru Tags Transformer uses special format prompt to generate danbooru tags.
        """

        return PROMPT_TEMPLATE