wuhp commited on
Commit
07a46f8
·
verified ·
1 Parent(s): b9d6d53

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +395 -278
app.py CHANGED
@@ -12,8 +12,6 @@ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(
12
 
13
  # --- Model & Quantization Settings ---
14
  MODEL_ID = "unsloth/DeepSeek-R1-Distill-Qwen-7B-unsloth-bnb-4bit"
15
-
16
- # Dictionaries to store the loaded model and tokenizer
17
  models: Dict[str, AutoModelForCausalLM] = {}
18
  tokenizers: Dict[str, AutoTokenizer] = {}
19
 
@@ -26,7 +24,6 @@ bnb_config_4bit = BitsAndBytesConfig(
26
  def get_model_and_tokenizer() -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
27
  """
28
  Lazy-load the model and tokenizer if not already loaded.
29
-
30
  Returns:
31
  Tuple[model, tokenizer]: The loaded model and tokenizer.
32
  """
@@ -50,46 +47,176 @@ def get_model_and_tokenizer() -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
50
  raise e
51
  return models["7B"], tokenizers["7B"]
52
 
53
-
54
- # --- Default Prompt Templates ---
55
- default_prompt_brainstorm = """**Brainstorming Task (Round 1)**
56
- As a Senior Code Analyst, provide an initial analysis of the problem below.
 
57
 
58
  **User Request:**
59
  {user_prompt}
60
 
61
  **Guidelines:**
62
- 1. Identify key challenges and constraints.
63
- 2. Suggest multiple potential approaches.
64
- 3. Outline any potential edge cases or critical considerations.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  """
 
 
 
 
 
 
 
66
 
67
- default_prompt_code_generation = """**Advanced Reasoning & Code Generation (Round 2)**
68
- Based on the initial analysis below:
 
 
 
 
 
69
 
70
  **Initial Analysis:**
71
  {brainstorm_response}
72
 
73
- **User Request:**
74
  {user_prompt}
75
 
76
  **Task:**
77
- 1. Develop a detailed solution that includes production-ready code.
78
- 2. Explain the reasoning behind the chosen approach.
79
- 3. Incorporate advanced reasoning to handle edge cases.
80
- 4. Provide commented code that is clear and maintainable.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  """
 
 
 
 
82
 
83
- default_prompt_synthesis = """**Synthesis & Final Refinement (Round 3)**
84
- Review the detailed code generation and reasoning below, and produce a final, refined response that:
85
- 1. Synthesizes the brainstorming insights and advanced reasoning.
86
- 2. Provides a concise summary of the solution.
87
- 3. Highlights any potential improvements or considerations.
88
 
89
- **Detailed Response:**
90
- {code_response}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  """
 
 
92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
  # --- Memory Management ---
95
  class MemoryManager:
@@ -98,222 +225,163 @@ class MemoryManager:
98
  self.shared_memory: List[str] = []
99
 
100
  def store(self, item: str) -> None:
101
- """
102
- Store a memory item and log an excerpt.
103
-
104
- Args:
105
- item (str): The memory content to store.
106
- """
107
  self.shared_memory.append(item)
108
  logging.info(f"[Memory Stored]: {item[:50]}...")
109
 
110
  def retrieve(self, query: str, top_k: int = 3) -> List[str]:
111
- """
112
- Retrieve memory items that contain the query text (case-insensitive).
113
-
114
- Args:
115
- query (str): The text query to search for.
116
- top_k (int): Maximum number of memory items to return.
117
-
118
- Returns:
119
- List[str]: A list of up to top_k memory items.
120
- """
121
  query_lower = query.lower()
122
  relevant = [item for item in self.shared_memory if query_lower in item.lower()]
123
  if not relevant:
124
  logging.info("[Memory Retrieval]: No relevant memories found.")
125
  else:
126
  logging.info(f"[Memory Retrieval]: Found {len(relevant)} relevant memories.")
127
- return relevant[:top_k]
128
 
129
- # Create a global memory manager instance for RAG purposes.
130
  global_memory_manager = MemoryManager()
131
 
132
-
133
- # --- Multi-Round Swarm Agent Function ---
134
- @spaces.GPU(duration=180) # Adjust duration as needed
135
- def swarm_agent_iterative(user_prompt: str, temp: float, top_p: float, max_new_tokens: int, memory_top_k: int,
136
- prompt_brainstorm_text: str, prompt_code_generation_text: str, prompt_synthesis_text: str
137
- ) -> Generator[str, None, None]:
138
- """
139
- A three-round iterative process that uses the provided prompt templates:
140
- - Round 1: Brainstorming.
141
- - Round 2: Advanced reasoning & code generation.
142
- - Round 3: Synthesis & refinement.
143
-
144
- This generator yields the response from the final round as it is produced.
145
-
146
- Yields:
147
- str: Progressive updates of the final response.
148
- """
149
- model, tokenizer = get_model_and_tokenizer()
150
-
151
- # ----- Round 1: Brainstorming -----
152
- logging.info("--- Round 1: Brainstorming ---")
153
- prompt_r1 = prompt_brainstorm_text.format(user_prompt=user_prompt)
154
- input_ids_r1 = tokenizer.encode(prompt_r1, return_tensors="pt").to(model.device)
155
- streamer_r1 = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
156
- kwargs_r1 = dict(
157
- input_ids=input_ids_r1,
158
- streamer=streamer_r1,
159
- max_new_tokens=max_new_tokens,
160
- do_sample=True,
161
- temperature=temp,
162
- top_p=top_p,
163
- )
164
- try:
165
- thread_r1 = Thread(target=model.generate, kwargs=kwargs_r1)
166
- with torch.no_grad():
167
- thread_r1.start()
168
- except Exception as e:
169
- logging.error(f"Error starting Round 1 thread: {e}")
170
- raise e
171
-
172
- brainstorm_response = ""
173
- try:
174
- for text in streamer_r1:
175
- logging.info(text)
176
- brainstorm_response += text
177
- except Exception as e:
178
- logging.error(f"Error during Round 1 generation: {e}")
179
- raise e
180
- thread_r1.join()
181
- global_memory_manager.store(f"Brainstorm Response: {brainstorm_response[:200]}...")
182
-
183
- # ----- Round 2: Code Generation -----
184
- logging.info("--- Round 2: Code Generation ---")
185
- prompt_r2 = prompt_code_generation_text.format(
186
- brainstorm_response=brainstorm_response,
187
- user_prompt=user_prompt
188
- )
189
- input_ids_r2 = tokenizer.encode(prompt_r2, return_tensors="pt").to(model.device)
190
- streamer_r2 = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
191
- kwargs_r2 = dict(
192
- input_ids=input_ids_r2,
193
- streamer=streamer_r2,
194
- max_new_tokens=max_new_tokens + 100, # extra tokens for detail
195
- temperature=temp,
196
  top_p=top_p,
 
197
  )
 
 
 
 
198
  try:
199
- thread_r2 = Thread(target=model.generate, kwargs=kwargs_r2)
200
- with torch.no_grad():
201
- thread_r2.start()
202
  except Exception as e:
203
- logging.error(f"Error starting Round 2 thread: {e}")
204
  raise e
 
 
205
 
206
- code_response = ""
207
- try:
208
- for text in streamer_r2:
209
- logging.info(text)
210
- code_response += text
211
- except Exception as e:
212
- logging.error(f"Error during Round 2 generation: {e}")
213
- raise e
214
- thread_r2.join()
215
- global_memory_manager.store(f"Code Generation Response: {code_response[:200]}...")
216
-
217
- # ----- Round 3: Synthesis & Refinement -----
218
- logging.info("--- Round 3: Synthesis & Refinement ---")
219
- prompt_r3 = prompt_synthesis_text.format(code_response=code_response)
220
- input_ids_r3 = tokenizer.encode(prompt_r3, return_tensors="pt").to(model.device)
221
- streamer_r3 = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
222
- kwargs_r3 = dict(
223
- input_ids=input_ids_r3,
224
- streamer=streamer_r3,
225
- max_new_tokens=max_new_tokens // 2,
226
- temperature=temp,
227
- top_p=top_p,
228
- )
229
- try:
230
- thread_r3 = Thread(target=model.generate, kwargs=kwargs_r3)
 
 
 
 
 
 
 
 
 
 
 
 
 
231
  with torch.no_grad():
232
  thread_r3.start()
233
- except Exception as e:
234
- logging.error(f"Error starting Round 3 thread: {e}")
235
- raise e
236
-
237
- final_response = ""
238
- try:
239
- for text in streamer_r3:
240
- logging.info(text)
241
- final_response += text
242
- yield final_response # Yield progressive updates
243
- except Exception as e:
244
- logging.error(f"Error during Round 3 generation: {e}")
245
- raise e
246
- thread_r3.join()
247
- global_memory_manager.store(f"Final Synthesis Response: {final_response[:200]}...")
 
 
 
 
 
 
 
 
 
 
 
 
 
248
 
 
249
 
250
- # --- Explanation Function for Puns ---
251
- def handle_explanation_request(user_prompt: str) -> str:
 
 
252
  """
253
- If the user asks for an explanation of the puns, this function retrieves
254
- relevant stored memory items (which are expected to include pun examples) and
255
- constructs a new prompt to generate a detailed explanation.
256
-
257
- Args:
258
- user_prompt (str): The user request (e.g. "explain the different puns you mentioned")
 
259
 
260
- Returns:
261
- str: The explanation generated by the model.
262
  """
263
- # Retrieve memory items that contain "pun" (assuming previous outputs include puns)
264
- retrieved = global_memory_manager.retrieve("pun", top_k=3)
265
- if not retrieved:
266
- explanation_prompt = "No previous puns found to explain. Please provide the pun examples."
267
- else:
268
- explanation_prompt = "Please explain the following coding puns in detail:\n\n"
269
  for item in retrieved:
270
  explanation_prompt += f"- {item}\n"
271
- explanation_prompt += "\nProvide a detailed explanation for each pun."
272
-
 
 
 
 
 
 
 
273
  model, tokenizer = get_model_and_tokenizer()
274
- input_ids = tokenizer.encode(explanation_prompt, return_tensors="pt").to(model.device)
275
- streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
276
- kwargs = dict(
277
- input_ids=input_ids,
278
- streamer=streamer,
279
- max_new_tokens=300,
280
- temperature=0.7,
281
- top_p=0.9,
282
- )
283
- try:
284
- thread = Thread(target=model.generate, kwargs=kwargs)
285
- with torch.no_grad():
286
- thread.start()
287
- except Exception as e:
288
- logging.error(f"Error starting explanation thread: {e}")
289
- raise e
290
-
291
- explanation = ""
292
- try:
293
- for text in streamer:
294
- explanation += text
295
- except Exception as e:
296
- logging.error(f"Error during explanation generation: {e}")
297
- raise e
298
- thread.join()
299
  return explanation
300
 
301
-
302
  # --- Helper to Format History ---
303
  def format_history(history: List) -> List[Dict[str, str]]:
304
  """
305
- Convert history (which might be a list of [user, assistant] pairs or already formatted dictionaries)
306
- into a list of OpenAI-style message dictionaries.
307
-
308
- Args:
309
- history (List): List of conversation items.
310
-
311
- Returns:
312
- List[Dict[str, str]]: A list of formatted message dictionaries.
313
  """
314
  messages = []
315
  for item in history:
316
- # If item is a list or tuple, try to unpack it if it has exactly 2 elements.
317
  if isinstance(item, (list, tuple)) and len(item) == 2:
318
  user_msg, assistant_msg = item
319
  messages.append({"role": "user", "content": user_msg})
@@ -323,27 +391,14 @@ def format_history(history: List) -> List[Dict[str, str]]:
323
  messages.append(item)
324
  return messages
325
 
326
-
327
  # --- Gradio Chat Interface Function ---
328
  def gradio_interface(message: str, history: List, param_state: Dict, prompt_state: Dict) -> Generator[List[Dict[str, str]], None, None]:
329
  """
330
- This function is called by Gradio's ChatInterface.
331
- It uses the current saved generation parameters and prompt templates.
332
- If the user request appears to ask for an explanation of puns,
333
- it routes the request to the explanation function.
334
-
335
- Args:
336
- message (str): The user message.
337
- history (List): The conversation history.
338
- param_state (Dict): Generation parameters.
339
- prompt_state (Dict): Prompt templates.
340
-
341
- Yields:
342
- Generator[List[Dict[str, str]]]: Updated history in OpenAI-style message dictionaries.
343
  """
344
- # Check if the user is asking to explain puns.
345
- if "explain" in message.lower() and "pun" in message.lower():
346
- explanation = handle_explanation_request(message)
347
  history = history + [[message, explanation]]
348
  yield format_history(history)
349
  return
@@ -353,42 +408,38 @@ def gradio_interface(message: str, history: List, param_state: Dict, prompt_stat
353
  top_p = float(param_state.get("top_p", 0.9))
354
  max_new_tokens = int(param_state.get("max_new_tokens", 300))
355
  memory_top_k = int(param_state.get("memory_top_k", 2))
 
356
  except Exception as e:
357
  logging.error(f"Parameter conversion error: {e}")
358
- temp, top_p, max_new_tokens, memory_top_k = 0.5, 0.9, 300, 2
359
 
360
- prompt_brainstorm_text = prompt_state.get("prompt_brainstorm", default_prompt_brainstorm)
361
- prompt_code_generation_text = prompt_state.get("prompt_code_generation", default_prompt_code_generation)
362
- prompt_synthesis_text = prompt_state.get("prompt_synthesis", default_prompt_synthesis)
363
 
364
- # Append the new user message with an empty assistant reply (as a two-item list)
365
  history = history + [[message, ""]]
366
-
367
- # Call the multi-round agent as a generator (for streaming)
368
  for partial_response in swarm_agent_iterative(
369
  user_prompt=message,
370
  temp=temp,
371
  top_p=top_p,
372
  max_new_tokens=max_new_tokens,
373
  memory_top_k=memory_top_k,
374
- prompt_brainstorm_text=prompt_brainstorm_text,
375
- prompt_code_generation_text=prompt_code_generation_text,
376
- prompt_synthesis_text=prompt_synthesis_text
377
  ):
378
- # Update the last assistant message with the new partial response.
379
  history[-1][1] = partial_response
380
  yield format_history(history)
381
 
382
-
383
  # --- UI Settings & Styling ---
384
  ui_description = '''
385
  <div>
386
  <h1 style="text-align: center;">DeepSeek Agent Swarm Chat</h1>
387
  <p style="text-align: center;">
388
- Multi-round agent:
389
- <br>- Brainstorming
390
- <br>- Advanced reasoning & code generation
391
- <br>- Synthesis & refinement
392
  </p>
393
  </div>
394
  '''
@@ -418,27 +469,24 @@ h1 {
418
  }
419
  """
420
 
421
-
422
  # --- Gradio UI ---
423
  with gr.Blocks(css=css, title="DeepSeek Agent Swarm Chat") as demo:
424
  gr.Markdown(ui_description)
425
-
426
- # Hidden States to hold parameters and prompt configuration
427
  param_state = gr.State({
428
  "temperature": 0.5,
429
  "top_p": 0.9,
430
  "max_new_tokens": 300,
431
  "memory_top_k": 2,
 
432
  })
433
  prompt_state = gr.State({
434
- "prompt_brainstorm": default_prompt_brainstorm,
435
- "prompt_code_generation": default_prompt_code_generation,
436
- "prompt_synthesis": default_prompt_synthesis,
437
  })
438
 
439
- # Create top-level Tabs
440
  with gr.Tabs():
441
- # --- Chat Tab ---
442
  with gr.Tab("Chat"):
443
  chatbot = gr.Chatbot(height=450, placeholder=ui_placeholder, label="Agent Swarm Output", type="messages")
444
  gr.ChatInterface(
@@ -447,59 +495,128 @@ with gr.Blocks(css=css, title="DeepSeek Agent Swarm Chat") as demo:
447
  additional_inputs=[param_state, prompt_state],
448
  examples=[
449
  ['How can we build a robust web service that scales efficiently under load?'],
450
- ['Explain how to design a fault-tolerant distributed system.'],
451
- ['Develop a streamlit app that visualizes real-time financial data.'],
452
- ['Create a pun-filled birthday message with a coding twist.'],
453
- ['Design a system that uses machine learning to optimize resource allocation.']
454
  ],
455
  cache_examples=False,
456
  type="messages",
457
  )
458
-
459
- # --- Parameters Tab ---
460
  with gr.Tab("Parameters"):
461
  gr.Markdown("### Generation Parameters")
462
  temp_slider = gr.Slider(minimum=0.1, maximum=1.0, step=0.1, value=0.5, label="Temperature")
463
  top_p_slider = gr.Slider(minimum=0.01, maximum=1.0, step=0.05, value=0.9, label="Top P")
464
  max_tokens_num = gr.Number(value=300, label="Max new tokens", precision=0)
465
  memory_topk_slider = gr.Slider(minimum=1, maximum=5, step=1, value=2, label="Memory Retrieval Top K")
 
466
  save_params_btn = gr.Button("Save Parameters")
467
  save_params_btn.click(
468
- lambda t, p, m, k: {"temperature": t, "top_p": p, "max_new_tokens": m, "memory_top_k": k},
469
- inputs=[temp_slider, top_p_slider, max_tokens_num, memory_topk_slider],
 
 
 
 
 
 
470
  outputs=param_state,
471
  )
472
-
473
- # --- Prompt Config Tab ---
474
  with gr.Tab("Prompt Config"):
475
- gr.Markdown("### Configure Prompt Templates")
476
- prompt_brainstorm_box = gr.Textbox(
477
- value=default_prompt_brainstorm,
478
- label="Brainstorm Prompt",
479
- lines=8,
480
- )
481
- prompt_code_generation_box = gr.Textbox(
482
- value=default_prompt_code_generation,
483
- label="Code Generation Prompt",
484
- lines=8,
485
- )
486
- prompt_synthesis_box = gr.Textbox(
487
- value=default_prompt_synthesis,
488
- label="Synthesis Prompt",
489
- lines=8,
490
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
491
  save_prompts_btn = gr.Button("Save Prompts")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
492
  save_prompts_btn.click(
493
- lambda b, c, s: {
494
- "prompt_brainstorm": b,
495
- "prompt_code_generation": c,
496
- "prompt_synthesis": s,
497
- },
498
- inputs=[prompt_brainstorm_box, prompt_code_generation_box, prompt_synthesis_box],
499
  outputs=prompt_state,
500
  )
501
-
502
  gr.Markdown(ui_license)
503
 
504
  if __name__ == "__main__":
505
- demo.launch()
 
12
 
13
  # --- Model & Quantization Settings ---
14
  MODEL_ID = "unsloth/DeepSeek-R1-Distill-Qwen-7B-unsloth-bnb-4bit"
 
 
15
  models: Dict[str, AutoModelForCausalLM] = {}
16
  tokenizers: Dict[str, AutoTokenizer] = {}
17
 
 
24
  def get_model_and_tokenizer() -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
25
  """
26
  Lazy-load the model and tokenizer if not already loaded.
 
27
  Returns:
28
  Tuple[model, tokenizer]: The loaded model and tokenizer.
29
  """
 
47
  raise e
48
  return models["7B"], tokenizers["7B"]
49
 
50
+ # --- Default Prompt Templates for Multiple Presets ---
51
+ default_prompts = {
52
+ "coding": {
53
+ "brainstorm": """**Coding Brainstorm (Round 1)**
54
+ As a Senior Code Analyst, analyze the following problem and list key challenges and potential approaches.
55
 
56
  **User Request:**
57
  {user_prompt}
58
 
59
  **Guidelines:**
60
+ 1. Identify coding challenges.
61
+ 2. Suggest potential methods and approaches.
62
+ 3. Highlight any critical edge cases.
63
+ """,
64
+ "round2": """**Advanced Reasoning & Code Generation (Round 2)**
65
+ Based on your initial analysis:
66
+
67
+ **Initial Analysis:**
68
+ {brainstorm_response}
69
+
70
+ **User Request:**
71
+ {user_prompt}
72
+
73
+ **Task:**
74
+ 1. Generate production-ready code with advanced reasoning.
75
+ 2. Include a pun-filled birthday message with a coding twist within your output.
76
+ 3. Comment the code clearly.
77
+ """,
78
+ "synthesis": """**Synthesis & Final Refinement (Round 3)**
79
+ Review the detailed code and reasoning below, and synthesize a final, refined response that:
80
+ 1. Combines the brainstorming insights and advanced code generation.
81
+ 2. Summarizes the solution succinctly.
82
+ 3. Provides any additional improvements.
83
+
84
+ **Detailed Code & Reasoning:**
85
+ {round2_response}
86
+ """,
87
+ "rationale": """**Pun Generation and Rationale (Round 4)**
88
+ Based on the final refined response below, generate a clear, stand-alone pun-filled birthday message with a coding twist, then explain in detail why that pun was chosen.
89
+
90
+ Final Refined Response:
91
+ {final_response}
92
+
93
+ Your answer should:
94
+ 1. Clearly output the pun as a separate line.
95
+ 2. Explain the pun’s connection to birthdays and coding concepts (e.g., binary, syntax).
96
+ 3. Describe any creative insights behind the choice.
97
  """
98
+ },
99
+ "math": {
100
+ "brainstorm": """**Math Problem Brainstorm (Round 1)**
101
+ As an expert mathematician, analyze the following problem and outline key concepts and strategies.
102
+
103
+ **Problem:**
104
+ {user_prompt}
105
 
106
+ **Guidelines:**
107
+ 1. Identify the mathematical concepts involved.
108
+ 2. List potential strategies or methods.
109
+ 3. Note any assumptions or conditions.
110
+ """,
111
+ "round2": """**Solution Strategy Development (Round 2)**
112
+ Based on the initial analysis:
113
 
114
  **Initial Analysis:**
115
  {brainstorm_response}
116
 
117
+ **Problem:**
118
  {user_prompt}
119
 
120
  **Task:**
121
+ 1. Develop a detailed strategy to solve the problem.
122
+ 2. Include potential methods and intermediate steps.
123
+ """,
124
+ "synthesis": """**Solution Synthesis (Round 3)**
125
+ Review the strategy and previous analysis below, and produce a refined, step-by-step solution that:
126
+ 1. Clearly explains the solution path.
127
+ 2. Highlights key steps and justifications.
128
+ 3. Summarizes the final answer.
129
+
130
+ **Detailed Strategy:**
131
+ {round2_response}
132
+ """,
133
+ "rationale": """**Solution Rationale (Round 4)**
134
+ Based on the final refined solution below, provide a detailed explanation of the key steps and mathematical insights.
135
+
136
+ Final Refined Solution:
137
+ {final_response}
138
+
139
+ Your response should:
140
+ 1. Clearly explain why each step was taken.
141
+ 2. Detail any assumptions and mathematical principles used.
142
+ 3. Summarize the creative reasoning behind the solution.
143
  """
144
+ },
145
+ "writing": {
146
+ "brainstorm": """**Creative Brainstorm (Round 1)**
147
+ As a seasoned writer, brainstorm creative ideas for the following writing prompt.
148
 
149
+ **Writing Prompt:**
150
+ {user_prompt}
 
 
 
151
 
152
+ **Guidelines:**
153
+ 1. List key themes and creative directions.
154
+ 2. Suggest multiple approaches to the narrative.
155
+ 3. Highlight any unique stylistic ideas.
156
+ """,
157
+ "round2": """**Outline Generation (Round 2)**
158
+ Based on the brainstorming below:
159
+
160
+ **Brainstormed Ideas:**
161
+ {brainstorm_response}
162
+
163
+ **Writing Prompt:**
164
+ {user_prompt}
165
+
166
+ **Task:**
167
+ 1. Generate a detailed outline for a creative piece.
168
+ 2. Organize the ideas into a coherent structure.
169
+ 3. Provide bullet points or sections for the narrative.
170
+ """,
171
+ "synthesis": """**Draft Writing (Round 3)**
172
+ Review the outline below and produce a refined draft of the creative piece that:
173
+ 1. Synthesizes the brainstorming insights and the outline.
174
+ 2. Provides a coherent and engaging narrative.
175
+ 3. Includes stylistic and thematic elements.
176
+
177
+ **Outline:**
178
+ {round2_response}
179
+ """,
180
+ "rationale": """**Final Editing and Rationale (Round 4)**
181
+ Based on the final draft below, refine the piece further and provide a detailed explanation of your creative choices.
182
+
183
+ Final Draft:
184
+ {final_response}
185
+
186
+ Your answer should:
187
+ 1. Present the final refined text.
188
+ 2. Explain the narrative choices, stylistic decisions, and thematic connections.
189
+ 3. Detail any creative insights that influenced the final version.
190
  """
191
+ }
192
+ }
193
 
194
+ # --- Domain Detection ---
195
+ def detect_domain(user_prompt: str) -> str:
196
+ """
197
+ Detect the domain based on keywords.
198
+ Args:
199
+ user_prompt (str): The user query.
200
+ Returns:
201
+ str: One of 'math', 'writing', or 'coding' (defaulting to coding).
202
+ """
203
+ prompt_lower = user_prompt.lower()
204
+ math_keywords = ["solve", "integral", "derivative", "equation", "proof", "calculate", "sum", "product"]
205
+ writing_keywords = ["write", "story", "essay", "novel", "poem", "article", "narrative", "creative"]
206
+ coding_keywords = ["code", "program", "debug", "compile", "algorithm", "function"]
207
+
208
+ if any(kw in prompt_lower for kw in math_keywords):
209
+ logging.info("Domain detected as: math")
210
+ return "math"
211
+ elif any(kw in prompt_lower for kw in writing_keywords):
212
+ logging.info("Domain detected as: writing")
213
+ return "writing"
214
+ elif any(kw in prompt_lower for kw in coding_keywords):
215
+ logging.info("Domain detected as: coding")
216
+ return "coding"
217
+ else:
218
+ logging.info("No specific domain detected; defaulting to coding")
219
+ return "coding"
220
 
221
  # --- Memory Management ---
222
  class MemoryManager:
 
225
  self.shared_memory: List[str] = []
226
 
227
  def store(self, item: str) -> None:
228
+ """Store a memory item and log an excerpt."""
 
 
 
 
 
229
  self.shared_memory.append(item)
230
  logging.info(f"[Memory Stored]: {item[:50]}...")
231
 
232
  def retrieve(self, query: str, top_k: int = 3) -> List[str]:
233
+ """Retrieve recent memory items containing the query text."""
 
 
 
 
 
 
 
 
 
234
  query_lower = query.lower()
235
  relevant = [item for item in self.shared_memory if query_lower in item.lower()]
236
  if not relevant:
237
  logging.info("[Memory Retrieval]: No relevant memories found.")
238
  else:
239
  logging.info(f"[Memory Retrieval]: Found {len(relevant)} relevant memories.")
240
+ return relevant[-top_k:]
241
 
 
242
  global_memory_manager = MemoryManager()
243
 
244
+ # --- Unified Generation Function ---
245
+ def generate_response(model, tokenizer, prompt: str, max_tokens: int, temperature: float, top_p: float) -> str:
246
+ """Generate a response for a given prompt."""
247
+ input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
248
+ streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
249
+ kwargs = dict(
250
+ input_ids=input_ids,
251
+ streamer=streamer,
252
+ max_new_tokens=max_tokens,
253
+ temperature=temperature,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
  top_p=top_p,
255
+ do_sample=True,
256
  )
257
+ thread = Thread(target=model.generate, kwargs=kwargs)
258
+ with torch.no_grad():
259
+ thread.start()
260
+ response = ""
261
  try:
262
+ for text in streamer:
263
+ response += text
 
264
  except Exception as e:
265
+ logging.error(f"Error during generation: {e}")
266
  raise e
267
+ thread.join()
268
+ return response
269
 
270
+ # --- Multi-Round Agent Class ---
271
+ class MultiRoundAgent:
272
+ """
273
+ Encapsulate the multi-round prompt chaining and response generation.
274
+ This class runs a 4-round pipeline based on the given preset.
275
+ """
276
+ def __init__(self, model, tokenizer, prompt_templates: Dict[str, str], memory_manager: MemoryManager):
277
+ self.model = model
278
+ self.tokenizer = tokenizer
279
+ self.prompt_templates = prompt_templates
280
+ self.memory_manager = memory_manager
281
+
282
+ def run_pipeline(self, user_prompt: str, params: Dict, show_raw: bool = False) -> Generator[str, None, None]:
283
+ # Round 1: Brainstorming / Analysis
284
+ logging.info("--- Round 1 ---")
285
+ prompt_r1 = self.prompt_templates["brainstorm"].format(user_prompt=user_prompt)
286
+ r1 = generate_response(self.model, self.tokenizer, prompt_r1, params.get("max_new_tokens"), params.get("temp"), params.get("top_p"))
287
+ self.memory_manager.store(f"Round 1 Response: {r1}")
288
+
289
+ # Round 2: Secondary Generation (strategy/outline/code)
290
+ logging.info("--- Round 2 ---")
291
+ prompt_r2 = self.prompt_templates["round2"].format(brainstorm_response=r1, user_prompt=user_prompt)
292
+ r2 = generate_response(self.model, self.tokenizer, prompt_r2, params.get("max_new_tokens") + 100, params.get("temp"), params.get("top_p"))
293
+ self.memory_manager.store(f"Round 2 Response: {r2}")
294
+
295
+ # Round 3: Synthesis & Refinement (streaming updates)
296
+ logging.info("--- Round 3 ---")
297
+ prompt_r3 = self.prompt_templates["synthesis"].format(round2_response=r2)
298
+ input_ids_r3 = self.tokenizer.encode(prompt_r3, return_tensors="pt").to(self.model.device)
299
+ streamer_r3 = TextIteratorStreamer(self.tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
300
+ kwargs_r3 = dict(
301
+ input_ids=input_ids_r3,
302
+ streamer=streamer_r3,
303
+ max_new_tokens=params.get("max_new_tokens") // 2,
304
+ temperature=params.get("temp"),
305
+ top_p=params.get("top_p")
306
+ )
307
+ thread_r3 = Thread(target=self.model.generate, kwargs=kwargs_r3)
308
  with torch.no_grad():
309
  thread_r3.start()
310
+ r3 = ""
311
+ try:
312
+ for text in streamer_r3:
313
+ r3 += text
314
+ yield r3 # Yield progressive updates from Round 3
315
+ except Exception as e:
316
+ logging.error(f"Error during Round 3 streaming: {e}")
317
+ raise e
318
+ thread_r3.join()
319
+ self.memory_manager.store(f"Final Synthesis Response: {r3}")
320
+
321
+ # Round 4: Rationale / Final Output
322
+ logging.info("--- Round 4 ---")
323
+ prompt_r4 = self.prompt_templates["rationale"].format(final_response=r3)
324
+ r4 = generate_response(self.model, self.tokenizer, prompt_r4, 300, params.get("temp"), params.get("top_p"))
325
+ self.memory_manager.store(f"Round 4 Response: {r4}")
326
+
327
+ # Construct final output based on the show_raw flag.
328
+ if show_raw:
329
+ final_output = (
330
+ f"{r4}\n\n[Raw Outputs]\n"
331
+ f"Round 1:\n{r1}\n\n"
332
+ f"Round 2:\n{r2}\n\n"
333
+ f"Round 3:\n{r3}\n\n"
334
+ f"Round 4:\n{r4}\n"
335
+ )
336
+ else:
337
+ final_output = r4
338
 
339
+ yield final_output
340
 
341
+ # --- Swarm Agent Iterative Function ---
342
+ @spaces.GPU(duration=180) # Adjust duration as needed
343
+ def swarm_agent_iterative(user_prompt: str, temp: float, top_p: float, max_new_tokens: int, memory_top_k: int,
344
+ prompt_templates: Dict[str, str], domain: str, show_raw: bool) -> Generator[str, None, None]:
345
  """
346
+ Wraps the multi-round agent functionality. Depending on the detected domain,
347
+ it runs the 4-round pipeline.
348
+ """
349
+ model, tokenizer = get_model_and_tokenizer()
350
+ agent = MultiRoundAgent(model, tokenizer, prompt_templates, global_memory_manager)
351
+ params = {"temp": temp, "top_p": top_p, "max_new_tokens": max_new_tokens}
352
+ return agent.run_pipeline(user_prompt, params, show_raw)
353
 
354
+ # --- Explanation Function for Additional Requests ---
355
+ def handle_explanation_request(user_prompt: str, history: List) -> str:
356
  """
357
+ Retrieve stored rationale and additional context from conversation history,
358
+ then generate an explanation.
359
+ """
360
+ retrieved = global_memory_manager.retrieve("Round 4 Response:", top_k=3)
361
+ explanation_prompt = "Below are previous final outputs and related context from our conversation:\n"
362
+ if retrieved:
363
  for item in retrieved:
364
  explanation_prompt += f"- {item}\n"
365
+ else:
366
+ explanation_prompt += "No stored final output found.\n"
367
+
368
+ explanation_prompt += "\nRecent related exchanges:\n"
369
+ for chat in history:
370
+ if ("explain" in chat[0].lower()) or (chat[1] and "explain" in chat[1].lower()):
371
+ explanation_prompt += f"User: {chat[0]}\nAssistant: {chat[1]}\n"
372
+
373
+ explanation_prompt += "\nBased on the above context, please provide a detailed explanation of the creative choices."
374
  model, tokenizer = get_model_and_tokenizer()
375
+ explanation = generate_response(model, tokenizer, explanation_prompt, 300, 0.7, 0.9)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
376
  return explanation
377
 
 
378
  # --- Helper to Format History ---
379
  def format_history(history: List) -> List[Dict[str, str]]:
380
  """
381
+ Convert history (list of [user, assistant] pairs) into a list of message dictionaries.
 
 
 
 
 
 
 
382
  """
383
  messages = []
384
  for item in history:
 
385
  if isinstance(item, (list, tuple)) and len(item) == 2:
386
  user_msg, assistant_msg = item
387
  messages.append({"role": "user", "content": user_msg})
 
391
  messages.append(item)
392
  return messages
393
 
 
394
  # --- Gradio Chat Interface Function ---
395
  def gradio_interface(message: str, history: List, param_state: Dict, prompt_state: Dict) -> Generator[List[Dict[str, str]], None, None]:
396
  """
397
+ Called by Gradio's ChatInterface. Uses current generation parameters and preset prompt templates.
398
+ If the user asks for an explanation, routes the request accordingly.
 
 
 
 
 
 
 
 
 
 
 
399
  """
400
+ if "explain" in message.lower():
401
+ explanation = handle_explanation_request(message, history)
 
402
  history = history + [[message, explanation]]
403
  yield format_history(history)
404
  return
 
408
  top_p = float(param_state.get("top_p", 0.9))
409
  max_new_tokens = int(param_state.get("max_new_tokens", 300))
410
  memory_top_k = int(param_state.get("memory_top_k", 2))
411
+ show_raw = bool(param_state.get("show_raw_output", False))
412
  except Exception as e:
413
  logging.error(f"Parameter conversion error: {e}")
414
+ temp, top_p, max_new_tokens, memory_top_k, show_raw = 0.5, 0.9, 300, 2, False
415
 
416
+ domain = detect_domain(message)
417
+ # Get the prompt templates for the detected domain; default to coding if not set.
418
+ prompt_templates = prompt_state.get(domain, default_prompts.get(domain, default_prompts["coding"]))
419
 
 
420
  history = history + [[message, ""]]
 
 
421
  for partial_response in swarm_agent_iterative(
422
  user_prompt=message,
423
  temp=temp,
424
  top_p=top_p,
425
  max_new_tokens=max_new_tokens,
426
  memory_top_k=memory_top_k,
427
+ prompt_templates=prompt_templates,
428
+ domain=domain,
429
+ show_raw=show_raw
430
  ):
 
431
  history[-1][1] = partial_response
432
  yield format_history(history)
433
 
 
434
  # --- UI Settings & Styling ---
435
  ui_description = '''
436
  <div>
437
  <h1 style="text-align: center;">DeepSeek Agent Swarm Chat</h1>
438
  <p style="text-align: center;">
439
+ Multi-round agent with 4-round prompt chaining for three presets:
440
+ <br>- Coding
441
+ <br>- Math
442
+ <br>- Writing
443
  </p>
444
  </div>
445
  '''
 
469
  }
470
  """
471
 
 
472
  # --- Gradio UI ---
473
  with gr.Blocks(css=css, title="DeepSeek Agent Swarm Chat") as demo:
474
  gr.Markdown(ui_description)
475
+ # Hidden states for parameters and prompt configurations.
 
476
  param_state = gr.State({
477
  "temperature": 0.5,
478
  "top_p": 0.9,
479
  "max_new_tokens": 300,
480
  "memory_top_k": 2,
481
+ "show_raw_output": False, # New parameter for raw output
482
  })
483
  prompt_state = gr.State({
484
+ "coding": default_prompts["coding"],
485
+ "math": default_prompts["math"],
486
+ "writing": default_prompts["writing"],
487
  })
488
 
 
489
  with gr.Tabs():
 
490
  with gr.Tab("Chat"):
491
  chatbot = gr.Chatbot(height=450, placeholder=ui_placeholder, label="Agent Swarm Output", type="messages")
492
  gr.ChatInterface(
 
495
  additional_inputs=[param_state, prompt_state],
496
  examples=[
497
  ['How can we build a robust web service that scales efficiently under load?'],
498
+ ['Solve the integral of x^2 from 0 to 1.'],
499
+ ['Write a short story about a mysterious writer in a busy city.'],
500
+ ['Create a pun-filled birthday message with a coding twist.']
 
501
  ],
502
  cache_examples=False,
503
  type="messages",
504
  )
 
 
505
  with gr.Tab("Parameters"):
506
  gr.Markdown("### Generation Parameters")
507
  temp_slider = gr.Slider(minimum=0.1, maximum=1.0, step=0.1, value=0.5, label="Temperature")
508
  top_p_slider = gr.Slider(minimum=0.01, maximum=1.0, step=0.05, value=0.9, label="Top P")
509
  max_tokens_num = gr.Number(value=300, label="Max new tokens", precision=0)
510
  memory_topk_slider = gr.Slider(minimum=1, maximum=5, step=1, value=2, label="Memory Retrieval Top K")
511
+ show_raw_checkbox = gr.Checkbox(value=False, label="Show Raw Output") # New checkbox for raw output
512
  save_params_btn = gr.Button("Save Parameters")
513
  save_params_btn.click(
514
+ lambda t, p, m, k, s: {
515
+ "temperature": t,
516
+ "top_p": p,
517
+ "max_new_tokens": m,
518
+ "memory_top_k": k,
519
+ "show_raw_output": s
520
+ },
521
+ inputs=[temp_slider, top_p_slider, max_tokens_num, memory_topk_slider, show_raw_checkbox],
522
  outputs=param_state,
523
  )
 
 
524
  with gr.Tab("Prompt Config"):
525
+ gr.Markdown("### Configure Prompt Templates for Each Preset")
526
+ with gr.Tabs():
527
+ with gr.Tab("Coding"):
528
+ prompt_brainstorm_box_code = gr.Textbox(
529
+ value=default_prompts["coding"]["brainstorm"],
530
+ label="Brainstorm Prompt (Coding)",
531
+ lines=8,
532
+ )
533
+ prompt_round2_box_code = gr.Textbox(
534
+ value=default_prompts["coding"]["round2"],
535
+ label="Round 2 Prompt (Coding)",
536
+ lines=8,
537
+ )
538
+ prompt_synthesis_box_code = gr.Textbox(
539
+ value=default_prompts["coding"]["synthesis"],
540
+ label="Synthesis Prompt (Coding)",
541
+ lines=8,
542
+ )
543
+ prompt_rationale_box_code = gr.Textbox(
544
+ value=default_prompts["coding"]["rationale"],
545
+ label="Rationale Prompt (Coding)",
546
+ lines=8,
547
+ )
548
+ with gr.Tab("Math"):
549
+ prompt_brainstorm_box_math = gr.Textbox(
550
+ value=default_prompts["math"]["brainstorm"],
551
+ label="Brainstorm Prompt (Math)",
552
+ lines=8,
553
+ )
554
+ prompt_round2_box_math = gr.Textbox(
555
+ value=default_prompts["math"]["round2"],
556
+ label="Round 2 Prompt (Math)",
557
+ lines=8,
558
+ )
559
+ prompt_synthesis_box_math = gr.Textbox(
560
+ value=default_prompts["math"]["synthesis"],
561
+ label="Synthesis Prompt (Math)",
562
+ lines=8,
563
+ )
564
+ prompt_rationale_box_math = gr.Textbox(
565
+ value=default_prompts["math"]["rationale"],
566
+ label="Rationale Prompt (Math)",
567
+ lines=8,
568
+ )
569
+ with gr.Tab("Writing"):
570
+ prompt_brainstorm_box_writing = gr.Textbox(
571
+ value=default_prompts["writing"]["brainstorm"],
572
+ label="Brainstorm Prompt (Writing)",
573
+ lines=8,
574
+ )
575
+ prompt_round2_box_writing = gr.Textbox(
576
+ value=default_prompts["writing"]["round2"],
577
+ label="Round 2 Prompt (Writing)",
578
+ lines=8,
579
+ )
580
+ prompt_synthesis_box_writing = gr.Textbox(
581
+ value=default_prompts["writing"]["synthesis"],
582
+ label="Synthesis Prompt (Writing)",
583
+ lines=8,
584
+ )
585
+ prompt_rationale_box_writing = gr.Textbox(
586
+ value=default_prompts["writing"]["rationale"],
587
+ label="Rationale Prompt (Writing)",
588
+ lines=8,
589
+ )
590
  save_prompts_btn = gr.Button("Save Prompts")
591
+ def save_prompts(code_brain, code_r2, code_syn, code_rat, math_brain, math_r2, math_syn, math_rat, writing_brain, writing_r2, writing_syn, writing_rat):
592
+ return {
593
+ "coding": {
594
+ "brainstorm": code_brain,
595
+ "round2": code_r2,
596
+ "synthesis": code_syn,
597
+ "rationale": code_rat,
598
+ },
599
+ "math": {
600
+ "brainstorm": math_brain,
601
+ "round2": math_r2,
602
+ "synthesis": math_syn,
603
+ "rationale": math_rat,
604
+ },
605
+ "writing": {
606
+ "brainstorm": writing_brain,
607
+ "round2": writing_r2,
608
+ "synthesis": writing_syn,
609
+ "rationale": writing_rat,
610
+ }
611
+ }
612
  save_prompts_btn.click(
613
+ save_prompts,
614
+ inputs=[prompt_brainstorm_box_code, prompt_round2_box_code, prompt_synthesis_box_code, prompt_rationale_box_code,
615
+ prompt_brainstorm_box_math, prompt_round2_box_math, prompt_synthesis_box_math, prompt_rationale_box_math,
616
+ prompt_brainstorm_box_writing, prompt_round2_box_writing, prompt_synthesis_box_writing, prompt_rationale_box_writing],
 
 
617
  outputs=prompt_state,
618
  )
 
619
  gr.Markdown(ui_license)
620
 
621
  if __name__ == "__main__":
622
+ demo.launch()