Delete handler.py
Browse files- handler.py +0 -48
handler.py
DELETED
@@ -1,48 +0,0 @@
|
|
1 |
-
from typing import Dict, List, Any
|
2 |
-
from transformers import GPT2LMHeadModel, AutoTokenizer
|
3 |
-
import torch
|
4 |
-
|
5 |
-
class EndpointHandler():
|
6 |
-
def __init__(self, path="mrpintime/GPTPoem"):
|
7 |
-
if torch.cuda.is_available():
|
8 |
-
self.device = "cuda"
|
9 |
-
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
10 |
-
self.device = "mps"
|
11 |
-
else:
|
12 |
-
self.device = "cpu"
|
13 |
-
self.device_type = "cuda" if self.device.startswith("cuda") else "cpu"
|
14 |
-
|
15 |
-
self.tokenizer = AutoTokenizer.from_pretrained('bolbolzaban/gpt2-persian', device=self.device)
|
16 |
-
self.model = GPT2LMHeadModel.from_pretrained('mrpintime/GPTPoem').to(self.device)
|
17 |
-
|
18 |
-
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
19 |
-
"""
|
20 |
-
data args:
|
21 |
-
inputs (:obj: `str` | `PIL.Image` | `np.array`)
|
22 |
-
kwargs
|
23 |
-
Return:
|
24 |
-
A :obj:`list` | `dict`: will be serialized and returned
|
25 |
-
"""
|
26 |
-
start_ids = self.tokenizer.encode(data['inputs'], add_special_tokens=False)
|
27 |
-
idx = (torch.tensor(start_ids)[None, ...])
|
28 |
-
# run generation
|
29 |
-
samples = []
|
30 |
-
with torch.autocast(device_type=self.device_type, dtype=torch.bfloat16):
|
31 |
-
for k in range(int(data['parameters']['num_samples'])):
|
32 |
-
for _ in range(int(data['parameters']['max_new_tokens'])):
|
33 |
-
# forward the model to get the logits for the index in the sequence
|
34 |
-
logits, _ = self.model(idx)
|
35 |
-
# pluck the logits at the final step and scale by desired temperature
|
36 |
-
logits = logits[:, -1, :] / data['parameters']['temperature']
|
37 |
-
# optionally crop the logits to only the top k options
|
38 |
-
if int(data['parameters']['top_k']) is not None:
|
39 |
-
v, _ = torch.topk(logits, min(int(data['parameters']['top_k']), logits.size(-1)))
|
40 |
-
logits[logits < v[:, [-1]]] = -float('Inf')
|
41 |
-
# apply softmax to convert logits to (normalized) probabilities
|
42 |
-
probs = torch.nn.functional.softmax(logits, dim=-1)
|
43 |
-
# sample from the distribution
|
44 |
-
idx_next = torch.multinomial(probs, num_samples=1)
|
45 |
-
# append sampled index to the running sequence and continue
|
46 |
-
idx = torch.cat((idx, idx_next), dim=1)
|
47 |
-
samples.append(idx)
|
48 |
-
return samples
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|