kgourgou commited on
Commit
58318c4
·
verified ·
1 Parent(s): 2034a72

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -42
app.py CHANGED
@@ -7,42 +7,46 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
7
  model_name = "gpt2"
8
  tokenizer = AutoTokenizer.from_pretrained(model_name)
9
  model = AutoModelForCausalLM.from_pretrained(model_name)
 
 
 
 
10
 
11
  def min_p_sampling(logits, pbase=0.1):
12
  """
13
  Perform min-p sampling on the logits.
14
-
15
  Args:
16
  logits (torch.Tensor): 1D tensor of logits for the next token.
17
  pbase (float): Base probability to scale pmax.
18
-
19
  Returns:
20
  int: The sampled token index.
21
  """
22
  # Convert logits to probabilities.
23
  probs = torch.softmax(logits, dim=-1)
24
-
25
  # 1. Find maximum probability.
26
  pmax = probs.max()
27
-
28
  # 2. Compute the dynamic threshold.
29
  pscaled = pbase * pmax
30
-
31
  # 3. Create a mask of tokens with probability >= pscaled.
32
  mask = probs >= pscaled
33
  # In the unlikely event that no token meets the threshold, use the full distribution.
34
  if mask.sum() == 0:
35
  mask = torch.ones_like(probs, dtype=torch.bool)
36
-
37
- # Zero out probabilities not meeting the threshold.
38
  probs_filtered = probs * mask.float()
39
-
40
  # 4. Normalize and sample.
41
  probs_normalized = probs_filtered / probs_filtered.sum()
42
  sampled_index = torch.multinomial(probs_normalized, num_samples=1)
43
-
44
  return sampled_index.item()
45
 
 
46
  def generate_completion(prompt, strategy, params):
47
  """
48
  Generate a complete answer using model.generate with specified parameters.
@@ -51,33 +55,35 @@ def generate_completion(prompt, strategy, params):
51
  encoded = tokenizer(prompt, return_tensors="pt")
52
  input_ids = encoded["input_ids"]
53
  attention_mask = encoded["attention_mask"]
54
-
55
  # Generate the output.
56
- output_ids = model.generate(input_ids, attention_mask=attention_mask, max_length=50, **params)
 
 
57
  return tokenizer.decode(output_ids[0], skip_special_tokens=True)
58
 
 
59
  def generate_min_p_completion(prompt, pbase=0.1, max_length=50):
60
- """
61
- Generate a complete answer using a token-by-token loop with min-p sampling.
62
- """
63
- # Encode the prompt.
64
  input_ids = tokenizer.encode(prompt, return_tensors="pt")
65
-
66
- # Generate up to max_length tokens.
67
- for _ in range(max_length - input_ids.size(1)):
68
- outputs = model(input_ids)
69
- logits = outputs.logits[:, -1, :] # Get logits for the last token.
70
- next_token = min_p_sampling(logits, pbase=pbase)
71
-
72
- # Append the new token.
73
- input_ids = torch.cat([input_ids, torch.tensor([[next_token]])], dim=-1)
74
-
75
- # Stop if the end-of-sequence token is generated.
76
- if next_token == tokenizer.eos_token_id:
77
- break
78
-
 
 
79
  return tokenizer.decode(input_ids[0], skip_special_tokens=True)
80
 
 
81
  def generate_all(prompt):
82
  """
83
  Run multiple decoding strategies concurrently and yield updates as each completes.
@@ -86,29 +92,58 @@ def generate_all(prompt):
86
  # For the default strategies, we use model.generate; for "Min‑p Sampling" we use our custom function.
87
  methods = {
88
  "Greedy": {"type": "default", "params": {"do_sample": False}},
89
- "Top-k Sampling": {"type": "default", "params": {"do_sample": True, "top_k": 50}},
90
- "Top-p Sampling": {"type": "default", "params": {"do_sample": True, "top_p": 0.95}},
91
- "Beam Search": {"type": "default", "params": {"num_beams": 5, "early_stopping": True}},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  "Min-p Sampling": {"type": "min_p", "pbase": 0.1},
93
  }
94
-
95
  # Define the order for display.
96
- method_order = ["Greedy", "Top-k Sampling", "Top-p Sampling", "Beam Search", "Min-p Sampling"]
 
 
 
 
 
 
 
 
97
  results = {method: None for method in methods}
98
-
99
  # Yield an initial placeholder state.
100
  yield tuple("Processing..." for _ in method_order)
101
-
102
  # Use a thread pool to run each generation concurrently.
103
  with concurrent.futures.ThreadPoolExecutor() as executor:
104
  future_to_method = {}
105
  for method, info in methods.items():
106
  if info["type"] == "default":
107
- future = executor.submit(generate_completion, prompt, method, info["params"])
 
 
108
  elif info["type"] == "min_p":
109
- future = executor.submit(generate_min_p_completion, prompt, info["pbase"])
 
 
110
  future_to_method[future] = method
111
-
112
  # As each future completes, update its result and yield the current state.
113
  for future in concurrent.futures.as_completed(future_to_method):
114
  method = future_to_method[future]
@@ -118,7 +153,11 @@ def generate_all(prompt):
118
  result = f"Error: {exc}"
119
  results[method] = result
120
  # Yield the results in the pre-defined order; pending methods show "Processing..."
121
- yield tuple(results[m] if results[m] is not None else "Processing..." for m in method_order)
 
 
 
 
122
 
123
  # Create the Gradio interface.
124
  interface = gr.Interface(
@@ -130,10 +169,12 @@ interface = gr.Interface(
130
  gr.Textbox(label="Top-p Sampling"),
131
  gr.Textbox(label="Beam Search"),
132
  gr.Textbox(label="Min-p Sampling"),
 
 
133
  ],
134
  title="Decoding Methods Comparison",
135
- description="Each decoding method's final answer is printed as soon as it is done, including custom min-p sampling."
136
  )
137
 
138
  if __name__ == "__main__":
139
- interface.launch()
 
7
  model_name = "gpt2"
8
  tokenizer = AutoTokenizer.from_pretrained(model_name)
9
  model = AutoModelForCausalLM.from_pretrained(model_name)
10
+ model.eval()
11
+
12
+ torch.set_num_threads(2)
13
+
14
 
15
  def min_p_sampling(logits, pbase=0.1):
16
  """
17
  Perform min-p sampling on the logits.
18
+
19
  Args:
20
  logits (torch.Tensor): 1D tensor of logits for the next token.
21
  pbase (float): Base probability to scale pmax.
22
+
23
  Returns:
24
  int: The sampled token index.
25
  """
26
  # Convert logits to probabilities.
27
  probs = torch.softmax(logits, dim=-1)
28
+
29
  # 1. Find maximum probability.
30
  pmax = probs.max()
31
+
32
  # 2. Compute the dynamic threshold.
33
  pscaled = pbase * pmax
34
+
35
  # 3. Create a mask of tokens with probability >= pscaled.
36
  mask = probs >= pscaled
37
  # In the unlikely event that no token meets the threshold, use the full distribution.
38
  if mask.sum() == 0:
39
  mask = torch.ones_like(probs, dtype=torch.bool)
40
+
 
41
  probs_filtered = probs * mask.float()
42
+
43
  # 4. Normalize and sample.
44
  probs_normalized = probs_filtered / probs_filtered.sum()
45
  sampled_index = torch.multinomial(probs_normalized, num_samples=1)
46
+
47
  return sampled_index.item()
48
 
49
+
50
  def generate_completion(prompt, strategy, params):
51
  """
52
  Generate a complete answer using model.generate with specified parameters.
 
55
  encoded = tokenizer(prompt, return_tensors="pt")
56
  input_ids = encoded["input_ids"]
57
  attention_mask = encoded["attention_mask"]
58
+
59
  # Generate the output.
60
+ output_ids = model.generate(
61
+ input_ids, attention_mask=attention_mask, max_length=50, **params
62
+ )
63
  return tokenizer.decode(output_ids[0], skip_special_tokens=True)
64
 
65
+
66
  def generate_min_p_completion(prompt, pbase=0.1, max_length=50):
 
 
 
 
67
  input_ids = tokenizer.encode(prompt, return_tensors="pt")
68
+ past = None
69
+ with torch.no_grad():
70
+ for _ in range(max_length - input_ids.size(1)):
71
+ # Only pass the last token if past is available
72
+ outputs = (
73
+ model(input_ids[:, -1:], past_key_values=past)
74
+ if past is not None
75
+ else model(input_ids)
76
+ )
77
+ past = outputs.past_key_values
78
+ logits = outputs.logits[:, -1, :]
79
+
80
+ next_token = min_p_sampling(logits, pbase=pbase)
81
+ input_ids = torch.cat([input_ids, torch.tensor([[next_token]])], dim=-1)
82
+ if next_token == tokenizer.eos_token_id:
83
+ break
84
  return tokenizer.decode(input_ids[0], skip_special_tokens=True)
85
 
86
+
87
  def generate_all(prompt):
88
  """
89
  Run multiple decoding strategies concurrently and yield updates as each completes.
 
92
  # For the default strategies, we use model.generate; for "Min‑p Sampling" we use our custom function.
93
  methods = {
94
  "Greedy": {"type": "default", "params": {"do_sample": False}},
95
+ "Top-k Sampling": {
96
+ "type": "default",
97
+ "params": {"do_sample": True, "top_k": 50},
98
+ },
99
+ "Top-p Sampling": {
100
+ "type": "default",
101
+ "params": {"do_sample": True, "top_p": 0.95},
102
+ },
103
+ "Beam Search": {
104
+ "type": "default",
105
+ "params": {"num_beams": 5, "early_stopping": True},
106
+ },
107
+ "Eta Sampling": {
108
+ "type": "default",
109
+ "params": {"do_sample": True, "eta_cutoff": 0.3},
110
+ },
111
+ "Epsilon Sampling": {
112
+ "type": "default",
113
+ "params": {"do_sample": True, "epsilon_cutoff": 0.2},
114
+ },
115
  "Min-p Sampling": {"type": "min_p", "pbase": 0.1},
116
  }
117
+
118
  # Define the order for display.
119
+ method_order = [
120
+ "Greedy",
121
+ "Top-k Sampling",
122
+ "Top-p Sampling",
123
+ "Beam Search",
124
+ "Min-p Sampling",
125
+ "Eta Sampling",
126
+ "Epsilon Sampling",
127
+ ]
128
  results = {method: None for method in methods}
129
+
130
  # Yield an initial placeholder state.
131
  yield tuple("Processing..." for _ in method_order)
132
+
133
  # Use a thread pool to run each generation concurrently.
134
  with concurrent.futures.ThreadPoolExecutor() as executor:
135
  future_to_method = {}
136
  for method, info in methods.items():
137
  if info["type"] == "default":
138
+ future = executor.submit(
139
+ generate_completion, prompt, method, info["params"]
140
+ )
141
  elif info["type"] == "min_p":
142
+ future = executor.submit(
143
+ generate_min_p_completion, prompt, info["pbase"]
144
+ )
145
  future_to_method[future] = method
146
+
147
  # As each future completes, update its result and yield the current state.
148
  for future in concurrent.futures.as_completed(future_to_method):
149
  method = future_to_method[future]
 
153
  result = f"Error: {exc}"
154
  results[method] = result
155
  # Yield the results in the pre-defined order; pending methods show "Processing..."
156
+ yield tuple(
157
+ results[m] if results[m] is not None else "Processing..."
158
+ for m in method_order
159
+ )
160
+
161
 
162
  # Create the Gradio interface.
163
  interface = gr.Interface(
 
169
  gr.Textbox(label="Top-p Sampling"),
170
  gr.Textbox(label="Beam Search"),
171
  gr.Textbox(label="Min-p Sampling"),
172
+ gr.Textbox(label="Eta Sampling"),
173
+ gr.Textbox(label="Epsilon Sampling"),
174
  ],
175
  title="Decoding Methods Comparison",
176
+ description="Each decoding method's final answer is printed as soon as it is done. Model used: GPT-2.",
177
  )
178
 
179
  if __name__ == "__main__":
180
+ interface.launch()