gruhit-patel commited on
Commit
7bbb8c5
1 Parent(s): 8bc8944

Updated backend to handled text-generation parameters

Browse files
Files changed (2) hide show
  1. QuoteGenerator.py +12 -7
  2. main.py +19 -6
QuoteGenerator.py CHANGED
@@ -3,7 +3,7 @@ from transformers.pipelines import TextGenerationPipeline
3
  from typing import Union
4
 
5
  class QuoteGenerator():
6
- def __init__(self, model_name:str='gruhit13/quote_generator_v1'):
7
  self.model_name = model_name
8
  self.quote_generator: TextGenerationPipeline
9
  self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
@@ -21,14 +21,19 @@ class QuoteGenerator():
21
 
22
  return self.tokenizer.bos_token + tags + '<bot>:'
23
 
24
- def generate_quote(self, tags:Union[None, str]=None,
25
- min_length: int=3, max_length:int=60,
26
- top_p:float=0.9, top_k:int=5):
27
 
28
  tags = self.preprocess_tags(tags)
29
- output = self.quote_generator(tags, min_length=min_length, max_length=max_length,
30
- temperature=1.0, top_k=5, top_p=top_p, early_stopping=True,
31
- num_beams=4)
 
 
 
 
 
 
32
 
33
  return output[0]['generated_text']
34
 
 
3
  from typing import Union
4
 
5
  class QuoteGenerator():
6
+ def __init__(self, model_name:str='gruhit13/quote-generator-v2'):
7
  self.model_name = model_name
8
  self.quote_generator: TextGenerationPipeline
9
  self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
 
21
 
22
  return self.tokenizer.bos_token + tags + '<bot>:'
23
 
24
+ def generate_quote(self, tags:Union[None, str], max_new_tokens: int, do_sample: bool,
25
+ num_beams: int, top_k: int, top_p: float, temperature: float):
 
26
 
27
  tags = self.preprocess_tags(tags)
28
+ output = self.quote_generator(
29
+ tags,
30
+ max_new_tokens=max_new_tokens,
31
+ num_beams=num_beams,
32
+ temperature=temperature,
33
+ top_k=top_k,
34
+ top_p=top_p,
35
+ do_sample = do_sample
36
+ )
37
 
38
  return output[0]['generated_text']
39
 
main.py CHANGED
@@ -8,7 +8,6 @@ import os
8
 
9
  # API to key to validate the Referer
10
  API_KEY = os.getenv('API_KEY')
11
-
12
  oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
13
 
14
  # Function to check of the incoming API call is from valid host or not
@@ -20,7 +19,13 @@ def api_key_auth(api_key:str = Depends(oauth2_scheme)):
20
  )
21
 
22
  class QuoteRequest(BaseModel):
23
- tags: Union[None, str]
 
 
 
 
 
 
24
 
25
  app = FastAPI()
26
 
@@ -37,11 +42,19 @@ quote_generator = QuoteGenerator()
37
  quote_generator.load_generator()
38
 
39
  @app.post("/", dependencies=[Depends(api_key_auth)])
40
- def root(request: Request):
41
- return {"message": "This is the website for quote-generator"}
 
42
 
43
  @app.post("/generate_quote", dependencies=[Depends(api_key_auth)])
44
  def generate_quote(req: QuoteRequest):
45
- print("Tags: ", req.tags)
46
- generated_quote_oup = quote_generator.generate_quote(req.tags)
 
 
 
 
 
 
 
47
  return {'quote': generated_quote_oup}
 
8
 
9
  # API to key to validate the Referer
10
  API_KEY = os.getenv('API_KEY')
 
11
  oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
12
 
13
  # Function to check of the incoming API call is from valid host or not
 
19
  )
20
 
21
  class QuoteRequest(BaseModel):
22
+ tags: Union[None, str] = None
23
+ do_sample: bool = False
24
+ max_new_tokens: int = 16
25
+ num_beams: int = 1
26
+ top_k: int = 50
27
+ top_p: float = 1.0
28
+ temperature: float = 1.0
29
 
30
  app = FastAPI()
31
 
 
42
  quote_generator.load_generator()
43
 
44
  @app.post("/", dependencies=[Depends(api_key_auth)])
45
+ def root(request: QuoteRequest):
46
+ print("Incoming request\n", request.__dict__)
47
+ return {"quote": "<bot>:A beautiful quote generated by bot"}
48
 
49
  @app.post("/generate_quote", dependencies=[Depends(api_key_auth)])
50
  def generate_quote(req: QuoteRequest):
51
+ generated_quote_oup = quote_generator.generate_quote(
52
+ tags = req.tags,
53
+ max_new_tokens = req.max_new_tokens,
54
+ num_beams = req.num_beams,
55
+ temperature = req.temperature,
56
+ top_k = req.top_k,
57
+ top_p = req.top_p,
58
+ do_sample = req.do_sample
59
+ )
60
  return {'quote': generated_quote_oup}