kgourgou commited on
Commit
c816679
·
verified ·
1 Parent(s): eab7f9b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +117 -5
app.py CHANGED
@@ -14,7 +14,8 @@ torch.set_num_threads(2)
14
 
15
  def min_p_sampling(logits, pbase=0.1):
16
  """
17
- Perform min-p sampling on the logits.
 
18
 
19
  Args:
20
  logits (torch.Tensor): 1D tensor of logits for the next token.
@@ -47,6 +48,96 @@ def min_p_sampling(logits, pbase=0.1):
47
  return sampled_index.item()
48
 
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  def generate_completion(prompt, strategy, params):
51
  """
52
  Generate a complete answer using model.generate with specified parameters.
@@ -59,12 +150,12 @@ def generate_completion(prompt, strategy, params):
59
 
60
  # Generate the output.
61
  output_ids = model.generate(
62
- input_ids, attention_mask=attention_mask, max_length=50, **params
63
  )
64
  return tokenizer.decode(output_ids[0], skip_special_tokens=True)
65
 
66
 
67
- def generate_min_p_completion(prompt, pbase=0.1, max_length=50):
68
  input_ids = tokenizer.encode(prompt, return_tensors="pt")
69
  past = None
70
  with torch.no_grad():
@@ -94,7 +185,7 @@ def generate_all(prompt):
94
  "Greedy": {"type": "default", "params": {"do_sample": False}},
95
  "Top-k Sampling": {
96
  "type": "default",
97
- "params": {"do_sample": True, "top_k": 50},
98
  },
99
  "Top-p Sampling": {
100
  "type": "default",
@@ -113,6 +204,14 @@ def generate_all(prompt):
113
  "params": {"do_sample": True, "epsilon_cutoff": 0.2},
114
  },
115
  "Min-p Sampling": {"type": "min_p", "pbase": 0.1},
 
 
 
 
 
 
 
 
116
  }
117
 
118
  # Define the order for display.
@@ -124,6 +223,8 @@ def generate_all(prompt):
124
  "Min-p Sampling",
125
  "Eta Sampling",
126
  "Epsilon Sampling",
 
 
127
  ]
128
  results = {method: None for method in methods}
129
 
@@ -142,6 +243,11 @@ def generate_all(prompt):
142
  future = executor.submit(
143
  generate_min_p_completion, prompt, info["pbase"]
144
  )
 
 
 
 
 
145
  future_to_method[future] = method
146
 
147
  # As each future completes, update its result and yield the current state.
@@ -169,9 +275,15 @@ interface = gr.Interface(
169
  gr.Textbox(label="Top-k Sampling"),
170
  gr.Textbox(label="Top-p Sampling"),
171
  gr.Textbox(label="Beam Search"),
172
- gr.Textbox(label="Min-p Sampling"),
173
  gr.Textbox(label="Eta Sampling"),
174
  gr.Textbox(label="Epsilon Sampling"),
 
 
 
 
 
 
175
  ],
176
  title="Decoding Methods Comparison",
177
  description="Each decoding method's final answer is printed as soon as it is done. Model used: GPT-2.",
 
14
 
15
  def min_p_sampling(logits, pbase=0.1):
16
  """
17
+ Perform min-p sampling on the logits. As described in
18
+ https://arxiv.org/abs/2407.01082
19
 
20
  Args:
21
  logits (torch.Tensor): 1D tensor of logits for the next token.
 
48
  return sampled_index.item()
49
 
50
 
51
+ def generate_laconic_completion(prompt: str, n: int = 5, max_length: int = 100):
52
+ # generate n completions greedily and return the shortest one
53
+ with torch.no_grad():
54
+ # Encode the prompt and get the attention mask.
55
+ encoded = tokenizer(prompt, return_tensors="pt")
56
+ input_ids = encoded["input_ids"]
57
+ attention_mask = encoded["attention_mask"]
58
+
59
+ # Generate the output.
60
+ outputs = model.generate(
61
+ input_ids,
62
+ attention_mask=attention_mask,
63
+ max_length=max_length,
64
+ num_return_sequences=n,
65
+ do_sample=True,
66
+ )
67
+ completions = [
68
+ tokenizer.decode(output, skip_special_tokens=True) for output in outputs
69
+ ]
70
+ return min(completions, key=len)
71
+
72
+
73
+ def generate_with_confidence(input_ids, max_length):
74
+ """
75
+ Generate a sequence using greedy decoding while returning the scores.
76
+ """
77
+ outputs = model.generate(
78
+ input_ids,
79
+ max_length=max_length,
80
+ do_sample=False,
81
+ output_scores=True,
82
+ return_dict_in_generate=True,
83
+ )
84
+ return outputs
85
+
86
+
87
+ def compute_answer_confidence(outputs):
88
+ """
89
+ Compute the answer confidence over the generated tokens.
90
+ For each generated token, compute the difference between the top-1 and top-2 logits.
91
+ Returns the average difference.
92
+ """
93
+ diffs = []
94
+ for score in outputs.scores:
95
+ # Get top-2 logit values
96
+ top2 = torch.topk(score[0], 2)
97
+ diff = top2.values[0] - top2.values[1]
98
+ diffs.append(diff.item())
99
+
100
+ return sum(diffs) / len(diffs) if diffs else 0.0
101
+
102
+
103
+ def cot_decoding(prompt, k=5, max_length=100):
104
+ """
105
+ Perform Chain-of-Thought (CoT) decoding by exploring top-k alternative paths.
106
+ """
107
+ input_ids = tokenizer.encode(prompt, return_tensors="pt")
108
+
109
+ # Get logits for the next token
110
+ with torch.no_grad():
111
+ outputs = model(input_ids)
112
+ logits = outputs.logits[0, -1, :]
113
+
114
+ # Get top-k candidate tokens
115
+ topk = torch.topk(logits, k)
116
+ candidate_tokens = topk.indices
117
+
118
+ paths = []
119
+ for token in candidate_tokens:
120
+ # Append the candidate token to the prompt
121
+ new_input_ids = torch.cat([input_ids, token.view(1, 1)], dim=1)
122
+
123
+ # Generate a full sequence with output scores
124
+ gen_outputs = generate_with_confidence(
125
+ new_input_ids, max_length=new_input_ids.shape[1] + max_length
126
+ )
127
+
128
+ # Decode the generated sequence
129
+ generated_text = tokenizer.decode(
130
+ gen_outputs.sequences[0], skip_special_tokens=True
131
+ )
132
+
133
+ # Compute answer confidence
134
+ confidence = compute_answer_confidence(gen_outputs)
135
+
136
+ paths.append({"text": generated_text, "confidence": confidence})
137
+
138
+ return max(paths, key=lambda x: x["confidence"])["text"]
139
+
140
+
141
  def generate_completion(prompt, strategy, params):
142
  """
143
  Generate a complete answer using model.generate with specified parameters.
 
150
 
151
  # Generate the output.
152
  output_ids = model.generate(
153
+ input_ids, attention_mask=attention_mask, max_length=100, **params
154
  )
155
  return tokenizer.decode(output_ids[0], skip_special_tokens=True)
156
 
157
 
158
+ def generate_min_p_completion(prompt, pbase=0.1, max_length=100):
159
  input_ids = tokenizer.encode(prompt, return_tensors="pt")
160
  past = None
161
  with torch.no_grad():
 
185
  "Greedy": {"type": "default", "params": {"do_sample": False}},
186
  "Top-k Sampling": {
187
  "type": "default",
188
+ "params": {"do_sample": True, "top_k": 100},
189
  },
190
  "Top-p Sampling": {
191
  "type": "default",
 
204
  "params": {"do_sample": True, "epsilon_cutoff": 0.2},
205
  },
206
  "Min-p Sampling": {"type": "min_p", "pbase": 0.1},
207
+ "laconic": {
208
+ "type": "default",
209
+ "params": {"do_sample": True, "num_return_sequences": 5},
210
+ },
211
+ "COT Decoding": {
212
+ "type": "cot_decoding",
213
+ "params": {"k": 5, "max_length": 100},
214
+ },
215
  }
216
 
217
  # Define the order for display.
 
223
  "Min-p Sampling",
224
  "Eta Sampling",
225
  "Epsilon Sampling",
226
+ "laconic",
227
+ "COT Decoding",
228
  ]
229
  results = {method: None for method in methods}
230
 
 
243
  future = executor.submit(
244
  generate_min_p_completion, prompt, info["pbase"]
245
  )
246
+ elif method == "laconic":
247
+ future = executor.submit(generate_laconic_completion, prompt)
248
+ elif method == "COT Decoding":
249
+ future = executor.submit(cot_decoding, prompt, **info["params"])
250
+
251
  future_to_method[future] = method
252
 
253
  # As each future completes, update its result and yield the current state.
 
275
  gr.Textbox(label="Top-k Sampling"),
276
  gr.Textbox(label="Top-p Sampling"),
277
  gr.Textbox(label="Beam Search"),
278
+ gr.Textbox(label="Min-p Sampling (as in https://arxiv.org/abs/2407.01082)"),
279
  gr.Textbox(label="Eta Sampling"),
280
  gr.Textbox(label="Epsilon Sampling"),
281
+ gr.Textbox(
282
+ label="laconic decoding (by Alex Dimakis, 2025, search for twitter thread)"
283
+ ),
284
+ gr.Textbox(
285
+ label="COT Decoding (Chain-of-Thought Reasoning without Prompting, Wang, Zhou, 2024)"
286
+ ),
287
  ],
288
  title="Decoding Methods Comparison",
289
  description="Each decoding method's final answer is printed as soon as it is done. Model used: GPT-2.",