File size: 2,070 Bytes
465f2e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
""" Functions for loading email generator."""
import torch
import transformers

DEVICE = 0 if torch.cuda.is_available() else -1

AVAILABLE_GENERATORS = [
    'pszemraj/opt-350m-email-generation',
    'pszemraj/opt-350m-email-generation',
    'postbot/gpt2-medium-emailgen',
    'sagorsarker/emailgenerator'
]
DEFAULT_GENERATOR = 'pszemraj/opt-350m-email-generation'


class EmailGenerator:
    """ Class that loads and wraps a HuggingFace email generation pipeline."""
    def __init__(self, model_tag: str) -> None:
        """ Initialize HuggingFace email generation pipeline.

        Args:
            model_tag (str): Model name.
        """
        self.tag = model_tag
        self.generator = transformers.pipeline(
            'text-generation', model_tag,
            use_fast=True, do_sample=False,
            device=DEVICE
        )
        
    def generate(self, prompt: str, max_tokens: int) -> str:
        """ Generate a sample from a given prompt.

        Args:
            prompt (str): Prompting for email generator.
            max_tokens (int): Maximum number of tokens to return.

        Returns:
            str: Generated text.
        """
        output = self.generator(prompt, max_length=max_tokens)
        return output[0]['generated_text']
    
    def __str__(self):
        return f'EmailGenerator({self.tag})'

def set_global_generator(model_tag: str=DEFAULT_GENERATOR):
    """ Set global parameter 'generator' as specified EmailGenerator."""
    global generator
    generator = EmailGenerator(model_tag=model_tag)
        
def generator_exists():
    """ Check if global variable 'generator' has been defined."""
    return 'generator' in globals()

def generate_email(model_tag: str, prompt: str, max_tokens: int):
    """ Check for generator and create prompt.
    Initialize correct generator if incorrect generator or no generator is found.
    """
    if not generator_exists() or generator.tag != model_tag:
        set_global_generator(model_tag=model_tag)
    
    return generator.generate(prompt, max_tokens=max_tokens)