BryanBradfo commited on
Commit
d198e0d
·
1 Parent(s): 9b002fb

generate output as it comes

Browse files
Files changed (1) hide show
  1. app.py +68 -36
app.py CHANGED
@@ -19,7 +19,7 @@ st.title("✨ GemmaTextAppeal")
19
  st.markdown("""
20
  ### Interactive Demo of Google's Gemma 2-2B-IT Model
21
  This app demonstrates the text generation capabilities of Google's Gemma 2-2B-IT model.
22
- Enter a prompt below and see the model generate text!
23
  """)
24
 
25
  # Function to load model
@@ -141,7 +141,7 @@ user_input = st.text_area("Enter your prompt:",
141
  height=100,
142
  placeholder="e.g., Write a short story about a robot discovering emotions")
143
 
144
- def generate_text(prompt, max_new_tokens=300, temperature=0.7):
145
  if not tokenizer or not model:
146
  st.session_state.error_message = "Model not properly loaded. Please check your Hugging Face token."
147
  return None
@@ -150,35 +150,71 @@ def generate_text(prompt, max_new_tokens=300, temperature=0.7):
150
  # Format the prompt according to Gemma's expected format
151
  formatted_prompt = f"<bos><start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n"
152
 
153
- # Create the status indicator and output area
154
- status_text = st.empty()
155
- output_area = st.empty()
156
- status_text.text("Generating response...")
 
 
 
157
 
158
  # Tokenize the input
159
- with torch.no_grad():
160
- encoding = tokenizer(formatted_prompt, return_tensors="pt")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
 
162
- # Move to the appropriate device
163
- if torch.cuda.is_available():
164
- encoding = {k: v.to("cuda") for k, v in encoding.items()}
165
 
166
- # Generate the text - streamlined version
167
- output_ids = model.generate(
168
- **encoding,
169
- max_new_tokens=max_new_tokens,
170
- do_sample=True,
171
- temperature=temperature,
172
- pad_token_id=tokenizer.eos_token_id
173
- )
174
 
175
- # Get only the generated part (exclude the prompt)
176
- new_tokens = output_ids[0][encoding["input_ids"].shape[1]:]
177
- generated_text = tokenizer.decode(new_tokens, skip_special_tokens=True)
178
-
179
- # Display the result
180
- output_area.markdown(f"**Generated Response:**\n\n{generated_text}")
181
- status_text.text("Generation complete!")
182
 
183
  return generated_text
184
 
@@ -232,19 +268,15 @@ if st.button("Generate Text"):
232
  st.error("Hugging Face token is required! Please add your token as described above.")
233
  elif user_input:
234
  st.session_state.user_prompt = user_input
235
- with st.spinner("Generating text..."):
236
- result = generate_text(user_input, max_length, temperature)
237
- if result is not None: # Only set if no error occurred
238
- st.session_state.generated_text = result
239
- st.session_state.generation_complete = True
240
  else:
241
  st.error("Please enter a prompt first!")
242
 
243
- # Display results
244
- if st.session_state.generation_complete and not st.session_state.error_message:
245
- st.markdown("### Generated Text")
246
- st.markdown(st.session_state.generated_text)
247
-
248
  # Analysis section
249
  with st.expander("Text Analysis"):
250
  col1, col2 = st.columns(2)
 
19
  st.markdown("""
20
  ### Interactive Demo of Google's Gemma 2-2B-IT Model
21
  This app demonstrates the text generation capabilities of Google's Gemma 2-2B-IT model.
22
+ Enter a prompt below and see the model generate text in real-time!
23
  """)
24
 
25
  # Function to load model
 
141
  height=100,
142
  placeholder="e.g., Write a short story about a robot discovering emotions")
143
 
144
+ def generate_text_streaming(prompt, max_new_tokens=300, temperature=0.7):
145
  if not tokenizer or not model:
146
  st.session_state.error_message = "Model not properly loaded. Please check your Hugging Face token."
147
  return None
 
150
  # Format the prompt according to Gemma's expected format
151
  formatted_prompt = f"<bos><start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n"
152
 
153
+ # Create the output area
154
+ output_container = st.empty()
155
+ response_area = st.container()
156
+
157
+ with response_area:
158
+ st.markdown("**Generated Response:**")
159
+ response_text = st.empty()
160
 
161
  # Tokenize the input
162
+ encoding = tokenizer(formatted_prompt, return_tensors="pt")
163
+
164
+ # Move to the appropriate device
165
+ if torch.cuda.is_available():
166
+ encoding = {k: v.to("cuda") for k, v in encoding.items()}
167
+
168
+ # Store the length of the input to track new tokens
169
+ input_length = encoding["input_ids"].shape[1]
170
+
171
+ # Initialize generated text container
172
+ generated_text = ""
173
+
174
+ # Generate tokens with streaming
175
+ generated_ids = []
176
+
177
+ # Set up generation configuration
178
+ for _ in range(max_new_tokens):
179
+ with torch.no_grad():
180
+ if len(generated_ids) == 0:
181
+ # First token generation
182
+ outputs = model.generate(
183
+ **encoding,
184
+ max_new_tokens=1,
185
+ do_sample=True,
186
+ temperature=temperature,
187
+ pad_token_id=tokenizer.eos_token_id,
188
+ return_dict_in_generate=True,
189
+ output_scores=False
190
+ )
191
+ next_token_id = outputs.sequences[0, input_length:input_length+1]
192
+ else:
193
+ # Subsequent tokens
194
+ current_input_ids = torch.cat([encoding["input_ids"], torch.tensor([generated_ids], device=encoding["input_ids"].device)], dim=1)
195
+ outputs = model.generate(
196
+ input_ids=current_input_ids,
197
+ max_new_tokens=1,
198
+ do_sample=True,
199
+ temperature=temperature,
200
+ pad_token_id=tokenizer.eos_token_id,
201
+ return_dict_in_generate=True,
202
+ output_scores=False
203
+ )
204
+ next_token_id = outputs.sequences[0, -1].unsqueeze(0)
205
 
206
+ # Convert to Python list and append
207
+ next_token_id_list = next_token_id.tolist()
208
+ generated_ids.extend(next_token_id_list)
209
 
210
+ # Check for EOS token
211
+ if tokenizer.eos_token_id in next_token_id_list:
212
+ break
 
 
 
 
 
213
 
214
+ # Decode the tokens generated so far and update the displayed text
215
+ current_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
216
+ generated_text = current_text
217
+ response_text.markdown(generated_text)
 
 
 
218
 
219
  return generated_text
220
 
 
268
  st.error("Hugging Face token is required! Please add your token as described above.")
269
  elif user_input:
270
  st.session_state.user_prompt = user_input
271
+ result = generate_text_streaming(user_input, max_length, temperature)
272
+ if result is not None: # Only set if no error occurred
273
+ st.session_state.generated_text = result
274
+ st.session_state.generation_complete = True
 
275
  else:
276
  st.error("Please enter a prompt first!")
277
 
278
+ # Analysis section (only show after generation is complete)
279
+ if st.session_state.generation_complete and not st.session_state.error_message and st.session_state.generated_text:
 
 
 
280
  # Analysis section
281
  with st.expander("Text Analysis"):
282
  col1, col2 = st.columns(2)