Spaces:
Sleeping
Sleeping
Commit
·
d198e0d
1
Parent(s):
9b002fb
generate output as it comes
Browse files
app.py
CHANGED
@@ -19,7 +19,7 @@ st.title("✨ GemmaTextAppeal")
|
|
19 |
st.markdown("""
|
20 |
### Interactive Demo of Google's Gemma 2-2B-IT Model
|
21 |
This app demonstrates the text generation capabilities of Google's Gemma 2-2B-IT model.
|
22 |
-
Enter a prompt below and see the model generate text!
|
23 |
""")
|
24 |
|
25 |
# Function to load model
|
@@ -141,7 +141,7 @@ user_input = st.text_area("Enter your prompt:",
|
|
141 |
height=100,
|
142 |
placeholder="e.g., Write a short story about a robot discovering emotions")
|
143 |
|
144 |
-
def
|
145 |
if not tokenizer or not model:
|
146 |
st.session_state.error_message = "Model not properly loaded. Please check your Hugging Face token."
|
147 |
return None
|
@@ -150,35 +150,71 @@ def generate_text(prompt, max_new_tokens=300, temperature=0.7):
|
|
150 |
# Format the prompt according to Gemma's expected format
|
151 |
formatted_prompt = f"<bos><start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n"
|
152 |
|
153 |
-
# Create the
|
154 |
-
|
155 |
-
|
156 |
-
|
|
|
|
|
|
|
157 |
|
158 |
# Tokenize the input
|
159 |
-
|
160 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
161 |
|
162 |
-
#
|
163 |
-
|
164 |
-
|
165 |
|
166 |
-
#
|
167 |
-
|
168 |
-
|
169 |
-
max_new_tokens=max_new_tokens,
|
170 |
-
do_sample=True,
|
171 |
-
temperature=temperature,
|
172 |
-
pad_token_id=tokenizer.eos_token_id
|
173 |
-
)
|
174 |
|
175 |
-
#
|
176 |
-
|
177 |
-
generated_text =
|
178 |
-
|
179 |
-
# Display the result
|
180 |
-
output_area.markdown(f"**Generated Response:**\n\n{generated_text}")
|
181 |
-
status_text.text("Generation complete!")
|
182 |
|
183 |
return generated_text
|
184 |
|
@@ -232,19 +268,15 @@ if st.button("Generate Text"):
|
|
232 |
st.error("Hugging Face token is required! Please add your token as described above.")
|
233 |
elif user_input:
|
234 |
st.session_state.user_prompt = user_input
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
st.session_state.generation_complete = True
|
240 |
else:
|
241 |
st.error("Please enter a prompt first!")
|
242 |
|
243 |
-
#
|
244 |
-
if st.session_state.generation_complete and not st.session_state.error_message:
|
245 |
-
st.markdown("### Generated Text")
|
246 |
-
st.markdown(st.session_state.generated_text)
|
247 |
-
|
248 |
# Analysis section
|
249 |
with st.expander("Text Analysis"):
|
250 |
col1, col2 = st.columns(2)
|
|
|
19 |
st.markdown("""
|
20 |
### Interactive Demo of Google's Gemma 2-2B-IT Model
|
21 |
This app demonstrates the text generation capabilities of Google's Gemma 2-2B-IT model.
|
22 |
+
Enter a prompt below and see the model generate text in real-time!
|
23 |
""")
|
24 |
|
25 |
# Function to load model
|
|
|
141 |
height=100,
|
142 |
placeholder="e.g., Write a short story about a robot discovering emotions")
|
143 |
|
144 |
+
def generate_text_streaming(prompt, max_new_tokens=300, temperature=0.7):
|
145 |
if not tokenizer or not model:
|
146 |
st.session_state.error_message = "Model not properly loaded. Please check your Hugging Face token."
|
147 |
return None
|
|
|
150 |
# Format the prompt according to Gemma's expected format
|
151 |
formatted_prompt = f"<bos><start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n"
|
152 |
|
153 |
+
# Create the output area
|
154 |
+
output_container = st.empty()
|
155 |
+
response_area = st.container()
|
156 |
+
|
157 |
+
with response_area:
|
158 |
+
st.markdown("**Generated Response:**")
|
159 |
+
response_text = st.empty()
|
160 |
|
161 |
# Tokenize the input
|
162 |
+
encoding = tokenizer(formatted_prompt, return_tensors="pt")
|
163 |
+
|
164 |
+
# Move to the appropriate device
|
165 |
+
if torch.cuda.is_available():
|
166 |
+
encoding = {k: v.to("cuda") for k, v in encoding.items()}
|
167 |
+
|
168 |
+
# Store the length of the input to track new tokens
|
169 |
+
input_length = encoding["input_ids"].shape[1]
|
170 |
+
|
171 |
+
# Initialize generated text container
|
172 |
+
generated_text = ""
|
173 |
+
|
174 |
+
# Generate tokens with streaming
|
175 |
+
generated_ids = []
|
176 |
+
|
177 |
+
# Set up generation configuration
|
178 |
+
for _ in range(max_new_tokens):
|
179 |
+
with torch.no_grad():
|
180 |
+
if len(generated_ids) == 0:
|
181 |
+
# First token generation
|
182 |
+
outputs = model.generate(
|
183 |
+
**encoding,
|
184 |
+
max_new_tokens=1,
|
185 |
+
do_sample=True,
|
186 |
+
temperature=temperature,
|
187 |
+
pad_token_id=tokenizer.eos_token_id,
|
188 |
+
return_dict_in_generate=True,
|
189 |
+
output_scores=False
|
190 |
+
)
|
191 |
+
next_token_id = outputs.sequences[0, input_length:input_length+1]
|
192 |
+
else:
|
193 |
+
# Subsequent tokens
|
194 |
+
current_input_ids = torch.cat([encoding["input_ids"], torch.tensor([generated_ids], device=encoding["input_ids"].device)], dim=1)
|
195 |
+
outputs = model.generate(
|
196 |
+
input_ids=current_input_ids,
|
197 |
+
max_new_tokens=1,
|
198 |
+
do_sample=True,
|
199 |
+
temperature=temperature,
|
200 |
+
pad_token_id=tokenizer.eos_token_id,
|
201 |
+
return_dict_in_generate=True,
|
202 |
+
output_scores=False
|
203 |
+
)
|
204 |
+
next_token_id = outputs.sequences[0, -1].unsqueeze(0)
|
205 |
|
206 |
+
# Convert to Python list and append
|
207 |
+
next_token_id_list = next_token_id.tolist()
|
208 |
+
generated_ids.extend(next_token_id_list)
|
209 |
|
210 |
+
# Check for EOS token
|
211 |
+
if tokenizer.eos_token_id in next_token_id_list:
|
212 |
+
break
|
|
|
|
|
|
|
|
|
|
|
213 |
|
214 |
+
# Decode the tokens generated so far and update the displayed text
|
215 |
+
current_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
|
216 |
+
generated_text = current_text
|
217 |
+
response_text.markdown(generated_text)
|
|
|
|
|
|
|
218 |
|
219 |
return generated_text
|
220 |
|
|
|
268 |
st.error("Hugging Face token is required! Please add your token as described above.")
|
269 |
elif user_input:
|
270 |
st.session_state.user_prompt = user_input
|
271 |
+
result = generate_text_streaming(user_input, max_length, temperature)
|
272 |
+
if result is not None: # Only set if no error occurred
|
273 |
+
st.session_state.generated_text = result
|
274 |
+
st.session_state.generation_complete = True
|
|
|
275 |
else:
|
276 |
st.error("Please enter a prompt first!")
|
277 |
|
278 |
+
# Analysis section (only show after generation is complete)
|
279 |
+
if st.session_state.generation_complete and not st.session_state.error_message and st.session_state.generated_text:
|
|
|
|
|
|
|
280 |
# Analysis section
|
281 |
with st.expander("Text Analysis"):
|
282 |
col1, col2 = st.columns(2)
|