Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,13 +1,13 @@
|
|
1 |
-
"""
|
2 |
-
Fun little experiment.
|
3 |
-
"""
|
4 |
-
|
5 |
-
|
6 |
import gradio as gr
|
7 |
import torch
|
8 |
import concurrent.futures
|
|
|
9 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
10 |
|
|
|
|
|
|
|
|
|
11 |
model_name = "gpt2"
|
12 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
13 |
model = AutoModelForCausalLM.from_pretrained(model_name)
|
@@ -52,13 +52,18 @@ def generate_completion(prompt, strategy, params):
|
|
52 |
Generate a complete answer using model.generate with specified parameters.
|
53 |
"""
|
54 |
# Encode the prompt and get the attention mask.
|
55 |
-
tokenizer.pad_token = tokenizer.eos_token
|
56 |
encoded = tokenizer(prompt, return_tensors="pt", padding=True)
|
57 |
input_ids = encoded["input_ids"]
|
58 |
attention_mask = encoded["attention_mask"]
|
59 |
|
60 |
-
#
|
61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
return tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
63 |
|
64 |
def generate_min_p_completion(prompt, pbase=0.1, max_length=50):
|
@@ -70,7 +75,9 @@ def generate_min_p_completion(prompt, pbase=0.1, max_length=50):
|
|
70 |
|
71 |
# Generate up to max_length tokens.
|
72 |
for _ in range(max_length - input_ids.size(1)):
|
73 |
-
|
|
|
|
|
74 |
logits = outputs.logits[:, -1, :] # Get logits for the last token.
|
75 |
next_token = min_p_sampling(logits, pbase=pbase)
|
76 |
|
@@ -88,7 +95,6 @@ def generate_all(prompt):
|
|
88 |
Run multiple decoding strategies concurrently and yield updates as each completes.
|
89 |
"""
|
90 |
# Define each decoding strategy and its parameters.
|
91 |
-
# For the default strategies, we use model.generate; for "Min‑p Sampling" we use our custom function.
|
92 |
methods = {
|
93 |
"Greedy": {"type": "default", "params": {"do_sample": False}},
|
94 |
"Top-k Sampling": {"type": "default", "params": {"do_sample": True, "top_k": 50}},
|
@@ -137,7 +143,7 @@ interface = gr.Interface(
|
|
137 |
gr.Textbox(label="Min-p Sampling"),
|
138 |
],
|
139 |
title="Decoding Methods Comparison",
|
140 |
-
description="Each decoding method's final answer is printed as soon as it is done
|
141 |
)
|
142 |
|
143 |
if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
import concurrent.futures
|
4 |
+
import threading
|
5 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
6 |
|
7 |
+
# Create a lock to serialize access to the model
|
8 |
+
model_lock = threading.Lock()
|
9 |
+
|
10 |
+
# Load the model and tokenizer (using GPT-2 as an example)
|
11 |
model_name = "gpt2"
|
12 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
13 |
model = AutoModelForCausalLM.from_pretrained(model_name)
|
|
|
52 |
Generate a complete answer using model.generate with specified parameters.
|
53 |
"""
|
54 |
# Encode the prompt and get the attention mask.
|
|
|
55 |
encoded = tokenizer(prompt, return_tensors="pt", padding=True)
|
56 |
input_ids = encoded["input_ids"]
|
57 |
attention_mask = encoded["attention_mask"]
|
58 |
|
59 |
+
# Use the lock when calling the model
|
60 |
+
with model_lock:
|
61 |
+
output_ids = model.generate(
|
62 |
+
input_ids,
|
63 |
+
attention_mask=attention_mask,
|
64 |
+
max_length=50,
|
65 |
+
**params
|
66 |
+
)
|
67 |
return tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
68 |
|
69 |
def generate_min_p_completion(prompt, pbase=0.1, max_length=50):
|
|
|
75 |
|
76 |
# Generate up to max_length tokens.
|
77 |
for _ in range(max_length - input_ids.size(1)):
|
78 |
+
# Lock the model call to ensure thread safety.
|
79 |
+
with model_lock:
|
80 |
+
outputs = model(input_ids)
|
81 |
logits = outputs.logits[:, -1, :] # Get logits for the last token.
|
82 |
next_token = min_p_sampling(logits, pbase=pbase)
|
83 |
|
|
|
95 |
Run multiple decoding strategies concurrently and yield updates as each completes.
|
96 |
"""
|
97 |
# Define each decoding strategy and its parameters.
|
|
|
98 |
methods = {
|
99 |
"Greedy": {"type": "default", "params": {"do_sample": False}},
|
100 |
"Top-k Sampling": {"type": "default", "params": {"do_sample": True, "top_k": 50}},
|
|
|
143 |
gr.Textbox(label="Min-p Sampling"),
|
144 |
],
|
145 |
title="Decoding Methods Comparison",
|
146 |
+
description="Each decoding method's final answer is printed as soon as it is done, including custom min-p sampling."
|
147 |
)
|
148 |
|
149 |
if __name__ == "__main__":
|