kgourgou commited on
Commit
2e20bf5
·
verified ·
1 Parent(s): 974587f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +96 -24
app.py CHANGED
@@ -3,42 +3,113 @@ import torch
3
  import concurrent.futures
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
 
6
- # Load your model and tokenizer (using GPT-2 as an example)
7
  model_name = "gpt2"
8
  tokenizer = AutoTokenizer.from_pretrained(model_name)
9
  model = AutoModelForCausalLM.from_pretrained(model_name)
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  def generate_completion(prompt, strategy, params):
12
- """Generate a complete answer using the specified decoding strategy."""
13
- input_ids = tokenizer.encode(prompt, return_tensors="pt")
14
- # Adjust generation parameters as needed.
15
- output_ids = model.generate(input_ids, max_length=50, **params)
 
 
 
 
 
 
16
  return tokenizer.decode(output_ids[0], skip_special_tokens=True)
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  def generate_all(prompt):
19
- # Define decoding strategies and their corresponding parameters.
 
 
 
 
20
  methods = {
21
- "Greedy": {"params": {"do_sample": False}},
22
- "Top-k Sampling": {"params": {"do_sample": True, "top_k": 50}},
23
- "Top-p Sampling": {"params": {"do_sample": True, "top_p": 0.95}},
24
- "Beam Search": {"params": {"num_beams": 5, "early_stopping": True}},
 
25
  }
26
- # This list defines the order in which results are displayed.
27
- method_order = ["Greedy", "Top-k Sampling", "Top-p Sampling", "Beam Search"]
28
 
29
- # Dictionary to store the final answer for each method (initially None)
 
30
  results = {method: None for method in methods}
31
 
32
- # Yield an initial state so the UI shows placeholders.
33
  yield tuple("Processing..." for _ in method_order)
34
 
35
- # Use ThreadPoolExecutor to run each generation concurrently.
36
  with concurrent.futures.ThreadPoolExecutor() as executor:
37
- future_to_method = {
38
- executor.submit(generate_completion, prompt, method, methods[method]["params"]): method
39
- for method in methods
40
- }
41
- # As soon as a method finishes, update its result and yield the current state.
 
 
 
 
42
  for future in concurrent.futures.as_completed(future_to_method):
43
  method = future_to_method[future]
44
  try:
@@ -46,10 +117,10 @@ def generate_all(prompt):
46
  except Exception as exc:
47
  result = f"Error: {exc}"
48
  results[method] = result
49
- # Yield the results in the specified order; methods still processing show "Processing..."
50
  yield tuple(results[m] if results[m] is not None else "Processing..." for m in method_order)
51
 
52
- # Create a Gradio interface that uses the generator function.
53
  interface = gr.Interface(
54
  fn=generate_all,
55
  inputs=gr.Textbox(lines=3, placeholder="Enter your prompt here...", label="Prompt"),
@@ -58,9 +129,10 @@ interface = gr.Interface(
58
  gr.Textbox(label="Top-k Sampling"),
59
  gr.Textbox(label="Top-p Sampling"),
60
  gr.Textbox(label="Beam Search"),
 
61
  ],
62
- title="Decoding Methods Results",
63
- description="Each decoding method's complete answer is printed as soon as it's done."
64
  )
65
 
66
  if __name__ == "__main__":
 
3
  import concurrent.futures
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
 
6
+ # Load the model and tokenizer (using GPT-2 as an example)
7
  model_name = "gpt2"
8
  tokenizer = AutoTokenizer.from_pretrained(model_name)
9
  model = AutoModelForCausalLM.from_pretrained(model_name)
10
 
11
+ def min_p_sampling(logits, pbase=0.1):
12
+ """
13
+ Perform min-p sampling on the logits.
14
+
15
+ Args:
16
+ logits (torch.Tensor): 1D tensor of logits for the next token.
17
+ pbase (float): Base probability to scale pmax.
18
+
19
+ Returns:
20
+ int: The sampled token index.
21
+ """
22
+ # Convert logits to probabilities.
23
+ probs = torch.softmax(logits, dim=-1)
24
+
25
+ # 1. Find maximum probability.
26
+ pmax = probs.max()
27
+
28
+ # 2. Compute the dynamic threshold.
29
+ pscaled = pbase * pmax
30
+
31
+ # 3. Create a mask of tokens with probability >= pscaled.
32
+ mask = probs >= pscaled
33
+ # In the unlikely event that no token meets the threshold, use the full distribution.
34
+ if mask.sum() == 0:
35
+ mask = torch.ones_like(probs, dtype=torch.bool)
36
+
37
+ # Zero out probabilities not meeting the threshold.
38
+ probs_filtered = probs * mask.float()
39
+
40
+ # 4. Normalize and sample.
41
+ probs_normalized = probs_filtered / probs_filtered.sum()
42
+ sampled_index = torch.multinomial(probs_normalized, num_samples=1)
43
+
44
+ return sampled_index.item()
45
+
46
  def generate_completion(prompt, strategy, params):
47
+ """
48
+ Generate a complete answer using model.generate with specified parameters.
49
+ """
50
+ # Encode the prompt and get the attention mask.
51
+ encoded = tokenizer(prompt, return_tensors="pt", padding=True)
52
+ input_ids = encoded["input_ids"]
53
+ attention_mask = encoded["attention_mask"]
54
+
55
+ # Generate the output.
56
+ output_ids = model.generate(input_ids, attention_mask=attention_mask, max_length=50, **params)
57
  return tokenizer.decode(output_ids[0], skip_special_tokens=True)
58
 
59
+ def generate_min_p_completion(prompt, pbase=0.1, max_length=50):
60
+ """
61
+ Generate a complete answer using a token-by-token loop with min-p sampling.
62
+ """
63
+ # Encode the prompt.
64
+ input_ids = tokenizer.encode(prompt, return_tensors="pt")
65
+
66
+ # Generate up to max_length tokens.
67
+ for _ in range(max_length - input_ids.size(1)):
68
+ outputs = model(input_ids)
69
+ logits = outputs.logits[:, -1, :] # Get logits for the last token.
70
+ next_token = min_p_sampling(logits, pbase=pbase)
71
+
72
+ # Append the new token.
73
+ input_ids = torch.cat([input_ids, torch.tensor([[next_token]])], dim=-1)
74
+
75
+ # Stop if the end-of-sequence token is generated.
76
+ if next_token == tokenizer.eos_token_id:
77
+ break
78
+
79
+ return tokenizer.decode(input_ids[0], skip_special_tokens=True)
80
+
81
  def generate_all(prompt):
82
+ """
83
+ Run multiple decoding strategies concurrently and yield updates as each completes.
84
+ """
85
+ # Define each decoding strategy and its parameters.
86
+ # For the default strategies, we use model.generate; for "Min‑p Sampling" we use our custom function.
87
  methods = {
88
+ "Greedy": {"type": "default", "params": {"do_sample": False}},
89
+ "Top-k Sampling": {"type": "default", "params": {"do_sample": True, "top_k": 50}},
90
+ "Top-p Sampling": {"type": "default", "params": {"do_sample": True, "top_p": 0.95}},
91
+ "Beam Search": {"type": "default", "params": {"num_beams": 5, "early_stopping": True}},
92
+ "Min-p Sampling": {"type": "min_p", "pbase": 0.1},
93
  }
 
 
94
 
95
+ # Define the order for display.
96
+ method_order = ["Greedy", "Top-k Sampling", "Top-p Sampling", "Beam Search", "Min-p Sampling"]
97
  results = {method: None for method in methods}
98
 
99
+ # Yield an initial placeholder state.
100
  yield tuple("Processing..." for _ in method_order)
101
 
102
+ # Use a thread pool to run each generation concurrently.
103
  with concurrent.futures.ThreadPoolExecutor() as executor:
104
+ future_to_method = {}
105
+ for method, info in methods.items():
106
+ if info["type"] == "default":
107
+ future = executor.submit(generate_completion, prompt, method, info["params"])
108
+ elif info["type"] == "min_p":
109
+ future = executor.submit(generate_min_p_completion, prompt, info["pbase"])
110
+ future_to_method[future] = method
111
+
112
+ # As each future completes, update its result and yield the current state.
113
  for future in concurrent.futures.as_completed(future_to_method):
114
  method = future_to_method[future]
115
  try:
 
117
  except Exception as exc:
118
  result = f"Error: {exc}"
119
  results[method] = result
120
+ # Yield the results in the pre-defined order; pending methods show "Processing..."
121
  yield tuple(results[m] if results[m] is not None else "Processing..." for m in method_order)
122
 
123
+ # Create the Gradio interface.
124
  interface = gr.Interface(
125
  fn=generate_all,
126
  inputs=gr.Textbox(lines=3, placeholder="Enter your prompt here...", label="Prompt"),
 
129
  gr.Textbox(label="Top-k Sampling"),
130
  gr.Textbox(label="Top-p Sampling"),
131
  gr.Textbox(label="Beam Search"),
132
+ gr.Textbox(label="Min-p Sampling"),
133
  ],
134
+ title="Decoding Methods Comparison",
135
+ description="Each decoding method's final answer is printed as soon as it is done, including custom min-p sampling."
136
  )
137
 
138
  if __name__ == "__main__":