Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -7,42 +7,46 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
7 |
model_name = "gpt2"
|
8 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
9 |
model = AutoModelForCausalLM.from_pretrained(model_name)
|
|
|
|
|
|
|
|
|
10 |
|
11 |
def min_p_sampling(logits, pbase=0.1):
|
12 |
"""
|
13 |
Perform min-p sampling on the logits.
|
14 |
-
|
15 |
Args:
|
16 |
logits (torch.Tensor): 1D tensor of logits for the next token.
|
17 |
pbase (float): Base probability to scale pmax.
|
18 |
-
|
19 |
Returns:
|
20 |
int: The sampled token index.
|
21 |
"""
|
22 |
# Convert logits to probabilities.
|
23 |
probs = torch.softmax(logits, dim=-1)
|
24 |
-
|
25 |
# 1. Find maximum probability.
|
26 |
pmax = probs.max()
|
27 |
-
|
28 |
# 2. Compute the dynamic threshold.
|
29 |
pscaled = pbase * pmax
|
30 |
-
|
31 |
# 3. Create a mask of tokens with probability >= pscaled.
|
32 |
mask = probs >= pscaled
|
33 |
# In the unlikely event that no token meets the threshold, use the full distribution.
|
34 |
if mask.sum() == 0:
|
35 |
mask = torch.ones_like(probs, dtype=torch.bool)
|
36 |
-
|
37 |
-
# Zero out probabilities not meeting the threshold.
|
38 |
probs_filtered = probs * mask.float()
|
39 |
-
|
40 |
# 4. Normalize and sample.
|
41 |
probs_normalized = probs_filtered / probs_filtered.sum()
|
42 |
sampled_index = torch.multinomial(probs_normalized, num_samples=1)
|
43 |
-
|
44 |
return sampled_index.item()
|
45 |
|
|
|
46 |
def generate_completion(prompt, strategy, params):
|
47 |
"""
|
48 |
Generate a complete answer using model.generate with specified parameters.
|
@@ -51,33 +55,35 @@ def generate_completion(prompt, strategy, params):
|
|
51 |
encoded = tokenizer(prompt, return_tensors="pt")
|
52 |
input_ids = encoded["input_ids"]
|
53 |
attention_mask = encoded["attention_mask"]
|
54 |
-
|
55 |
# Generate the output.
|
56 |
-
output_ids = model.generate(
|
|
|
|
|
57 |
return tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
58 |
|
|
|
59 |
def generate_min_p_completion(prompt, pbase=0.1, max_length=50):
|
60 |
-
"""
|
61 |
-
Generate a complete answer using a token-by-token loop with min-p sampling.
|
62 |
-
"""
|
63 |
-
# Encode the prompt.
|
64 |
input_ids = tokenizer.encode(prompt, return_tensors="pt")
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
|
|
|
|
79 |
return tokenizer.decode(input_ids[0], skip_special_tokens=True)
|
80 |
|
|
|
81 |
def generate_all(prompt):
|
82 |
"""
|
83 |
Run multiple decoding strategies concurrently and yield updates as each completes.
|
@@ -86,29 +92,58 @@ def generate_all(prompt):
|
|
86 |
# For the default strategies, we use model.generate; for "Min‑p Sampling" we use our custom function.
|
87 |
methods = {
|
88 |
"Greedy": {"type": "default", "params": {"do_sample": False}},
|
89 |
-
"Top-k Sampling": {
|
90 |
-
|
91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
"Min-p Sampling": {"type": "min_p", "pbase": 0.1},
|
93 |
}
|
94 |
-
|
95 |
# Define the order for display.
|
96 |
-
method_order = [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
results = {method: None for method in methods}
|
98 |
-
|
99 |
# Yield an initial placeholder state.
|
100 |
yield tuple("Processing..." for _ in method_order)
|
101 |
-
|
102 |
# Use a thread pool to run each generation concurrently.
|
103 |
with concurrent.futures.ThreadPoolExecutor() as executor:
|
104 |
future_to_method = {}
|
105 |
for method, info in methods.items():
|
106 |
if info["type"] == "default":
|
107 |
-
future = executor.submit(
|
|
|
|
|
108 |
elif info["type"] == "min_p":
|
109 |
-
future = executor.submit(
|
|
|
|
|
110 |
future_to_method[future] = method
|
111 |
-
|
112 |
# As each future completes, update its result and yield the current state.
|
113 |
for future in concurrent.futures.as_completed(future_to_method):
|
114 |
method = future_to_method[future]
|
@@ -118,7 +153,11 @@ def generate_all(prompt):
|
|
118 |
result = f"Error: {exc}"
|
119 |
results[method] = result
|
120 |
# Yield the results in the pre-defined order; pending methods show "Processing..."
|
121 |
-
yield tuple(
|
|
|
|
|
|
|
|
|
122 |
|
123 |
# Create the Gradio interface.
|
124 |
interface = gr.Interface(
|
@@ -130,10 +169,12 @@ interface = gr.Interface(
|
|
130 |
gr.Textbox(label="Top-p Sampling"),
|
131 |
gr.Textbox(label="Beam Search"),
|
132 |
gr.Textbox(label="Min-p Sampling"),
|
|
|
|
|
133 |
],
|
134 |
title="Decoding Methods Comparison",
|
135 |
-
description="Each decoding method's final answer is printed as soon as it is done
|
136 |
)
|
137 |
|
138 |
if __name__ == "__main__":
|
139 |
-
interface.launch()
|
|
|
7 |
model_name = "gpt2"
|
8 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
9 |
model = AutoModelForCausalLM.from_pretrained(model_name)
|
10 |
+
model.eval()
|
11 |
+
|
12 |
+
torch.set_num_threads(2)
|
13 |
+
|
14 |
|
15 |
def min_p_sampling(logits, pbase=0.1):
|
16 |
"""
|
17 |
Perform min-p sampling on the logits.
|
18 |
+
|
19 |
Args:
|
20 |
logits (torch.Tensor): 1D tensor of logits for the next token.
|
21 |
pbase (float): Base probability to scale pmax.
|
22 |
+
|
23 |
Returns:
|
24 |
int: The sampled token index.
|
25 |
"""
|
26 |
# Convert logits to probabilities.
|
27 |
probs = torch.softmax(logits, dim=-1)
|
28 |
+
|
29 |
# 1. Find maximum probability.
|
30 |
pmax = probs.max()
|
31 |
+
|
32 |
# 2. Compute the dynamic threshold.
|
33 |
pscaled = pbase * pmax
|
34 |
+
|
35 |
# 3. Create a mask of tokens with probability >= pscaled.
|
36 |
mask = probs >= pscaled
|
37 |
# In the unlikely event that no token meets the threshold, use the full distribution.
|
38 |
if mask.sum() == 0:
|
39 |
mask = torch.ones_like(probs, dtype=torch.bool)
|
40 |
+
|
|
|
41 |
probs_filtered = probs * mask.float()
|
42 |
+
|
43 |
# 4. Normalize and sample.
|
44 |
probs_normalized = probs_filtered / probs_filtered.sum()
|
45 |
sampled_index = torch.multinomial(probs_normalized, num_samples=1)
|
46 |
+
|
47 |
return sampled_index.item()
|
48 |
|
49 |
+
|
50 |
def generate_completion(prompt, strategy, params):
|
51 |
"""
|
52 |
Generate a complete answer using model.generate with specified parameters.
|
|
|
55 |
encoded = tokenizer(prompt, return_tensors="pt")
|
56 |
input_ids = encoded["input_ids"]
|
57 |
attention_mask = encoded["attention_mask"]
|
58 |
+
|
59 |
# Generate the output.
|
60 |
+
output_ids = model.generate(
|
61 |
+
input_ids, attention_mask=attention_mask, max_length=50, **params
|
62 |
+
)
|
63 |
return tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
64 |
|
65 |
+
|
66 |
def generate_min_p_completion(prompt, pbase=0.1, max_length=50):
|
|
|
|
|
|
|
|
|
67 |
input_ids = tokenizer.encode(prompt, return_tensors="pt")
|
68 |
+
past = None
|
69 |
+
with torch.no_grad():
|
70 |
+
for _ in range(max_length - input_ids.size(1)):
|
71 |
+
# Only pass the last token if past is available
|
72 |
+
outputs = (
|
73 |
+
model(input_ids[:, -1:], past_key_values=past)
|
74 |
+
if past is not None
|
75 |
+
else model(input_ids)
|
76 |
+
)
|
77 |
+
past = outputs.past_key_values
|
78 |
+
logits = outputs.logits[:, -1, :]
|
79 |
+
|
80 |
+
next_token = min_p_sampling(logits, pbase=pbase)
|
81 |
+
input_ids = torch.cat([input_ids, torch.tensor([[next_token]])], dim=-1)
|
82 |
+
if next_token == tokenizer.eos_token_id:
|
83 |
+
break
|
84 |
return tokenizer.decode(input_ids[0], skip_special_tokens=True)
|
85 |
|
86 |
+
|
87 |
def generate_all(prompt):
|
88 |
"""
|
89 |
Run multiple decoding strategies concurrently and yield updates as each completes.
|
|
|
92 |
# For the default strategies, we use model.generate; for "Min‑p Sampling" we use our custom function.
|
93 |
methods = {
|
94 |
"Greedy": {"type": "default", "params": {"do_sample": False}},
|
95 |
+
"Top-k Sampling": {
|
96 |
+
"type": "default",
|
97 |
+
"params": {"do_sample": True, "top_k": 50},
|
98 |
+
},
|
99 |
+
"Top-p Sampling": {
|
100 |
+
"type": "default",
|
101 |
+
"params": {"do_sample": True, "top_p": 0.95},
|
102 |
+
},
|
103 |
+
"Beam Search": {
|
104 |
+
"type": "default",
|
105 |
+
"params": {"num_beams": 5, "early_stopping": True},
|
106 |
+
},
|
107 |
+
"Eta Sampling": {
|
108 |
+
"type": "default",
|
109 |
+
"params": {"do_sample": True, "eta_cutoff": 0.3},
|
110 |
+
},
|
111 |
+
"Epsilon Sampling": {
|
112 |
+
"type": "default",
|
113 |
+
"params": {"do_sample": True, "epsilon_cutoff": 0.2},
|
114 |
+
},
|
115 |
"Min-p Sampling": {"type": "min_p", "pbase": 0.1},
|
116 |
}
|
117 |
+
|
118 |
# Define the order for display.
|
119 |
+
method_order = [
|
120 |
+
"Greedy",
|
121 |
+
"Top-k Sampling",
|
122 |
+
"Top-p Sampling",
|
123 |
+
"Beam Search",
|
124 |
+
"Min-p Sampling",
|
125 |
+
"Eta Sampling",
|
126 |
+
"Epsilon Sampling",
|
127 |
+
]
|
128 |
results = {method: None for method in methods}
|
129 |
+
|
130 |
# Yield an initial placeholder state.
|
131 |
yield tuple("Processing..." for _ in method_order)
|
132 |
+
|
133 |
# Use a thread pool to run each generation concurrently.
|
134 |
with concurrent.futures.ThreadPoolExecutor() as executor:
|
135 |
future_to_method = {}
|
136 |
for method, info in methods.items():
|
137 |
if info["type"] == "default":
|
138 |
+
future = executor.submit(
|
139 |
+
generate_completion, prompt, method, info["params"]
|
140 |
+
)
|
141 |
elif info["type"] == "min_p":
|
142 |
+
future = executor.submit(
|
143 |
+
generate_min_p_completion, prompt, info["pbase"]
|
144 |
+
)
|
145 |
future_to_method[future] = method
|
146 |
+
|
147 |
# As each future completes, update its result and yield the current state.
|
148 |
for future in concurrent.futures.as_completed(future_to_method):
|
149 |
method = future_to_method[future]
|
|
|
153 |
result = f"Error: {exc}"
|
154 |
results[method] = result
|
155 |
# Yield the results in the pre-defined order; pending methods show "Processing..."
|
156 |
+
yield tuple(
|
157 |
+
results[m] if results[m] is not None else "Processing..."
|
158 |
+
for m in method_order
|
159 |
+
)
|
160 |
+
|
161 |
|
162 |
# Create the Gradio interface.
|
163 |
interface = gr.Interface(
|
|
|
169 |
gr.Textbox(label="Top-p Sampling"),
|
170 |
gr.Textbox(label="Beam Search"),
|
171 |
gr.Textbox(label="Min-p Sampling"),
|
172 |
+
gr.Textbox(label="Eta Sampling"),
|
173 |
+
gr.Textbox(label="Epsilon Sampling"),
|
174 |
],
|
175 |
title="Decoding Methods Comparison",
|
176 |
+
description="Each decoding method's final answer is printed as soon as it is done. Model used: GPT-2.",
|
177 |
)
|
178 |
|
179 |
if __name__ == "__main__":
|
180 |
+
interface.launch()
|