Spaces:
Runtime error
Runtime error
from transformers import CLIPTokenizer, AutoTokenizer | |
from ..models import ModelManager | |
import os | |
def tokenize_long_prompt(tokenizer, prompt): | |
# Get model_max_length from self.tokenizer | |
length = tokenizer.model_max_length | |
# To avoid the warning. set self.tokenizer.model_max_length to +oo. | |
tokenizer.model_max_length = 99999999 | |
# Tokenize it! | |
input_ids = tokenizer(prompt, return_tensors="pt").input_ids | |
# Determine the real length. | |
max_length = (input_ids.shape[1] + length - 1) // length * length | |
# Restore tokenizer.model_max_length | |
tokenizer.model_max_length = length | |
# Tokenize it again with fixed length. | |
input_ids = tokenizer( | |
prompt, | |
return_tensors="pt", | |
padding="max_length", | |
max_length=max_length, | |
truncation=True | |
).input_ids | |
# Reshape input_ids to fit the text encoder. | |
num_sentence = input_ids.shape[1] // length | |
input_ids = input_ids.reshape((num_sentence, length)) | |
return input_ids | |
class BeautifulPrompt: | |
def __init__(self, tokenizer_path=None, model=None): | |
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) | |
self.model = model | |
self.template = 'Instruction: Give a simple description of the image to generate a drawing prompt.\nInput: {raw_prompt}\nOutput:' | |
def __call__(self, raw_prompt): | |
model_input = self.template.format(raw_prompt=raw_prompt) | |
input_ids = self.tokenizer.encode(model_input, return_tensors='pt').to(self.model.device) | |
outputs = self.model.generate( | |
input_ids, | |
max_new_tokens=384, | |
do_sample=True, | |
temperature=0.9, | |
top_k=50, | |
top_p=0.95, | |
repetition_penalty=1.1, | |
num_return_sequences=1 | |
) | |
prompt = raw_prompt + ", " + self.tokenizer.batch_decode( | |
outputs[:, input_ids.size(1):], | |
skip_special_tokens=True | |
)[0].strip() | |
return prompt | |
class Translator: | |
def __init__(self, tokenizer_path=None, model=None): | |
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) | |
self.model = model | |
def __call__(self, prompt): | |
input_ids = self.tokenizer.encode(prompt, return_tensors='pt').to(self.model.device) | |
output_ids = self.model.generate(input_ids) | |
prompt = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0] | |
return prompt | |
class Prompter: | |
def __init__(self): | |
self.tokenizer: CLIPTokenizer = None | |
self.keyword_dict = {} | |
self.translator: Translator = None | |
self.beautiful_prompt: BeautifulPrompt = None | |
def load_textual_inversion(self, textual_inversion_dict): | |
self.keyword_dict = {} | |
additional_tokens = [] | |
for keyword in textual_inversion_dict: | |
tokens, _ = textual_inversion_dict[keyword] | |
additional_tokens += tokens | |
self.keyword_dict[keyword] = " " + " ".join(tokens) + " " | |
if self.tokenizer is not None: | |
self.tokenizer.add_tokens(additional_tokens) | |
def load_beautiful_prompt(self, model, model_path): | |
model_folder = os.path.dirname(model_path) | |
self.beautiful_prompt = BeautifulPrompt(tokenizer_path=model_folder, model=model) | |
if model_folder.endswith("v2"): | |
self.beautiful_prompt.template = """Converts a simple image description into a prompt. \ | |
Prompts are formatted as multiple related tags separated by commas, plus you can use () to increase the weight, [] to decrease the weight, \ | |
or use a number to specify the weight. You should add appropriate words to make the images described in the prompt more aesthetically pleasing, \ | |
but make sure there is a correlation between the input and output.\n\ | |
### Input: {raw_prompt}\n### Output:""" | |
def load_translator(self, model, model_path): | |
model_folder = os.path.dirname(model_path) | |
self.translator = Translator(tokenizer_path=model_folder, model=model) | |
def load_from_model_manager(self, model_manager: ModelManager): | |
self.load_textual_inversion(model_manager.textual_inversion_dict) | |
if "translator" in model_manager.model: | |
self.load_translator(model_manager.model["translator"], model_manager.model_path["translator"]) | |
if "beautiful_prompt" in model_manager.model: | |
self.load_beautiful_prompt(model_manager.model["beautiful_prompt"], model_manager.model_path["beautiful_prompt"]) | |
def add_textual_inversion_tokens(self, prompt): | |
for keyword in self.keyword_dict: | |
if keyword in prompt: | |
prompt = prompt.replace(keyword, self.keyword_dict[keyword]) | |
return prompt | |
def del_textual_inversion_tokens(self, prompt): | |
for keyword in self.keyword_dict: | |
if keyword in prompt: | |
prompt = prompt.replace(keyword, "") | |
return prompt | |
def process_prompt(self, prompt, positive=True, require_pure_prompt=False): | |
if isinstance(prompt, list): | |
prompt = [self.process_prompt(prompt_, positive=positive, require_pure_prompt=require_pure_prompt) for prompt_ in prompt] | |
if require_pure_prompt: | |
prompt, pure_prompt = [i[0] for i in prompt], [i[1] for i in prompt] | |
return prompt, pure_prompt | |
else: | |
return prompt | |
prompt, pure_prompt = self.add_textual_inversion_tokens(prompt), self.del_textual_inversion_tokens(prompt) | |
if positive and self.translator is not None: | |
prompt = self.translator(prompt) | |
print(f"Your prompt is translated: \"{prompt}\"") | |
if positive and self.beautiful_prompt is not None: | |
prompt = self.beautiful_prompt(prompt) | |
print(f"Your prompt is refined by BeautifulPrompt: \"{prompt}\"") | |
if require_pure_prompt: | |
return prompt, pure_prompt | |
else: | |
return prompt | |