|
from transformers import AutoTokenizer, TFGPT2LMHeadModel, pipeline |
|
from transformers.pipelines import TextGenerationPipeline |
|
from typing import Union |
|
|
|
class QuoteGenerator(): |
|
def __init__(self, model_name:str='gruhit13/quote_generator_v1'): |
|
self.model_name = model_name |
|
self.quote_generator: TextGenerationPipeline |
|
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) |
|
self.model = TFGPT2LMHeadModel.from_pretrained(self.model_name) |
|
self.default_tags = 'love,life' |
|
print("Model has been loaded") |
|
|
|
def load_generator(self) -> None: |
|
self.quote_generator = pipeline('text-generation', model=self.model, tokenizer=self.tokenizer) |
|
print("Pipeline has been generated") |
|
|
|
def preprocess_tags(self, tags: Union[None, str] = None) -> str: |
|
if tags is None: |
|
tags = self.default_tags |
|
|
|
return self.tokenizer.bos_token + tags + '<bot>:' |
|
|
|
def generate_quote(self, tags:Union[None, str]=None, |
|
min_length: int=3, max_length:int=60, |
|
top_p:float=0.9, top_k:int=5): |
|
|
|
tags = self.preprocess_tags(tags) |
|
output = self.quote_generator(tags, min_length=min_length, max_length=max_length, |
|
temperature=1.0, top_k=5, top_p=top_p, early_stopping=True, |
|
num_beams=4) |
|
|
|
return output[0]['generated_text'] |
|
|
|
|