DiDustin commited on
Commit
cbc4fca
·
verified ·
1 Parent(s): 4bb79cf

Update app.py

Browse files

update activation of top_p, and gpu required time

Files changed (1) hide show
  1. app.py +27 -16
app.py CHANGED
@@ -63,15 +63,15 @@ LANGUAGES = {
63
  loaded_models = {}
64
  loaded_tokenizers = {}
65
 
 
66
  @spaces.GPU(duration=60)
67
  def load_model_and_tokenizer(model_key):
68
  if model_key not in loaded_models:
69
  model_info = MODELS[model_key]
70
- device = "cuda"
71
  model = AutoModelForCausalLM.from_pretrained(
72
  model_info["model_name"],
73
- token=HF_TOKEN,
74
- torch_dtype=torch.float16
75
  ).to(device)
76
  loaded_models[model_key] = model
77
 
@@ -84,26 +84,31 @@ def load_model_and_tokenizer(model_key):
84
  tokenizer.pad_token = tokenizer.eos_token
85
  loaded_tokenizers[model_key] = tokenizer
86
 
87
- @spaces.GPU(duration=140)
 
88
  def generate_text(model_choice, prompt, max_length, temperature, top_p, do_sample):
89
  load_model_and_tokenizer(model_choice)
90
 
91
  model = loaded_models[model_choice]
92
  tokenizer = loaded_tokenizers[model_choice]
93
- device = "cuda"
94
 
95
  inputs = tokenizer(prompt, return_tensors="pt", truncation=True, padding=True).to(device)
96
 
97
- outputs = model.generate(
98
- input_ids=inputs["input_ids"],
99
- attention_mask=inputs["attention_mask"],
100
- max_length=max_length,
101
- temperature=temperature,
102
- top_p=top_p,
103
- repetition_penalty=1.2,
104
- no_repeat_ngram_size=2,
105
- do_sample=do_sample,
106
- )
 
 
 
 
107
 
108
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
109
 
@@ -136,7 +141,7 @@ def update_language(selected_language):
136
  )
137
 
138
 
139
- @spaces.GPU(duration=140)
140
  def wrapped_generate_text(model_choice, prompt, max_length, temperature, top_p, do_sample):
141
  return generate_text(model_choice, prompt, max_length, temperature, top_p, do_sample)
142
 
@@ -215,6 +220,12 @@ with gr.Blocks() as iface:
215
  do_sample_checkbox, generate_button, output_text]
216
  )
217
 
 
 
 
 
 
 
218
  generate_button.click(
219
  fn=wrapped_generate_text,
220
  inputs=[
 
63
  loaded_models = {}
64
  loaded_tokenizers = {}
65
 
66
+
67
  @spaces.GPU(duration=60)
68
  def load_model_and_tokenizer(model_key):
69
  if model_key not in loaded_models:
70
  model_info = MODELS[model_key]
71
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
72
  model = AutoModelForCausalLM.from_pretrained(
73
  model_info["model_name"],
74
+ token=HF_TOKEN
 
75
  ).to(device)
76
  loaded_models[model_key] = model
77
 
 
84
  tokenizer.pad_token = tokenizer.eos_token
85
  loaded_tokenizers[model_key] = tokenizer
86
 
87
+
88
+ @spaces.GPU(duration=120)
89
  def generate_text(model_choice, prompt, max_length, temperature, top_p, do_sample):
90
  load_model_and_tokenizer(model_choice)
91
 
92
  model = loaded_models[model_choice]
93
  tokenizer = loaded_tokenizers[model_choice]
94
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
95
 
96
  inputs = tokenizer(prompt, return_tensors="pt", truncation=True, padding=True).to(device)
97
 
98
+ generation_kwargs = {
99
+ "input_ids": inputs["input_ids"],
100
+ "attention_mask": inputs["attention_mask"],
101
+ "max_length": max_length,
102
+ "temperature": temperature,
103
+ "repetition_penalty": 1.2,
104
+ "no_repeat_ngram_size": 2,
105
+ "do_sample": do_sample,
106
+ }
107
+
108
+ if do_sample:
109
+ generation_kwargs["top_p"] = top_p
110
+
111
+ outputs = model.generate(**generation_kwargs)
112
 
113
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
114
 
 
141
  )
142
 
143
 
144
+ @spaces.GPU(duration=120)
145
  def wrapped_generate_text(model_choice, prompt, max_length, temperature, top_p, do_sample):
146
  return generate_text(model_choice, prompt, max_length, temperature, top_p, do_sample)
147
 
 
220
  do_sample_checkbox, generate_button, output_text]
221
  )
222
 
223
+ do_sample_checkbox.change(
224
+ fn=lambda do_sample: gr.update(visible=do_sample),
225
+ inputs=[do_sample_checkbox],
226
+ outputs=[top_p_slider]
227
+ )
228
+
229
  generate_button.click(
230
  fn=wrapped_generate_text,
231
  inputs=[