from typing import List, Optional from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig class PromptModifier: def __init__(self, num_of_sequences: Optional[int] = 4): self.__blacklist = {"alphonse mucha": "", "adolphe bouguereau": ""} self.__num_of_sequences = num_of_sequences def load(self): self.prompter_model = AutoModelForCausalLM.from_pretrained( "Gustavosta/MagicPrompt-Stable-Diffusion" ) self.prompter_tokenizer = AutoTokenizer.from_pretrained( "Gustavosta/MagicPrompt-Stable-Diffusion" ) self.prompter_tokenizer.pad_token = self.prompter_tokenizer.eos_token self.prompter_tokenizer.padding_side = "left" def modify(self, text: str, num_of_sequences: Optional[int] = None) -> List[str]: eos_id = self.prompter_tokenizer.eos_token_id # restricted_words_list = ["octane", "cyber"] # restricted_words_token_ids = prompter_tokenizer( # restricted_words_list, add_special_tokens=False # ).input_ids num_of_sequences = num_of_sequences or self.__num_of_sequences generation_config = GenerationConfig( do_sample=False, max_new_tokens=75, num_beams=4, num_return_sequences=num_of_sequences, eos_token_id=eos_id, pad_token_id=eos_id, length_penalty=-1.0, ) input_ids = self.prompter_tokenizer(text.strip(), return_tensors="pt").input_ids outputs = self.prompter_model.generate( input_ids, generation_config=generation_config ) output_texts = self.prompter_tokenizer.batch_decode( outputs, skip_special_tokens=True ) output_texts = self.__patch_blacklist_words(output_texts) return output_texts def __patch_blacklist_words(self, texts: List[str]): def replace_all(text, dic): for i, j in dic.items(): text = text.replace(i, j) return text return [replace_all(text, self.__blacklist) for text in texts]