Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,12 +1,8 @@
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
import concurrent.futures
|
4 |
-
import threading
|
5 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
6 |
|
7 |
-
# Create a lock to serialize access to the model
|
8 |
-
model_lock = threading.Lock()
|
9 |
-
|
10 |
# Load the model and tokenizer (using GPT-2 as an example)
|
11 |
model_name = "gpt2"
|
12 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
@@ -52,19 +48,12 @@ def generate_completion(prompt, strategy, params):
|
|
52 |
Generate a complete answer using model.generate with specified parameters.
|
53 |
"""
|
54 |
# Encode the prompt and get the attention mask.
|
55 |
-
|
56 |
-
encoded = tokenizer(prompt, return_tensors="pt", padding=True)
|
57 |
input_ids = encoded["input_ids"]
|
58 |
attention_mask = encoded["attention_mask"]
|
59 |
|
60 |
-
#
|
61 |
-
|
62 |
-
output_ids = model.generate(
|
63 |
-
input_ids,
|
64 |
-
attention_mask=attention_mask,
|
65 |
-
max_length=50,
|
66 |
-
**params
|
67 |
-
)
|
68 |
return tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
69 |
|
70 |
def generate_min_p_completion(prompt, pbase=0.1, max_length=50):
|
@@ -76,9 +65,7 @@ def generate_min_p_completion(prompt, pbase=0.1, max_length=50):
|
|
76 |
|
77 |
# Generate up to max_length tokens.
|
78 |
for _ in range(max_length - input_ids.size(1)):
|
79 |
-
|
80 |
-
with model_lock:
|
81 |
-
outputs = model(input_ids)
|
82 |
logits = outputs.logits[:, -1, :] # Get logits for the last token.
|
83 |
next_token = min_p_sampling(logits, pbase=pbase)
|
84 |
|
@@ -96,6 +83,7 @@ def generate_all(prompt):
|
|
96 |
Run multiple decoding strategies concurrently and yield updates as each completes.
|
97 |
"""
|
98 |
# Define each decoding strategy and its parameters.
|
|
|
99 |
methods = {
|
100 |
"Greedy": {"type": "default", "params": {"do_sample": False}},
|
101 |
"Top-k Sampling": {"type": "default", "params": {"do_sample": True, "top_k": 50}},
|
@@ -144,7 +132,7 @@ interface = gr.Interface(
|
|
144 |
gr.Textbox(label="Min-p Sampling"),
|
145 |
],
|
146 |
title="Decoding Methods Comparison",
|
147 |
-
description="
|
148 |
)
|
149 |
|
150 |
if __name__ == "__main__":
|
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
import concurrent.futures
|
|
|
4 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
5 |
|
|
|
|
|
|
|
6 |
# Load the model and tokenizer (using GPT-2 as an example)
|
7 |
model_name = "gpt2"
|
8 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
|
48 |
Generate a complete answer using model.generate with specified parameters.
|
49 |
"""
|
50 |
# Encode the prompt and get the attention mask.
|
51 |
+
encoded = tokenizer(prompt, return_tensors="pt")
|
|
|
52 |
input_ids = encoded["input_ids"]
|
53 |
attention_mask = encoded["attention_mask"]
|
54 |
|
55 |
+
# Generate the output.
|
56 |
+
output_ids = model.generate(input_ids, attention_mask=attention_mask, max_length=50, **params)
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
return tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
58 |
|
59 |
def generate_min_p_completion(prompt, pbase=0.1, max_length=50):
|
|
|
65 |
|
66 |
# Generate up to max_length tokens.
|
67 |
for _ in range(max_length - input_ids.size(1)):
|
68 |
+
outputs = model(input_ids)
|
|
|
|
|
69 |
logits = outputs.logits[:, -1, :] # Get logits for the last token.
|
70 |
next_token = min_p_sampling(logits, pbase=pbase)
|
71 |
|
|
|
83 |
Run multiple decoding strategies concurrently and yield updates as each completes.
|
84 |
"""
|
85 |
# Define each decoding strategy and its parameters.
|
86 |
+
# For the default strategies, we use model.generate; for "Min‑p Sampling" we use our custom function.
|
87 |
methods = {
|
88 |
"Greedy": {"type": "default", "params": {"do_sample": False}},
|
89 |
"Top-k Sampling": {"type": "default", "params": {"do_sample": True, "top_k": 50}},
|
|
|
132 |
gr.Textbox(label="Min-p Sampling"),
|
133 |
],
|
134 |
title="Decoding Methods Comparison",
|
135 |
+
description="Each decoding method's final answer is printed as soon as it is done, including custom min-p sampling."
|
136 |
)
|
137 |
|
138 |
if __name__ == "__main__":
|