Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -3,42 +3,113 @@ import torch
|
|
3 |
import concurrent.futures
|
4 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
5 |
|
6 |
-
# Load
|
7 |
model_name = "gpt2"
|
8 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
9 |
model = AutoModelForCausalLM.from_pretrained(model_name)
|
10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
def generate_completion(prompt, strategy, params):
|
12 |
-
"""
|
13 |
-
|
14 |
-
|
15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
return tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
def generate_all(prompt):
|
19 |
-
|
|
|
|
|
|
|
|
|
20 |
methods = {
|
21 |
-
"Greedy": {"params": {"do_sample": False}},
|
22 |
-
"Top-k Sampling": {"params": {"do_sample": True, "top_k": 50}},
|
23 |
-
"Top-p Sampling": {"params": {"do_sample": True, "top_p": 0.95}},
|
24 |
-
"Beam Search": {"params": {"num_beams": 5, "early_stopping": True}},
|
|
|
25 |
}
|
26 |
-
# This list defines the order in which results are displayed.
|
27 |
-
method_order = ["Greedy", "Top-k Sampling", "Top-p Sampling", "Beam Search"]
|
28 |
|
29 |
-
#
|
|
|
30 |
results = {method: None for method in methods}
|
31 |
|
32 |
-
# Yield an initial state
|
33 |
yield tuple("Processing..." for _ in method_order)
|
34 |
|
35 |
-
# Use
|
36 |
with concurrent.futures.ThreadPoolExecutor() as executor:
|
37 |
-
future_to_method = {
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
|
|
|
|
|
|
|
|
42 |
for future in concurrent.futures.as_completed(future_to_method):
|
43 |
method = future_to_method[future]
|
44 |
try:
|
@@ -46,10 +117,10 @@ def generate_all(prompt):
|
|
46 |
except Exception as exc:
|
47 |
result = f"Error: {exc}"
|
48 |
results[method] = result
|
49 |
-
# Yield the results in the
|
50 |
yield tuple(results[m] if results[m] is not None else "Processing..." for m in method_order)
|
51 |
|
52 |
-
# Create
|
53 |
interface = gr.Interface(
|
54 |
fn=generate_all,
|
55 |
inputs=gr.Textbox(lines=3, placeholder="Enter your prompt here...", label="Prompt"),
|
@@ -58,9 +129,10 @@ interface = gr.Interface(
|
|
58 |
gr.Textbox(label="Top-k Sampling"),
|
59 |
gr.Textbox(label="Top-p Sampling"),
|
60 |
gr.Textbox(label="Beam Search"),
|
|
|
61 |
],
|
62 |
-
title="Decoding Methods
|
63 |
-
description="Each decoding method's
|
64 |
)
|
65 |
|
66 |
if __name__ == "__main__":
|
|
|
3 |
import concurrent.futures
|
4 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
5 |
|
6 |
+
# Load the model and tokenizer (using GPT-2 as an example)
|
7 |
model_name = "gpt2"
|
8 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
9 |
model = AutoModelForCausalLM.from_pretrained(model_name)
|
10 |
|
11 |
+
def min_p_sampling(logits, pbase=0.1):
|
12 |
+
"""
|
13 |
+
Perform min-p sampling on the logits.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
logits (torch.Tensor): 1D tensor of logits for the next token.
|
17 |
+
pbase (float): Base probability to scale pmax.
|
18 |
+
|
19 |
+
Returns:
|
20 |
+
int: The sampled token index.
|
21 |
+
"""
|
22 |
+
# Convert logits to probabilities.
|
23 |
+
probs = torch.softmax(logits, dim=-1)
|
24 |
+
|
25 |
+
# 1. Find maximum probability.
|
26 |
+
pmax = probs.max()
|
27 |
+
|
28 |
+
# 2. Compute the dynamic threshold.
|
29 |
+
pscaled = pbase * pmax
|
30 |
+
|
31 |
+
# 3. Create a mask of tokens with probability >= pscaled.
|
32 |
+
mask = probs >= pscaled
|
33 |
+
# In the unlikely event that no token meets the threshold, use the full distribution.
|
34 |
+
if mask.sum() == 0:
|
35 |
+
mask = torch.ones_like(probs, dtype=torch.bool)
|
36 |
+
|
37 |
+
# Zero out probabilities not meeting the threshold.
|
38 |
+
probs_filtered = probs * mask.float()
|
39 |
+
|
40 |
+
# 4. Normalize and sample.
|
41 |
+
probs_normalized = probs_filtered / probs_filtered.sum()
|
42 |
+
sampled_index = torch.multinomial(probs_normalized, num_samples=1)
|
43 |
+
|
44 |
+
return sampled_index.item()
|
45 |
+
|
46 |
def generate_completion(prompt, strategy, params):
|
47 |
+
"""
|
48 |
+
Generate a complete answer using model.generate with specified parameters.
|
49 |
+
"""
|
50 |
+
# Encode the prompt and get the attention mask.
|
51 |
+
encoded = tokenizer(prompt, return_tensors="pt", padding=True)
|
52 |
+
input_ids = encoded["input_ids"]
|
53 |
+
attention_mask = encoded["attention_mask"]
|
54 |
+
|
55 |
+
# Generate the output.
|
56 |
+
output_ids = model.generate(input_ids, attention_mask=attention_mask, max_length=50, **params)
|
57 |
return tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
58 |
|
59 |
+
def generate_min_p_completion(prompt, pbase=0.1, max_length=50):
|
60 |
+
"""
|
61 |
+
Generate a complete answer using a token-by-token loop with min-p sampling.
|
62 |
+
"""
|
63 |
+
# Encode the prompt.
|
64 |
+
input_ids = tokenizer.encode(prompt, return_tensors="pt")
|
65 |
+
|
66 |
+
# Generate up to max_length tokens.
|
67 |
+
for _ in range(max_length - input_ids.size(1)):
|
68 |
+
outputs = model(input_ids)
|
69 |
+
logits = outputs.logits[:, -1, :] # Get logits for the last token.
|
70 |
+
next_token = min_p_sampling(logits, pbase=pbase)
|
71 |
+
|
72 |
+
# Append the new token.
|
73 |
+
input_ids = torch.cat([input_ids, torch.tensor([[next_token]])], dim=-1)
|
74 |
+
|
75 |
+
# Stop if the end-of-sequence token is generated.
|
76 |
+
if next_token == tokenizer.eos_token_id:
|
77 |
+
break
|
78 |
+
|
79 |
+
return tokenizer.decode(input_ids[0], skip_special_tokens=True)
|
80 |
+
|
81 |
def generate_all(prompt):
|
82 |
+
"""
|
83 |
+
Run multiple decoding strategies concurrently and yield updates as each completes.
|
84 |
+
"""
|
85 |
+
# Define each decoding strategy and its parameters.
|
86 |
+
# For the default strategies, we use model.generate; for "Min‑p Sampling" we use our custom function.
|
87 |
methods = {
|
88 |
+
"Greedy": {"type": "default", "params": {"do_sample": False}},
|
89 |
+
"Top-k Sampling": {"type": "default", "params": {"do_sample": True, "top_k": 50}},
|
90 |
+
"Top-p Sampling": {"type": "default", "params": {"do_sample": True, "top_p": 0.95}},
|
91 |
+
"Beam Search": {"type": "default", "params": {"num_beams": 5, "early_stopping": True}},
|
92 |
+
"Min-p Sampling": {"type": "min_p", "pbase": 0.1},
|
93 |
}
|
|
|
|
|
94 |
|
95 |
+
# Define the order for display.
|
96 |
+
method_order = ["Greedy", "Top-k Sampling", "Top-p Sampling", "Beam Search", "Min-p Sampling"]
|
97 |
results = {method: None for method in methods}
|
98 |
|
99 |
+
# Yield an initial placeholder state.
|
100 |
yield tuple("Processing..." for _ in method_order)
|
101 |
|
102 |
+
# Use a thread pool to run each generation concurrently.
|
103 |
with concurrent.futures.ThreadPoolExecutor() as executor:
|
104 |
+
future_to_method = {}
|
105 |
+
for method, info in methods.items():
|
106 |
+
if info["type"] == "default":
|
107 |
+
future = executor.submit(generate_completion, prompt, method, info["params"])
|
108 |
+
elif info["type"] == "min_p":
|
109 |
+
future = executor.submit(generate_min_p_completion, prompt, info["pbase"])
|
110 |
+
future_to_method[future] = method
|
111 |
+
|
112 |
+
# As each future completes, update its result and yield the current state.
|
113 |
for future in concurrent.futures.as_completed(future_to_method):
|
114 |
method = future_to_method[future]
|
115 |
try:
|
|
|
117 |
except Exception as exc:
|
118 |
result = f"Error: {exc}"
|
119 |
results[method] = result
|
120 |
+
# Yield the results in the pre-defined order; pending methods show "Processing..."
|
121 |
yield tuple(results[m] if results[m] is not None else "Processing..." for m in method_order)
|
122 |
|
123 |
+
# Create the Gradio interface.
|
124 |
interface = gr.Interface(
|
125 |
fn=generate_all,
|
126 |
inputs=gr.Textbox(lines=3, placeholder="Enter your prompt here...", label="Prompt"),
|
|
|
129 |
gr.Textbox(label="Top-k Sampling"),
|
130 |
gr.Textbox(label="Top-p Sampling"),
|
131 |
gr.Textbox(label="Beam Search"),
|
132 |
+
gr.Textbox(label="Min-p Sampling"),
|
133 |
],
|
134 |
+
title="Decoding Methods Comparison",
|
135 |
+
description="Each decoding method's final answer is printed as soon as it is done, including custom min-p sampling."
|
136 |
)
|
137 |
|
138 |
if __name__ == "__main__":
|