|
from typing import List, Optional |
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig |
|
|
|
|
|
class PromptModifier: |
|
def __init__(self, num_of_sequences: Optional[int] = 4): |
|
self.__blacklist = {"alphonse mucha": "", "adolphe bouguereau": ""} |
|
self.__num_of_sequences = num_of_sequences |
|
|
|
def load(self): |
|
self.prompter_model = AutoModelForCausalLM.from_pretrained( |
|
"Gustavosta/MagicPrompt-Stable-Diffusion" |
|
) |
|
self.prompter_tokenizer = AutoTokenizer.from_pretrained( |
|
"Gustavosta/MagicPrompt-Stable-Diffusion" |
|
) |
|
self.prompter_tokenizer.pad_token = self.prompter_tokenizer.eos_token |
|
self.prompter_tokenizer.padding_side = "left" |
|
|
|
def modify(self, text: str) -> List[str]: |
|
eos_id = self.prompter_tokenizer.eos_token_id |
|
|
|
|
|
|
|
|
|
|
|
generation_config = GenerationConfig( |
|
do_sample=False, |
|
max_new_tokens=75, |
|
num_beams=4, |
|
num_return_sequences=self.__num_of_sequences, |
|
eos_token_id=eos_id, |
|
pad_token_id=eos_id, |
|
length_penalty=-1.0, |
|
) |
|
|
|
input_ids = self.prompter_tokenizer(text.strip(), return_tensors="pt").input_ids |
|
outputs = self.prompter_model.generate( |
|
input_ids, generation_config=generation_config |
|
) |
|
output_texts = self.prompter_tokenizer.batch_decode( |
|
outputs, skip_special_tokens=True |
|
) |
|
output_texts = self.__patch_blacklist_words(output_texts) |
|
return output_texts |
|
|
|
def __patch_blacklist_words(self, texts: List[str]): |
|
def replace_all(text, dic): |
|
for i, j in dic.items(): |
|
text = text.replace(i, j) |
|
return text |
|
|
|
return [replace_all(text, self.__blacklist) for text in texts] |
|
|