CM2000112 / internals /pipelines /prompt_modifier.py
jayparmr's picture
Upload 118 files
19b3da3
raw
history blame
2.02 kB
from typing import List, Optional
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
class PromptModifier:
def __init__(self, num_of_sequences: Optional[int] = 4):
self.__blacklist = {"alphonse mucha": "", "adolphe bouguereau": ""}
self.__num_of_sequences = num_of_sequences
def load(self):
self.prompter_model = AutoModelForCausalLM.from_pretrained(
"Gustavosta/MagicPrompt-Stable-Diffusion"
)
self.prompter_tokenizer = AutoTokenizer.from_pretrained(
"Gustavosta/MagicPrompt-Stable-Diffusion"
)
self.prompter_tokenizer.pad_token = self.prompter_tokenizer.eos_token
self.prompter_tokenizer.padding_side = "left"
def modify(self, text: str) -> List[str]:
eos_id = self.prompter_tokenizer.eos_token_id
# restricted_words_list = ["octane", "cyber"]
# restricted_words_token_ids = prompter_tokenizer(
# restricted_words_list, add_special_tokens=False
# ).input_ids
generation_config = GenerationConfig(
do_sample=False,
max_new_tokens=75,
num_beams=4,
num_return_sequences=self.__num_of_sequences,
eos_token_id=eos_id,
pad_token_id=eos_id,
length_penalty=-1.0,
)
input_ids = self.prompter_tokenizer(text.strip(), return_tensors="pt").input_ids
outputs = self.prompter_model.generate(
input_ids, generation_config=generation_config
)
output_texts = self.prompter_tokenizer.batch_decode(
outputs, skip_special_tokens=True
)
output_texts = self.__patch_blacklist_words(output_texts)
return output_texts
def __patch_blacklist_words(self, texts: List[str]):
def replace_all(text, dic):
for i, j in dic.items():
text = text.replace(i, j)
return text
return [replace_all(text, self.__blacklist) for text in texts]