Spaces:
Runtime error
Runtime error
""" Functions for loading email generator.""" | |
import torch | |
import transformers | |
DEVICE = 0 if torch.cuda.is_available() else -1 | |
AVAILABLE_GENERATORS = [ | |
'pszemraj/opt-350m-email-generation', | |
'pszemraj/opt-350m-email-generation', | |
'postbot/gpt2-medium-emailgen', | |
'sagorsarker/emailgenerator' | |
] | |
DEFAULT_GENERATOR = 'pszemraj/opt-350m-email-generation' | |
class EmailGenerator: | |
""" Class that loads and wraps a HuggingFace email generation pipeline.""" | |
def __init__(self, model_tag: str) -> None: | |
""" Initialize HuggingFace email generation pipeline. | |
Args: | |
model_tag (str): Model name. | |
""" | |
self.tag = model_tag | |
self.generator = transformers.pipeline( | |
'text-generation', model_tag, | |
use_fast=True, do_sample=False, | |
device=DEVICE | |
) | |
def generate(self, prompt: str, max_tokens: int) -> str: | |
""" Generate a sample from a given prompt. | |
Args: | |
prompt (str): Prompting for email generator. | |
max_tokens (int): Maximum number of tokens to return. | |
Returns: | |
str: Generated text. | |
""" | |
output = self.generator(prompt, max_length=max_tokens) | |
return output[0]['generated_text'] | |
def __str__(self): | |
return f'EmailGenerator({self.tag})' | |
def set_global_generator(model_tag: str=DEFAULT_GENERATOR): | |
""" Set global parameter 'generator' as specified EmailGenerator.""" | |
global generator | |
generator = EmailGenerator(model_tag=model_tag) | |
def generator_exists(): | |
""" Check if global variable 'generator' has been defined.""" | |
return 'generator' in globals() | |
def generate_email(model_tag: str, prompt: str, max_tokens: int): | |
""" Check for generator and create prompt. | |
Initialize correct generator if incorrect generator or no generator is found. | |
""" | |
if not generator_exists() or generator.tag != model_tag: | |
set_global_generator(model_tag=model_tag) | |
return generator.generate(prompt, max_tokens=max_tokens) | |