Spaces:
Sleeping
Sleeping
Dial model back
Browse files
main.py
CHANGED
@@ -13,8 +13,8 @@ from starlette.middleware.cors import CORSMiddleware
|
|
13 |
|
14 |
load_dotenv()
|
15 |
|
16 |
-
model = BlenderbotForConditionalGeneration.from_pretrained("facebook/blenderbot-
|
17 |
-
tokenizer = BlenderbotTokenizer.from_pretrained("facebook/blenderbot-
|
18 |
class UserMSGRequest(BaseModel):
|
19 |
message: str
|
20 |
|
@@ -51,7 +51,7 @@ async def root(utterance: UserMSGRequest):
|
|
51 |
inputs = tokenizer(utterance.message, return_tensors = "pt")
|
52 |
results = model.generate(**inputs)
|
53 |
response = tokenizer.batch_decode(results, skip_special_tokens=True)[0]
|
54 |
-
r.set(utterance.message, response,
|
55 |
return response
|
56 |
|
57 |
if __name__ == '__main__':
|
|
|
13 |
|
14 |
load_dotenv()
|
15 |
|
16 |
+
model = BlenderbotForConditionalGeneration.from_pretrained("facebook/blenderbot-400M-distill", cache_dir="new_cache_dir/")
|
17 |
+
tokenizer = BlenderbotTokenizer.from_pretrained("facebook/blenderbot-400M-distill", cache_dir="new_cache_dir/")
|
18 |
class UserMSGRequest(BaseModel):
|
19 |
message: str
|
20 |
|
|
|
51 |
inputs = tokenizer(utterance.message, return_tensors = "pt")
|
52 |
results = model.generate(**inputs)
|
53 |
response = tokenizer.batch_decode(results, skip_special_tokens=True)[0]
|
54 |
+
r.set(utterance.message, response, 250)
|
55 |
return response
|
56 |
|
57 |
if __name__ == '__main__':
|