mrpintime commited on
Commit
fc2d4ef
·
verified ·
1 Parent(s): 318bd50

Delete handler.py

Browse files
Files changed (1) hide show
  1. handler.py +0 -48
handler.py DELETED
@@ -1,48 +0,0 @@
1
- from typing import Dict, List, Any
2
- from transformers import GPT2LMHeadModel, AutoTokenizer
3
- import torch
4
-
5
- class EndpointHandler():
6
- def __init__(self, path="mrpintime/GPTPoem"):
7
- if torch.cuda.is_available():
8
- self.device = "cuda"
9
- elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
10
- self.device = "mps"
11
- else:
12
- self.device = "cpu"
13
- self.device_type = "cuda" if self.device.startswith("cuda") else "cpu"
14
-
15
- self.tokenizer = AutoTokenizer.from_pretrained('bolbolzaban/gpt2-persian', device=self.device)
16
- self.model = GPT2LMHeadModel.from_pretrained('mrpintime/GPTPoem').to(self.device)
17
-
18
- def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
19
- """
20
- data args:
21
- inputs (:obj: `str` | `PIL.Image` | `np.array`)
22
- kwargs
23
- Return:
24
- A :obj:`list` | `dict`: will be serialized and returned
25
- """
26
- start_ids = self.tokenizer.encode(data['inputs'], add_special_tokens=False)
27
- idx = (torch.tensor(start_ids)[None, ...])
28
- # run generation
29
- samples = []
30
- with torch.autocast(device_type=self.device_type, dtype=torch.bfloat16):
31
- for k in range(int(data['parameters']['num_samples'])):
32
- for _ in range(int(data['parameters']['max_new_tokens'])):
33
- # forward the model to get the logits for the index in the sequence
34
- logits, _ = self.model(idx)
35
- # pluck the logits at the final step and scale by desired temperature
36
- logits = logits[:, -1, :] / data['parameters']['temperature']
37
- # optionally crop the logits to only the top k options
38
- if int(data['parameters']['top_k']) is not None:
39
- v, _ = torch.topk(logits, min(int(data['parameters']['top_k']), logits.size(-1)))
40
- logits[logits < v[:, [-1]]] = -float('Inf')
41
- # apply softmax to convert logits to (normalized) probabilities
42
- probs = torch.nn.functional.softmax(logits, dim=-1)
43
- # sample from the distribution
44
- idx_next = torch.multinomial(probs, num_samples=1)
45
- # append sampled index to the running sequence and continue
46
- idx = torch.cat((idx, idx_next), dim=1)
47
- samples.append(idx)
48
- return samples