Deepakvictor
commited on
Commit
•
3da6512
1
Parent(s):
32681ab
Create handler.py
Browse files- handler.py +17 -0
handler.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Any
|
2 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
3 |
+
import torch
|
4 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
5 |
+
class EndpointHandler:
|
6 |
+
def __init__(self, path=""):
|
7 |
+
self.tokenizer = AutoTokenizer.from_pretrained(path).to(device)
|
8 |
+
self.model = AutoModelForSeq2SeqLM.from_pretrained(path).to(device))
|
9 |
+
|
10 |
+
def __call__(self, data: str) -> str:
|
11 |
+
inp = self.tokenizer(data, return_tensors="pt")
|
12 |
+
for q in inp:
|
13 |
+
inp[q] = inp[q].to(device)
|
14 |
+
with torch.inference_mode():
|
15 |
+
out= model.generate(**inp)
|
16 |
+
final_output = tokenizer.batch_decode(out,skip_special_tokens=True)
|
17 |
+
return {"translation": final_output[0]}
|