from transformers import AutoTokenizer, TFGPT2LMHeadModel, pipeline from transformers.pipelines import TextGenerationPipeline from typing import Union class QuoteGenerator(): def __init__(self, model_name:str='gruhit13/quote-generator-v2'): self.model_name = model_name self.quote_generator: TextGenerationPipeline self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) self.model = TFGPT2LMHeadModel.from_pretrained(self.model_name) self.default_tags = 'love,life' print("Model has been loaded") def load_generator(self) -> None: self.quote_generator = pipeline('text-generation', model=self.model, tokenizer=self.tokenizer) print("Pipeline has been generated") def preprocess_tags(self, tags: Union[None, str] = None) -> str: if tags is None: tags = self.default_tags return self.tokenizer.bos_token + tags + ':' def generate_quote(self, tags:Union[None, str], max_new_tokens: int, do_sample: bool, num_beams: int, top_k: int, top_p: float, temperature: float): tags = self.preprocess_tags(tags) output = self.quote_generator( tags, max_new_tokens=max_new_tokens, num_beams=num_beams, temperature=temperature, top_k=top_k, top_p=top_p, do_sample = do_sample ) return output[0]['generated_text']