Ehren12 commited on
Commit
5e2686d
·
1 Parent(s): 97d741c

Dial model back

Browse files
Files changed (1) hide show
  1. main.py +3 -3
main.py CHANGED
@@ -13,8 +13,8 @@ from starlette.middleware.cors import CORSMiddleware
13
 
14
  load_dotenv()
15
 
16
- model = BlenderbotForConditionalGeneration.from_pretrained("facebook/blenderbot-3B", cache_dir="new_cache_dir/")
17
- tokenizer = BlenderbotTokenizer.from_pretrained("facebook/blenderbot-3B", cache_dir="new_cache_dir/")
18
  class UserMSGRequest(BaseModel):
19
  message: str
20
 
@@ -51,7 +51,7 @@ async def root(utterance: UserMSGRequest):
51
  inputs = tokenizer(utterance.message, return_tensors = "pt")
52
  results = model.generate(**inputs)
53
  response = tokenizer.batch_decode(results, skip_special_tokens=True)[0]
54
- r.set(utterance.message, response, 300)
55
  return response
56
 
57
  if __name__ == '__main__':
 
13
 
14
  load_dotenv()
15
 
16
+ model = BlenderbotForConditionalGeneration.from_pretrained("facebook/blenderbot-400M-distill", cache_dir="new_cache_dir/")
17
+ tokenizer = BlenderbotTokenizer.from_pretrained("facebook/blenderbot-400M-distill", cache_dir="new_cache_dir/")
18
  class UserMSGRequest(BaseModel):
19
  message: str
20
 
 
51
  inputs = tokenizer(utterance.message, return_tensors = "pt")
52
  results = model.generate(**inputs)
53
  response = tokenizer.batch_decode(results, skip_special_tokens=True)[0]
54
+ r.set(utterance.message, response, 250)
55
  return response
56
 
57
  if __name__ == '__main__':