kgourgou commited on
Commit
2034a72
·
verified ·
1 Parent(s): 59f8a34

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -18
app.py CHANGED
@@ -1,12 +1,8 @@
1
  import gradio as gr
2
  import torch
3
  import concurrent.futures
4
- import threading
5
  from transformers import AutoTokenizer, AutoModelForCausalLM
6
 
7
- # Create a lock to serialize access to the model
8
- model_lock = threading.Lock()
9
-
10
  # Load the model and tokenizer (using GPT-2 as an example)
11
  model_name = "gpt2"
12
  tokenizer = AutoTokenizer.from_pretrained(model_name)
@@ -52,19 +48,12 @@ def generate_completion(prompt, strategy, params):
52
  Generate a complete answer using model.generate with specified parameters.
53
  """
54
  # Encode the prompt and get the attention mask.
55
- tokenizer.pad_token = tokenizer.eos_token
56
- encoded = tokenizer(prompt, return_tensors="pt", padding=True)
57
  input_ids = encoded["input_ids"]
58
  attention_mask = encoded["attention_mask"]
59
 
60
- # Use the lock when calling the model
61
- with model_lock:
62
- output_ids = model.generate(
63
- input_ids,
64
- attention_mask=attention_mask,
65
- max_length=50,
66
- **params
67
- )
68
  return tokenizer.decode(output_ids[0], skip_special_tokens=True)
69
 
70
  def generate_min_p_completion(prompt, pbase=0.1, max_length=50):
@@ -76,9 +65,7 @@ def generate_min_p_completion(prompt, pbase=0.1, max_length=50):
76
 
77
  # Generate up to max_length tokens.
78
  for _ in range(max_length - input_ids.size(1)):
79
- # Lock the model call to ensure thread safety.
80
- with model_lock:
81
- outputs = model(input_ids)
82
  logits = outputs.logits[:, -1, :] # Get logits for the last token.
83
  next_token = min_p_sampling(logits, pbase=pbase)
84
 
@@ -96,6 +83,7 @@ def generate_all(prompt):
96
  Run multiple decoding strategies concurrently and yield updates as each completes.
97
  """
98
  # Define each decoding strategy and its parameters.
 
99
  methods = {
100
  "Greedy": {"type": "default", "params": {"do_sample": False}},
101
  "Top-k Sampling": {"type": "default", "params": {"do_sample": True, "top_k": 50}},
@@ -144,7 +132,7 @@ interface = gr.Interface(
144
  gr.Textbox(label="Min-p Sampling"),
145
  ],
146
  title="Decoding Methods Comparison",
147
- description="""This uses GPT2. min-p sampling is from Nguyen, M., et al, 2024, "Turning up the heat: Min-p sampling for creative and coherent llm outputs. arXiv preprint arXiv:2407.01082."""
148
  )
149
 
150
  if __name__ == "__main__":
 
1
  import gradio as gr
2
  import torch
3
  import concurrent.futures
 
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
 
 
 
 
6
  # Load the model and tokenizer (using GPT-2 as an example)
7
  model_name = "gpt2"
8
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
48
  Generate a complete answer using model.generate with specified parameters.
49
  """
50
  # Encode the prompt and get the attention mask.
51
+ encoded = tokenizer(prompt, return_tensors="pt")
 
52
  input_ids = encoded["input_ids"]
53
  attention_mask = encoded["attention_mask"]
54
 
55
+ # Generate the output.
56
+ output_ids = model.generate(input_ids, attention_mask=attention_mask, max_length=50, **params)
 
 
 
 
 
 
57
  return tokenizer.decode(output_ids[0], skip_special_tokens=True)
58
 
59
  def generate_min_p_completion(prompt, pbase=0.1, max_length=50):
 
65
 
66
  # Generate up to max_length tokens.
67
  for _ in range(max_length - input_ids.size(1)):
68
+ outputs = model(input_ids)
 
 
69
  logits = outputs.logits[:, -1, :] # Get logits for the last token.
70
  next_token = min_p_sampling(logits, pbase=pbase)
71
 
 
83
  Run multiple decoding strategies concurrently and yield updates as each completes.
84
  """
85
  # Define each decoding strategy and its parameters.
86
+ # For the default strategies, we use model.generate; for "Min‑p Sampling" we use our custom function.
87
  methods = {
88
  "Greedy": {"type": "default", "params": {"do_sample": False}},
89
  "Top-k Sampling": {"type": "default", "params": {"do_sample": True, "top_k": 50}},
 
132
  gr.Textbox(label="Min-p Sampling"),
133
  ],
134
  title="Decoding Methods Comparison",
135
+ description="Each decoding method's final answer is printed as soon as it is done, including custom min-p sampling."
136
  )
137
 
138
  if __name__ == "__main__":