ReallyFloppyPenguin commited on
Commit
f351ad2
·
verified ·
1 Parent(s): 230e3e0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -46
app.py CHANGED
@@ -3,7 +3,7 @@ import json
3
  import tempfile
4
  import os
5
  import re # For parsing conversation
6
- from typing import Union, Optional, Dict # Import Dict
7
  # Import the actual functions from synthgen
8
  from synthgen import (
9
  generate_synthetic_text,
@@ -154,7 +154,7 @@ def generate_prompts_ui(
154
 
155
  # --- Modified Generation Wrappers ---
156
 
157
- # Wrapper for text generation + JSON preparation - RETURNS DICT
158
  def run_generation_and_prepare_json(
159
  prompt: str,
160
  model: str,
@@ -162,31 +162,18 @@ def run_generation_and_prepare_json(
162
  temperature: float,
163
  top_p: float,
164
  max_tokens: int
165
- ) -> Dict[gr.Textbox, str]: # Return type hint (optional but good practice)
166
  """Generates text samples and prepares a JSON file for download."""
167
  # Handle optional settings
168
  temp_val = temperature if temperature > 0 else None
169
  top_p_val = top_p if 0 < top_p <= 1 else None
170
  max_tokens_val = max_tokens if max_tokens > 0 else None
171
 
172
- # Define component objects used in return dict keys - MUST MATCH OUTPUTS
173
- # This requires the components to be defined *before* this function,
174
- # which isn't the case. So we cannot use component objects as keys here.
175
- # Gradio handles mapping if the keys are strings matching component labels
176
- # OR if we return gr.update targeting components.
177
- # Let's return explicit gr.update for clarity and robustness.
178
-
179
  if not prompt:
180
- # Return updates for both outputs
181
- return {
182
- output_text: gr.update(value="Error: Please enter a prompt."),
183
- download_file_text: gr.update(value=None) # Clear file output
184
- }
185
  if num_samples <= 0:
186
- return {
187
- output_text: gr.update(value="Error: Number of samples must be positive."),
188
- download_file_text: gr.update(value=None)
189
- }
190
 
191
  output_str = f"Generating {num_samples} samples using model '{model}'...\n"
192
  output_str += f"(Settings: Temp={temp_val}, Top-P={top_p_val}, MaxTokens={max_tokens_val})\n"
@@ -205,14 +192,11 @@ def run_generation_and_prepare_json(
205
  output_str += "="*20 + "\nGeneration complete (check results above for errors)."
206
  json_filepath = create_json_file(results_list, "text_samples.json")
207
 
208
- # Return dictionary mapping components to updates
209
- return {
210
- output_text: gr.update(value=output_str),
211
- download_file_text: gr.update(value=json_filepath) # Update file path
212
- }
213
 
214
 
215
- # Wrapper for conversation generation + JSON preparation - RETURNS DICT
216
  def run_conversation_generation_and_prepare_json(
217
  system_prompts_text: str,
218
  model: str,
@@ -220,32 +204,21 @@ def run_conversation_generation_and_prepare_json(
220
  temperature: float,
221
  top_p: float,
222
  max_tokens: int
223
- ) -> Dict[gr.Textbox, str]: # Return type hint (optional)
224
  """Generates conversations and prepares a JSON file for download."""
225
  temp_val = temperature if temperature > 0 else None
226
  top_p_val = top_p if 0 < top_p <= 1 else None
227
  max_tokens_val = max_tokens if max_tokens > 0 else None
228
 
229
- # Define component objects used in return dict keys - requires components defined first.
230
- # Using explicit gr.update instead.
231
-
232
  if not system_prompts_text:
233
- return {
234
- output_conv: gr.update(value="Error: Please enter or generate at least one system prompt/topic."),
235
- download_file_conv: gr.update(value=None)
236
- }
237
  if num_turns <= 0:
238
- return {
239
- output_conv: gr.update(value="Error: Number of turns must be positive."),
240
- download_file_conv: gr.update(value=None)
241
- }
242
 
243
  prompts = [p.strip() for p in system_prompts_text.strip().split('\n') if p.strip()]
244
  if not prompts:
245
- return {
246
- output_conv: gr.update(value="Error: No valid prompts found in the input."),
247
- download_file_conv: gr.update(value=None)
248
- }
249
 
250
  output_str = f"Generating {len(prompts)} conversations ({num_turns} turns each) using model '{model}'...\n"
251
  output_str += f"(Settings: Temp={temp_val}, Top-P={top_p_val}, MaxTokens={max_tokens_val})\n"
@@ -275,11 +248,8 @@ def run_conversation_generation_and_prepare_json(
275
  output_str += "="*40 + "\nGeneration complete (check results above for errors)."
276
  json_filepath = create_json_file(results_list_structured, "conversations.json")
277
 
278
- # Return dictionary mapping components to updates
279
- return {
280
- output_conv: gr.update(value=output_str),
281
- download_file_conv: gr.update(value=json_filepath)
282
- }
283
 
284
 
285
  # --- Gradio Interface Definition ---
 
3
  import tempfile
4
  import os
5
  import re # For parsing conversation
6
+ from typing import Union, Optional, Dict, Tuple # Import Dict and Tuple
7
  # Import the actual functions from synthgen
8
  from synthgen import (
9
  generate_synthetic_text,
 
154
 
155
  # --- Modified Generation Wrappers ---
156
 
157
+ # Wrapper for text generation + JSON preparation - RETURNS TUPLE
158
  def run_generation_and_prepare_json(
159
  prompt: str,
160
  model: str,
 
162
  temperature: float,
163
  top_p: float,
164
  max_tokens: int
165
+ ) -> Tuple[gr.update, gr.update]: # Return type hint (optional)
166
  """Generates text samples and prepares a JSON file for download."""
167
  # Handle optional settings
168
  temp_val = temperature if temperature > 0 else None
169
  top_p_val = top_p if 0 < top_p <= 1 else None
170
  max_tokens_val = max_tokens if max_tokens > 0 else None
171
 
172
+ # Handle errors by returning updates for both outputs in a tuple
 
 
 
 
 
 
173
  if not prompt:
174
+ return (gr.update(value="Error: Please enter a prompt."), gr.update(value=None))
 
 
 
 
175
  if num_samples <= 0:
176
+ return (gr.update(value="Error: Number of samples must be positive."), gr.update(value=None))
 
 
 
177
 
178
  output_str = f"Generating {num_samples} samples using model '{model}'...\n"
179
  output_str += f"(Settings: Temp={temp_val}, Top-P={top_p_val}, MaxTokens={max_tokens_val})\n"
 
192
  output_str += "="*20 + "\nGeneration complete (check results above for errors)."
193
  json_filepath = create_json_file(results_list, "text_samples.json")
194
 
195
+ # Return tuple of updates in the order of outputs list
196
+ return (gr.update(value=output_str), gr.update(value=json_filepath))
 
 
 
197
 
198
 
199
+ # Wrapper for conversation generation + JSON preparation - RETURNS TUPLE
200
  def run_conversation_generation_and_prepare_json(
201
  system_prompts_text: str,
202
  model: str,
 
204
  temperature: float,
205
  top_p: float,
206
  max_tokens: int
207
+ ) -> Tuple[gr.update, gr.update]: # Return type hint (optional)
208
  """Generates conversations and prepares a JSON file for download."""
209
  temp_val = temperature if temperature > 0 else None
210
  top_p_val = top_p if 0 < top_p <= 1 else None
211
  max_tokens_val = max_tokens if max_tokens > 0 else None
212
 
213
+ # Handle errors by returning updates for both outputs in a tuple
 
 
214
  if not system_prompts_text:
215
+ return (gr.update(value="Error: Please enter or generate at least one system prompt/topic."), gr.update(value=None))
 
 
 
216
  if num_turns <= 0:
217
+ return (gr.update(value="Error: Number of turns must be positive."), gr.update(value=None))
 
 
 
218
 
219
  prompts = [p.strip() for p in system_prompts_text.strip().split('\n') if p.strip()]
220
  if not prompts:
221
+ return (gr.update(value="Error: No valid prompts found in the input."), gr.update(value=None))
 
 
 
222
 
223
  output_str = f"Generating {len(prompts)} conversations ({num_turns} turns each) using model '{model}'...\n"
224
  output_str += f"(Settings: Temp={temp_val}, Top-P={top_p_val}, MaxTokens={max_tokens_val})\n"
 
248
  output_str += "="*40 + "\nGeneration complete (check results above for errors)."
249
  json_filepath = create_json_file(results_list_structured, "conversations.json")
250
 
251
+ # Return tuple of updates in the order of outputs list
252
+ return (gr.update(value=output_str), gr.update(value=json_filepath))
 
 
 
253
 
254
 
255
  # --- Gradio Interface Definition ---