taylorj94 commited on
Commit
2f4df5a
·
1 Parent(s): 0a82a06
Files changed (1) hide show
  1. handler.py +79 -16
handler.py CHANGED
@@ -1,23 +1,86 @@
1
- from typing import Dict, List, Any
2
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
 
5
  class EndpointHandler:
6
  def __init__(self, path=""):
7
- # load the model
8
  tokenizer = AutoTokenizer.from_pretrained(path)
9
- model = AutoModelForCausalLM.from_pretrained(path, device_map="auto")
10
- # create inference pipeline
11
- self.pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer)
 
 
 
 
 
 
 
 
 
 
12
 
13
- def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
 
 
 
 
14
  inputs = data.pop("inputs", data)
15
- parameters = data.pop("parameters", None)
16
-
17
- # pass inputs with all kwargs in data
18
- if parameters is not None:
19
- prediction = self.pipeline(inputs, **parameters)
20
- else:
21
- prediction = self.pipeline(inputs)
22
- # postprocess the prediction
23
- return prediction
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import (
3
+ AutoTokenizer,
4
+ AutoModelForCausalLM,
5
+ pipeline,
6
+ LogitsProcessor,
7
+ LogitsProcessorList
8
+ )
9
+ from typing import Any, List, Dict
10
+
11
+
12
+ class FixedVocabLogitsProcessor(LogitsProcessor):
13
+ """
14
+ A custom LogitsProcessor that restricts the vocabulary
15
+ to a fixed set of token IDs, masking out everything else.
16
+ """
17
+
18
+ def __init__(self, allowed_ids: set[int], fill_value=float('-inf')):
19
+ """
20
+ Args:
21
+ allowed_ids (set[int]): Token IDs allowed for generation.
22
+ fill_value (float): Value used to mask disallowed tokens, default -inf.
23
+ """
24
+ self.allowed_ids = allowed_ids
25
+ self.fill_value = fill_value
26
+
27
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
28
+ """
29
+ Args:
30
+ input_ids: shape (batch_size, sequence_length)
31
+ scores: shape (batch_size, vocab_size) - pre-softmax logits for the next token
32
+ Returns:
33
+ scores: shape (batch_size, vocab_size) with masked logits
34
+ """
35
+ batch_size, vocab_size = scores.size()
36
+ for b in range(batch_size):
37
+ for token_id in range(vocab_size):
38
+ if token_id not in self.allowed_ids:
39
+ scores[b, token_id] = self.fill_value
40
+ return scores
41
 
42
 
43
  class EndpointHandler:
44
  def __init__(self, path=""):
45
+ # Load tokenizer and model
46
  tokenizer = AutoTokenizer.from_pretrained(path)
47
+ model = AutoModelForCausalLM.from_pretrained(path, device_map="auto", torch_dtype=torch.float16)
48
+
49
+ # Define allowed tokens
50
+ words = ["Paris", "France", "Hello"] # Customize as needed
51
+ allowed_ids = set()
52
+ for word in words:
53
+ for tid in tokenizer.encode(word, add_special_tokens=False):
54
+ allowed_ids.add(tid)
55
+ for tid in tokenizer.encode(" " + word, add_special_tokens=False):
56
+ allowed_ids.add(tid)
57
+
58
+ # Create custom logits processor
59
+ self.logits_processors = LogitsProcessorList([FixedVocabLogitsProcessor(allowed_ids=allowed_ids)])
60
 
61
+ self.tokenizer = tokenizer
62
+ self.model = model
63
+
64
+ def __call__(self, data: Any) -> List[Dict[str, str]]:
65
+ # Extract inputs and parameters
66
  inputs = data.pop("inputs", data)
67
+ parameters = data.pop("parameters", {})
68
+
69
+ # Prepare input IDs
70
+ input_ids = self.tokenizer(inputs, return_tensors="pt").input_ids.to(self.model.device)
71
+
72
+ # Generate output
73
+ output_ids = self.model.generate(
74
+ input_ids=input_ids,
75
+ logits_processor=self.logits_processors,
76
+ max_length=parameters.get("max_length", 30),
77
+ num_beams=parameters.get("num_beams", 1),
78
+ do_sample=parameters.get("do_sample", False),
79
+ pad_token_id=self.tokenizer.eos_token_id,
80
+ no_repeat_ngram_size=parameters.get("no_repeat_ngram_size", 3)
81
+ )
82
+
83
+ # Decode the output
84
+ generated_text = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
85
+
86
+ return [{"generated_text": generated_text}]