kgourgou commited on
Commit
0044085
·
verified ·
1 Parent(s): 19e5757

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -11
app.py CHANGED
@@ -1,13 +1,13 @@
1
- """
2
- Fun little experiment.
3
- """
4
-
5
-
6
  import gradio as gr
7
  import torch
8
  import concurrent.futures
 
9
  from transformers import AutoTokenizer, AutoModelForCausalLM
10
 
 
 
 
 
11
  model_name = "gpt2"
12
  tokenizer = AutoTokenizer.from_pretrained(model_name)
13
  model = AutoModelForCausalLM.from_pretrained(model_name)
@@ -52,13 +52,18 @@ def generate_completion(prompt, strategy, params):
52
  Generate a complete answer using model.generate with specified parameters.
53
  """
54
  # Encode the prompt and get the attention mask.
55
- tokenizer.pad_token = tokenizer.eos_token
56
  encoded = tokenizer(prompt, return_tensors="pt", padding=True)
57
  input_ids = encoded["input_ids"]
58
  attention_mask = encoded["attention_mask"]
59
 
60
- # Generate the output.
61
- output_ids = model.generate(input_ids, attention_mask=attention_mask, max_new_tokens=50, **params)
 
 
 
 
 
 
62
  return tokenizer.decode(output_ids[0], skip_special_tokens=True)
63
 
64
  def generate_min_p_completion(prompt, pbase=0.1, max_length=50):
@@ -70,7 +75,9 @@ def generate_min_p_completion(prompt, pbase=0.1, max_length=50):
70
 
71
  # Generate up to max_length tokens.
72
  for _ in range(max_length - input_ids.size(1)):
73
- outputs = model(input_ids)
 
 
74
  logits = outputs.logits[:, -1, :] # Get logits for the last token.
75
  next_token = min_p_sampling(logits, pbase=pbase)
76
 
@@ -88,7 +95,6 @@ def generate_all(prompt):
88
  Run multiple decoding strategies concurrently and yield updates as each completes.
89
  """
90
  # Define each decoding strategy and its parameters.
91
- # For the default strategies, we use model.generate; for "Min‑p Sampling" we use our custom function.
92
  methods = {
93
  "Greedy": {"type": "default", "params": {"do_sample": False}},
94
  "Top-k Sampling": {"type": "default", "params": {"do_sample": True, "top_k": 50}},
@@ -137,7 +143,7 @@ interface = gr.Interface(
137
  gr.Textbox(label="Min-p Sampling"),
138
  ],
139
  title="Decoding Methods Comparison",
140
- description="Each decoding method's final answer is printed as soon as it is done. This uses GPT2."
141
  )
142
 
143
  if __name__ == "__main__":
 
 
 
 
 
 
1
  import gradio as gr
2
  import torch
3
  import concurrent.futures
4
+ import threading
5
  from transformers import AutoTokenizer, AutoModelForCausalLM
6
 
7
+ # Create a lock to serialize access to the model
8
+ model_lock = threading.Lock()
9
+
10
+ # Load the model and tokenizer (using GPT-2 as an example)
11
  model_name = "gpt2"
12
  tokenizer = AutoTokenizer.from_pretrained(model_name)
13
  model = AutoModelForCausalLM.from_pretrained(model_name)
 
52
  Generate a complete answer using model.generate with specified parameters.
53
  """
54
  # Encode the prompt and get the attention mask.
 
55
  encoded = tokenizer(prompt, return_tensors="pt", padding=True)
56
  input_ids = encoded["input_ids"]
57
  attention_mask = encoded["attention_mask"]
58
 
59
+ # Use the lock when calling the model
60
+ with model_lock:
61
+ output_ids = model.generate(
62
+ input_ids,
63
+ attention_mask=attention_mask,
64
+ max_length=50,
65
+ **params
66
+ )
67
  return tokenizer.decode(output_ids[0], skip_special_tokens=True)
68
 
69
  def generate_min_p_completion(prompt, pbase=0.1, max_length=50):
 
75
 
76
  # Generate up to max_length tokens.
77
  for _ in range(max_length - input_ids.size(1)):
78
+ # Lock the model call to ensure thread safety.
79
+ with model_lock:
80
+ outputs = model(input_ids)
81
  logits = outputs.logits[:, -1, :] # Get logits for the last token.
82
  next_token = min_p_sampling(logits, pbase=pbase)
83
 
 
95
  Run multiple decoding strategies concurrently and yield updates as each completes.
96
  """
97
  # Define each decoding strategy and its parameters.
 
98
  methods = {
99
  "Greedy": {"type": "default", "params": {"do_sample": False}},
100
  "Top-k Sampling": {"type": "default", "params": {"do_sample": True, "top_k": 50}},
 
143
  gr.Textbox(label="Min-p Sampling"),
144
  ],
145
  title="Decoding Methods Comparison",
146
+ description="Each decoding method's final answer is printed as soon as it is done, including custom min-p sampling."
147
  )
148
 
149
  if __name__ == "__main__":