File size: 5,996 Bytes
fb4fac3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
from transformers import CLIPTokenizer, AutoTokenizer
from ..models import ModelManager
import os


def tokenize_long_prompt(tokenizer, prompt):
    # Get model_max_length from self.tokenizer
    length = tokenizer.model_max_length

    # To avoid the warning. set self.tokenizer.model_max_length to +oo.
    tokenizer.model_max_length = 99999999

    # Tokenize it!
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids

    # Determine the real length.
    max_length = (input_ids.shape[1] + length - 1) // length * length

    # Restore tokenizer.model_max_length
    tokenizer.model_max_length = length
    
    # Tokenize it again with fixed length.
    input_ids = tokenizer(
        prompt,
        return_tensors="pt",
        padding="max_length",
        max_length=max_length,
        truncation=True
    ).input_ids

    # Reshape input_ids to fit the text encoder.
    num_sentence = input_ids.shape[1] // length
    input_ids = input_ids.reshape((num_sentence, length))
    
    return input_ids


class BeautifulPrompt:
    def __init__(self, tokenizer_path=None, model=None):
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
        self.model = model
        self.template = 'Instruction: Give a simple description of the image to generate a drawing prompt.\nInput: {raw_prompt}\nOutput:'
    
    def __call__(self, raw_prompt):
        model_input = self.template.format(raw_prompt=raw_prompt)
        input_ids = self.tokenizer.encode(model_input, return_tensors='pt').to(self.model.device)
        outputs = self.model.generate(
            input_ids,
            max_new_tokens=384,
            do_sample=True,
            temperature=0.9,
            top_k=50,
            top_p=0.95,
            repetition_penalty=1.1,
            num_return_sequences=1
        )
        prompt = raw_prompt + ", " + self.tokenizer.batch_decode(
            outputs[:, input_ids.size(1):],
            skip_special_tokens=True
        )[0].strip()
        return prompt
    

class Translator:
    def __init__(self, tokenizer_path=None, model=None):
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
        self.model = model

    def __call__(self, prompt):
        input_ids = self.tokenizer.encode(prompt, return_tensors='pt').to(self.model.device)
        output_ids = self.model.generate(input_ids)
        prompt = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
        return prompt
    

class Prompter:
    def __init__(self):
        self.tokenizer: CLIPTokenizer = None
        self.keyword_dict = {}
        self.translator: Translator = None
        self.beautiful_prompt: BeautifulPrompt = None

    def load_textual_inversion(self, textual_inversion_dict):
        self.keyword_dict = {}
        additional_tokens = []
        for keyword in textual_inversion_dict:
            tokens, _ = textual_inversion_dict[keyword]
            additional_tokens += tokens
            self.keyword_dict[keyword] = " " + " ".join(tokens) + " "
        if self.tokenizer is not None:
            self.tokenizer.add_tokens(additional_tokens)

    def load_beautiful_prompt(self, model, model_path):
        model_folder = os.path.dirname(model_path)
        self.beautiful_prompt = BeautifulPrompt(tokenizer_path=model_folder, model=model)
        if model_folder.endswith("v2"):
            self.beautiful_prompt.template = """Converts a simple image description into a prompt. \
Prompts are formatted as multiple related tags separated by commas, plus you can use () to increase the weight, [] to decrease the weight, \
or use a number to specify the weight. You should add appropriate words to make the images described in the prompt more aesthetically pleasing, \
but make sure there is a correlation between the input and output.\n\
### Input: {raw_prompt}\n### Output:"""

    def load_translator(self, model, model_path):
        model_folder = os.path.dirname(model_path)
        self.translator = Translator(tokenizer_path=model_folder, model=model)

    def load_from_model_manager(self, model_manager: ModelManager):
        self.load_textual_inversion(model_manager.textual_inversion_dict)
        if "translator" in model_manager.model:
            self.load_translator(model_manager.model["translator"], model_manager.model_path["translator"])
        if "beautiful_prompt" in model_manager.model:
            self.load_beautiful_prompt(model_manager.model["beautiful_prompt"], model_manager.model_path["beautiful_prompt"])

    def add_textual_inversion_tokens(self, prompt):
        for keyword in self.keyword_dict:
            if keyword in prompt:
                prompt = prompt.replace(keyword, self.keyword_dict[keyword])
        return prompt

    def del_textual_inversion_tokens(self, prompt):
        for keyword in self.keyword_dict:
            if keyword in prompt:
                prompt = prompt.replace(keyword, "")
        return prompt

    def process_prompt(self, prompt, positive=True, require_pure_prompt=False):
        if isinstance(prompt, list):
            prompt = [self.process_prompt(prompt_, positive=positive, require_pure_prompt=require_pure_prompt) for prompt_ in prompt]
            if require_pure_prompt:
                prompt, pure_prompt = [i[0] for i in prompt], [i[1] for i in prompt]
                return prompt, pure_prompt
            else:
                return prompt
        prompt, pure_prompt = self.add_textual_inversion_tokens(prompt), self.del_textual_inversion_tokens(prompt)
        if positive and self.translator is not None:
            prompt = self.translator(prompt)
            print(f"Your prompt is translated: \"{prompt}\"")
        if positive and self.beautiful_prompt is not None:
            prompt = self.beautiful_prompt(prompt)
            print(f"Your prompt is refined by BeautifulPrompt: \"{prompt}\"")
        if require_pure_prompt:
            return prompt, pure_prompt
        else:
            return prompt