Update app.py
Browse files
app.py
CHANGED
@@ -4,13 +4,18 @@ import spaces # Import the spaces library
|
|
4 |
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer
|
5 |
import torch
|
6 |
from threading import Thread
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
# --- Model & Quantization Settings ---
|
9 |
MODEL_ID = "unsloth/DeepSeek-R1-Distill-Qwen-7B-unsloth-bnb-4bit"
|
10 |
|
11 |
# Dictionaries to store the loaded model and tokenizer
|
12 |
-
models = {}
|
13 |
-
tokenizers = {}
|
14 |
|
15 |
bnb_config_4bit = BitsAndBytesConfig(
|
16 |
load_in_4bit=True,
|
@@ -18,23 +23,34 @@ bnb_config_4bit = BitsAndBytesConfig(
|
|
18 |
bnb_4bit_compute_dtype=torch.bfloat16, # Or torch.float16 if needed
|
19 |
)
|
20 |
|
21 |
-
def get_model_and_tokenizer():
|
22 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
23 |
if "7B" not in models:
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
|
|
|
|
|
|
|
|
|
|
36 |
return models["7B"], tokenizers["7B"]
|
37 |
|
|
|
38 |
# --- Default Prompt Templates ---
|
39 |
default_prompt_brainstorm = """**Brainstorming Task (Round 1)**
|
40 |
As a Senior Code Analyst, provide an initial analysis of the problem below.
|
@@ -74,50 +90,68 @@ Review the detailed code generation and reasoning below, and produce a final, re
|
|
74 |
{code_response}
|
75 |
"""
|
76 |
|
77 |
-
# --- Shared Memory for Rounds ---
|
78 |
-
shared_memory = []
|
79 |
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
|
85 |
-
def retrieve_from_memory(query, top_k=2):
|
86 |
-
"""
|
87 |
-
Retrieve memory items that contain the query text (case-insensitive).
|
88 |
-
Returns up to top_k items.
|
89 |
-
"""
|
90 |
-
relevant_memories = []
|
91 |
-
query_lower = query.lower()
|
92 |
-
for memory_item in shared_memory:
|
93 |
-
if query_lower in memory_item.lower():
|
94 |
-
relevant_memories.append(memory_item)
|
95 |
-
if not relevant_memories:
|
96 |
-
print("\n[Memory Retrieval]: No relevant memories found.")
|
97 |
-
return []
|
98 |
-
print(f"\n[Memory Retrieval]: Found {len(relevant_memories)} relevant memories.")
|
99 |
-
return relevant_memories[:top_k]
|
100 |
|
101 |
# --- Multi-Round Swarm Agent Function ---
|
102 |
@spaces.GPU(duration=180) # Adjust duration as needed
|
103 |
-
def swarm_agent_iterative(user_prompt, temp, top_p, max_new_tokens, memory_top_k,
|
104 |
-
prompt_brainstorm_text, prompt_code_generation_text, prompt_synthesis_text
|
|
|
105 |
"""
|
106 |
A three-round iterative process that uses the provided prompt templates:
|
107 |
- Round 1: Brainstorming.
|
108 |
- Round 2: Advanced reasoning & code generation.
|
109 |
- Round 3: Synthesis & refinement.
|
|
|
110 |
This generator yields the response from the final round as it is produced.
|
111 |
-
"""
|
112 |
-
global shared_memory
|
113 |
-
shared_memory = [] # Clear shared memory for each new request
|
114 |
|
|
|
|
|
|
|
115 |
model, tokenizer = get_model_and_tokenizer()
|
116 |
|
117 |
# ----- Round 1: Brainstorming -----
|
118 |
-
|
119 |
-
|
120 |
-
input_ids_r1 = tokenizer.encode(
|
121 |
streamer_r1 = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
|
122 |
kwargs_r1 = dict(
|
123 |
input_ids=input_ids_r1,
|
@@ -127,22 +161,32 @@ def swarm_agent_iterative(user_prompt, temp, top_p, max_new_tokens, memory_top_k
|
|
127 |
temperature=temp,
|
128 |
top_p=top_p,
|
129 |
)
|
130 |
-
|
131 |
-
|
|
|
|
|
|
|
|
|
|
|
132 |
|
133 |
brainstorm_response = ""
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
|
|
|
|
|
|
|
|
|
|
138 |
|
139 |
# ----- Round 2: Code Generation -----
|
140 |
-
|
141 |
-
|
142 |
brainstorm_response=brainstorm_response,
|
143 |
user_prompt=user_prompt
|
144 |
)
|
145 |
-
input_ids_r2 = tokenizer.encode(
|
146 |
streamer_r2 = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
|
147 |
kwargs_r2 = dict(
|
148 |
input_ids=input_ids_r2,
|
@@ -151,19 +195,29 @@ def swarm_agent_iterative(user_prompt, temp, top_p, max_new_tokens, memory_top_k
|
|
151 |
temperature=temp,
|
152 |
top_p=top_p,
|
153 |
)
|
154 |
-
|
155 |
-
|
|
|
|
|
|
|
|
|
|
|
156 |
|
157 |
code_response = ""
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
|
|
|
|
|
|
|
|
|
|
162 |
|
163 |
# ----- Round 3: Synthesis & Refinement -----
|
164 |
-
|
165 |
-
|
166 |
-
input_ids_r3 = tokenizer.encode(
|
167 |
streamer_r3 = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
|
168 |
kwargs_r3 = dict(
|
169 |
input_ids=input_ids_r3,
|
@@ -172,58 +226,137 @@ def swarm_agent_iterative(user_prompt, temp, top_p, max_new_tokens, memory_top_k
|
|
172 |
temperature=temp,
|
173 |
top_p=top_p,
|
174 |
)
|
175 |
-
|
176 |
-
|
|
|
|
|
|
|
|
|
|
|
177 |
|
178 |
final_response = ""
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
183 |
|
184 |
-
store_in_memory(f"Final Synthesis Response: {final_response[:200]}...")
|
185 |
|
186 |
# --- Helper to Format History ---
|
187 |
-
def format_history(history):
|
188 |
"""
|
189 |
Convert history (which might be a list of [user, assistant] pairs or already formatted dictionaries)
|
190 |
into a list of OpenAI-style message dictionaries.
|
|
|
|
|
|
|
|
|
|
|
|
|
191 |
"""
|
192 |
messages = []
|
193 |
for item in history:
|
194 |
# If item is a list or tuple, try to unpack it if it has exactly 2 elements.
|
195 |
-
if isinstance(item, (list, tuple)):
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
messages.append({"role": "assistant", "content": assistant_msg})
|
201 |
-
else:
|
202 |
-
# If it doesn't have exactly two items, skip it.
|
203 |
-
continue
|
204 |
elif isinstance(item, dict):
|
205 |
-
# Already formatted message dictionary.
|
206 |
messages.append(item)
|
207 |
-
else:
|
208 |
-
continue
|
209 |
return messages
|
210 |
|
|
|
211 |
# --- Gradio Chat Interface Function ---
|
212 |
-
def gradio_interface(message, history, param_state, prompt_state):
|
213 |
"""
|
214 |
This function is called by Gradio's ChatInterface.
|
215 |
It uses the current saved generation parameters and prompt templates.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
216 |
"""
|
217 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
218 |
try:
|
219 |
temp = float(param_state.get("temperature", 0.5))
|
220 |
top_p = float(param_state.get("top_p", 0.9))
|
221 |
max_new_tokens = int(param_state.get("max_new_tokens", 300))
|
222 |
memory_top_k = int(param_state.get("memory_top_k", 2))
|
223 |
-
except Exception:
|
|
|
224 |
temp, top_p, max_new_tokens, memory_top_k = 0.5, 0.9, 300, 2
|
225 |
|
226 |
-
# Unpack prompt state (with fallback defaults)
|
227 |
prompt_brainstorm_text = prompt_state.get("prompt_brainstorm", default_prompt_brainstorm)
|
228 |
prompt_code_generation_text = prompt_state.get("prompt_code_generation", default_prompt_code_generation)
|
229 |
prompt_synthesis_text = prompt_state.get("prompt_synthesis", default_prompt_synthesis)
|
@@ -244,9 +377,9 @@ def gradio_interface(message, history, param_state, prompt_state):
|
|
244 |
):
|
245 |
# Update the last assistant message with the new partial response.
|
246 |
history[-1][1] = partial_response
|
247 |
-
# Yield the history formatted as OpenAI-style messages.
|
248 |
yield format_history(history)
|
249 |
|
|
|
250 |
# --- UI Settings & Styling ---
|
251 |
ui_description = '''
|
252 |
<div>
|
@@ -285,10 +418,11 @@ h1 {
|
|
285 |
}
|
286 |
"""
|
287 |
|
|
|
288 |
# --- Gradio UI ---
|
289 |
with gr.Blocks(css=css, title="DeepSeek Agent Swarm Chat") as demo:
|
290 |
gr.Markdown(ui_description)
|
291 |
-
|
292 |
# Hidden States to hold parameters and prompt configuration
|
293 |
param_state = gr.State({
|
294 |
"temperature": 0.5,
|
@@ -301,14 +435,12 @@ with gr.Blocks(css=css, title="DeepSeek Agent Swarm Chat") as demo:
|
|
301 |
"prompt_code_generation": default_prompt_code_generation,
|
302 |
"prompt_synthesis": default_prompt_synthesis,
|
303 |
})
|
304 |
-
|
305 |
# Create top-level Tabs
|
306 |
with gr.Tabs():
|
307 |
# --- Chat Tab ---
|
308 |
with gr.Tab("Chat"):
|
309 |
-
# Set type="messages" for OpenAI-style message dictionaries
|
310 |
chatbot = gr.Chatbot(height=450, placeholder=ui_placeholder, label="Agent Swarm Output", type="messages")
|
311 |
-
# Use ChatInterface and pass the hidden states as additional inputs.
|
312 |
gr.ChatInterface(
|
313 |
fn=gradio_interface,
|
314 |
chatbot=chatbot,
|
@@ -323,7 +455,7 @@ with gr.Blocks(css=css, title="DeepSeek Agent Swarm Chat") as demo:
|
|
323 |
cache_examples=False,
|
324 |
type="messages",
|
325 |
)
|
326 |
-
|
327 |
# --- Parameters Tab ---
|
328 |
with gr.Tab("Parameters"):
|
329 |
gr.Markdown("### Generation Parameters")
|
@@ -332,13 +464,12 @@ with gr.Blocks(css=css, title="DeepSeek Agent Swarm Chat") as demo:
|
|
332 |
max_tokens_num = gr.Number(value=300, label="Max new tokens", precision=0)
|
333 |
memory_topk_slider = gr.Slider(minimum=1, maximum=5, step=1, value=2, label="Memory Retrieval Top K")
|
334 |
save_params_btn = gr.Button("Save Parameters")
|
335 |
-
# When the user clicks Save, update the param_state
|
336 |
save_params_btn.click(
|
337 |
lambda t, p, m, k: {"temperature": t, "top_p": p, "max_new_tokens": m, "memory_top_k": k},
|
338 |
inputs=[temp_slider, top_p_slider, max_tokens_num, memory_topk_slider],
|
339 |
outputs=param_state,
|
340 |
)
|
341 |
-
|
342 |
# --- Prompt Config Tab ---
|
343 |
with gr.Tab("Prompt Config"):
|
344 |
gr.Markdown("### Configure Prompt Templates")
|
@@ -358,7 +489,6 @@ with gr.Blocks(css=css, title="DeepSeek Agent Swarm Chat") as demo:
|
|
358 |
lines=8,
|
359 |
)
|
360 |
save_prompts_btn = gr.Button("Save Prompts")
|
361 |
-
# When clicked, update the prompt_state with new values
|
362 |
save_prompts_btn.click(
|
363 |
lambda b, c, s: {
|
364 |
"prompt_brainstorm": b,
|
@@ -368,8 +498,8 @@ with gr.Blocks(css=css, title="DeepSeek Agent Swarm Chat") as demo:
|
|
368 |
inputs=[prompt_brainstorm_box, prompt_code_generation_box, prompt_synthesis_box],
|
369 |
outputs=prompt_state,
|
370 |
)
|
371 |
-
|
372 |
gr.Markdown(ui_license)
|
373 |
|
374 |
if __name__ == "__main__":
|
375 |
-
demo.launch()
|
|
|
4 |
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer
|
5 |
import torch
|
6 |
from threading import Thread
|
7 |
+
import logging
|
8 |
+
from typing import Tuple, List, Dict, Generator
|
9 |
+
|
10 |
+
# --- Logging Configuration ---
|
11 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
12 |
|
13 |
# --- Model & Quantization Settings ---
|
14 |
MODEL_ID = "unsloth/DeepSeek-R1-Distill-Qwen-7B-unsloth-bnb-4bit"
|
15 |
|
16 |
# Dictionaries to store the loaded model and tokenizer
|
17 |
+
models: Dict[str, AutoModelForCausalLM] = {}
|
18 |
+
tokenizers: Dict[str, AutoTokenizer] = {}
|
19 |
|
20 |
bnb_config_4bit = BitsAndBytesConfig(
|
21 |
load_in_4bit=True,
|
|
|
23 |
bnb_4bit_compute_dtype=torch.bfloat16, # Or torch.float16 if needed
|
24 |
)
|
25 |
|
26 |
+
def get_model_and_tokenizer() -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
|
27 |
+
"""
|
28 |
+
Lazy-load the model and tokenizer if not already loaded.
|
29 |
+
|
30 |
+
Returns:
|
31 |
+
Tuple[model, tokenizer]: The loaded model and tokenizer.
|
32 |
+
"""
|
33 |
if "7B" not in models:
|
34 |
+
logging.info(f"Loading 7B model: {MODEL_ID} on demand")
|
35 |
+
try:
|
36 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
|
37 |
+
model = AutoModelForCausalLM.from_pretrained(
|
38 |
+
MODEL_ID,
|
39 |
+
quantization_config=bnb_config_4bit,
|
40 |
+
torch_dtype=torch.bfloat16, # Or torch.float16 if needed
|
41 |
+
device_map='auto',
|
42 |
+
trust_remote_code=True,
|
43 |
+
)
|
44 |
+
model.eval() # Set the model to evaluation mode
|
45 |
+
models["7B"] = model
|
46 |
+
tokenizers["7B"] = tokenizer
|
47 |
+
logging.info("Loaded 7B model on demand.")
|
48 |
+
except Exception as e:
|
49 |
+
logging.error(f"Failed to load model and tokenizer: {e}")
|
50 |
+
raise e
|
51 |
return models["7B"], tokenizers["7B"]
|
52 |
|
53 |
+
|
54 |
# --- Default Prompt Templates ---
|
55 |
default_prompt_brainstorm = """**Brainstorming Task (Round 1)**
|
56 |
As a Senior Code Analyst, provide an initial analysis of the problem below.
|
|
|
90 |
{code_response}
|
91 |
"""
|
92 |
|
|
|
|
|
93 |
|
94 |
+
# --- Memory Management ---
|
95 |
+
class MemoryManager:
|
96 |
+
"""Encapsulate shared memory for storing and retrieving conversation items."""
|
97 |
+
def __init__(self) -> None:
|
98 |
+
self.shared_memory: List[str] = []
|
99 |
+
|
100 |
+
def store(self, item: str) -> None:
|
101 |
+
"""
|
102 |
+
Store a memory item and log an excerpt.
|
103 |
+
|
104 |
+
Args:
|
105 |
+
item (str): The memory content to store.
|
106 |
+
"""
|
107 |
+
self.shared_memory.append(item)
|
108 |
+
logging.info(f"[Memory Stored]: {item[:50]}...")
|
109 |
+
|
110 |
+
def retrieve(self, query: str, top_k: int = 3) -> List[str]:
|
111 |
+
"""
|
112 |
+
Retrieve memory items that contain the query text (case-insensitive).
|
113 |
+
|
114 |
+
Args:
|
115 |
+
query (str): The text query to search for.
|
116 |
+
top_k (int): Maximum number of memory items to return.
|
117 |
+
|
118 |
+
Returns:
|
119 |
+
List[str]: A list of up to top_k memory items.
|
120 |
+
"""
|
121 |
+
query_lower = query.lower()
|
122 |
+
relevant = [item for item in self.shared_memory if query_lower in item.lower()]
|
123 |
+
if not relevant:
|
124 |
+
logging.info("[Memory Retrieval]: No relevant memories found.")
|
125 |
+
else:
|
126 |
+
logging.info(f"[Memory Retrieval]: Found {len(relevant)} relevant memories.")
|
127 |
+
return relevant[:top_k]
|
128 |
+
|
129 |
+
# Create a global memory manager instance for RAG purposes.
|
130 |
+
global_memory_manager = MemoryManager()
|
131 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
|
133 |
# --- Multi-Round Swarm Agent Function ---
|
134 |
@spaces.GPU(duration=180) # Adjust duration as needed
|
135 |
+
def swarm_agent_iterative(user_prompt: str, temp: float, top_p: float, max_new_tokens: int, memory_top_k: int,
|
136 |
+
prompt_brainstorm_text: str, prompt_code_generation_text: str, prompt_synthesis_text: str
|
137 |
+
) -> Generator[str, None, None]:
|
138 |
"""
|
139 |
A three-round iterative process that uses the provided prompt templates:
|
140 |
- Round 1: Brainstorming.
|
141 |
- Round 2: Advanced reasoning & code generation.
|
142 |
- Round 3: Synthesis & refinement.
|
143 |
+
|
144 |
This generator yields the response from the final round as it is produced.
|
|
|
|
|
|
|
145 |
|
146 |
+
Yields:
|
147 |
+
str: Progressive updates of the final response.
|
148 |
+
"""
|
149 |
model, tokenizer = get_model_and_tokenizer()
|
150 |
|
151 |
# ----- Round 1: Brainstorming -----
|
152 |
+
logging.info("--- Round 1: Brainstorming ---")
|
153 |
+
prompt_r1 = prompt_brainstorm_text.format(user_prompt=user_prompt)
|
154 |
+
input_ids_r1 = tokenizer.encode(prompt_r1, return_tensors="pt").to(model.device)
|
155 |
streamer_r1 = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
|
156 |
kwargs_r1 = dict(
|
157 |
input_ids=input_ids_r1,
|
|
|
161 |
temperature=temp,
|
162 |
top_p=top_p,
|
163 |
)
|
164 |
+
try:
|
165 |
+
thread_r1 = Thread(target=model.generate, kwargs=kwargs_r1)
|
166 |
+
with torch.no_grad():
|
167 |
+
thread_r1.start()
|
168 |
+
except Exception as e:
|
169 |
+
logging.error(f"Error starting Round 1 thread: {e}")
|
170 |
+
raise e
|
171 |
|
172 |
brainstorm_response = ""
|
173 |
+
try:
|
174 |
+
for text in streamer_r1:
|
175 |
+
logging.info(text)
|
176 |
+
brainstorm_response += text
|
177 |
+
except Exception as e:
|
178 |
+
logging.error(f"Error during Round 1 generation: {e}")
|
179 |
+
raise e
|
180 |
+
thread_r1.join()
|
181 |
+
global_memory_manager.store(f"Brainstorm Response: {brainstorm_response[:200]}...")
|
182 |
|
183 |
# ----- Round 2: Code Generation -----
|
184 |
+
logging.info("--- Round 2: Code Generation ---")
|
185 |
+
prompt_r2 = prompt_code_generation_text.format(
|
186 |
brainstorm_response=brainstorm_response,
|
187 |
user_prompt=user_prompt
|
188 |
)
|
189 |
+
input_ids_r2 = tokenizer.encode(prompt_r2, return_tensors="pt").to(model.device)
|
190 |
streamer_r2 = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
|
191 |
kwargs_r2 = dict(
|
192 |
input_ids=input_ids_r2,
|
|
|
195 |
temperature=temp,
|
196 |
top_p=top_p,
|
197 |
)
|
198 |
+
try:
|
199 |
+
thread_r2 = Thread(target=model.generate, kwargs=kwargs_r2)
|
200 |
+
with torch.no_grad():
|
201 |
+
thread_r2.start()
|
202 |
+
except Exception as e:
|
203 |
+
logging.error(f"Error starting Round 2 thread: {e}")
|
204 |
+
raise e
|
205 |
|
206 |
code_response = ""
|
207 |
+
try:
|
208 |
+
for text in streamer_r2:
|
209 |
+
logging.info(text)
|
210 |
+
code_response += text
|
211 |
+
except Exception as e:
|
212 |
+
logging.error(f"Error during Round 2 generation: {e}")
|
213 |
+
raise e
|
214 |
+
thread_r2.join()
|
215 |
+
global_memory_manager.store(f"Code Generation Response: {code_response[:200]}...")
|
216 |
|
217 |
# ----- Round 3: Synthesis & Refinement -----
|
218 |
+
logging.info("--- Round 3: Synthesis & Refinement ---")
|
219 |
+
prompt_r3 = prompt_synthesis_text.format(code_response=code_response)
|
220 |
+
input_ids_r3 = tokenizer.encode(prompt_r3, return_tensors="pt").to(model.device)
|
221 |
streamer_r3 = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
|
222 |
kwargs_r3 = dict(
|
223 |
input_ids=input_ids_r3,
|
|
|
226 |
temperature=temp,
|
227 |
top_p=top_p,
|
228 |
)
|
229 |
+
try:
|
230 |
+
thread_r3 = Thread(target=model.generate, kwargs=kwargs_r3)
|
231 |
+
with torch.no_grad():
|
232 |
+
thread_r3.start()
|
233 |
+
except Exception as e:
|
234 |
+
logging.error(f"Error starting Round 3 thread: {e}")
|
235 |
+
raise e
|
236 |
|
237 |
final_response = ""
|
238 |
+
try:
|
239 |
+
for text in streamer_r3:
|
240 |
+
logging.info(text)
|
241 |
+
final_response += text
|
242 |
+
yield final_response # Yield progressive updates
|
243 |
+
except Exception as e:
|
244 |
+
logging.error(f"Error during Round 3 generation: {e}")
|
245 |
+
raise e
|
246 |
+
thread_r3.join()
|
247 |
+
global_memory_manager.store(f"Final Synthesis Response: {final_response[:200]}...")
|
248 |
+
|
249 |
+
|
250 |
+
# --- Explanation Function for Puns ---
|
251 |
+
def handle_explanation_request(user_prompt: str) -> str:
|
252 |
+
"""
|
253 |
+
If the user asks for an explanation of the puns, this function retrieves
|
254 |
+
relevant stored memory items (which are expected to include pun examples) and
|
255 |
+
constructs a new prompt to generate a detailed explanation.
|
256 |
+
|
257 |
+
Args:
|
258 |
+
user_prompt (str): The user request (e.g. "explain the different puns you mentioned")
|
259 |
+
|
260 |
+
Returns:
|
261 |
+
str: The explanation generated by the model.
|
262 |
+
"""
|
263 |
+
# Retrieve memory items that contain "pun" (assuming previous outputs include puns)
|
264 |
+
retrieved = global_memory_manager.retrieve("pun", top_k=3)
|
265 |
+
if not retrieved:
|
266 |
+
explanation_prompt = "No previous puns found to explain. Please provide the pun examples."
|
267 |
+
else:
|
268 |
+
explanation_prompt = "Please explain the following coding puns in detail:\n\n"
|
269 |
+
for item in retrieved:
|
270 |
+
explanation_prompt += f"- {item}\n"
|
271 |
+
explanation_prompt += "\nProvide a detailed explanation for each pun."
|
272 |
+
|
273 |
+
model, tokenizer = get_model_and_tokenizer()
|
274 |
+
input_ids = tokenizer.encode(explanation_prompt, return_tensors="pt").to(model.device)
|
275 |
+
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
|
276 |
+
kwargs = dict(
|
277 |
+
input_ids=input_ids,
|
278 |
+
streamer=streamer,
|
279 |
+
max_new_tokens=300,
|
280 |
+
temperature=0.7,
|
281 |
+
top_p=0.9,
|
282 |
+
)
|
283 |
+
try:
|
284 |
+
thread = Thread(target=model.generate, kwargs=kwargs)
|
285 |
+
with torch.no_grad():
|
286 |
+
thread.start()
|
287 |
+
except Exception as e:
|
288 |
+
logging.error(f"Error starting explanation thread: {e}")
|
289 |
+
raise e
|
290 |
+
|
291 |
+
explanation = ""
|
292 |
+
try:
|
293 |
+
for text in streamer:
|
294 |
+
explanation += text
|
295 |
+
except Exception as e:
|
296 |
+
logging.error(f"Error during explanation generation: {e}")
|
297 |
+
raise e
|
298 |
+
thread.join()
|
299 |
+
return explanation
|
300 |
|
|
|
301 |
|
302 |
# --- Helper to Format History ---
|
303 |
+
def format_history(history: List) -> List[Dict[str, str]]:
|
304 |
"""
|
305 |
Convert history (which might be a list of [user, assistant] pairs or already formatted dictionaries)
|
306 |
into a list of OpenAI-style message dictionaries.
|
307 |
+
|
308 |
+
Args:
|
309 |
+
history (List): List of conversation items.
|
310 |
+
|
311 |
+
Returns:
|
312 |
+
List[Dict[str, str]]: A list of formatted message dictionaries.
|
313 |
"""
|
314 |
messages = []
|
315 |
for item in history:
|
316 |
# If item is a list or tuple, try to unpack it if it has exactly 2 elements.
|
317 |
+
if isinstance(item, (list, tuple)) and len(item) == 2:
|
318 |
+
user_msg, assistant_msg = item
|
319 |
+
messages.append({"role": "user", "content": user_msg})
|
320 |
+
if assistant_msg:
|
321 |
+
messages.append({"role": "assistant", "content": assistant_msg})
|
|
|
|
|
|
|
|
|
322 |
elif isinstance(item, dict):
|
|
|
323 |
messages.append(item)
|
|
|
|
|
324 |
return messages
|
325 |
|
326 |
+
|
327 |
# --- Gradio Chat Interface Function ---
|
328 |
+
def gradio_interface(message: str, history: List, param_state: Dict, prompt_state: Dict) -> Generator[List[Dict[str, str]], None, None]:
|
329 |
"""
|
330 |
This function is called by Gradio's ChatInterface.
|
331 |
It uses the current saved generation parameters and prompt templates.
|
332 |
+
If the user request appears to ask for an explanation of puns,
|
333 |
+
it routes the request to the explanation function.
|
334 |
+
|
335 |
+
Args:
|
336 |
+
message (str): The user message.
|
337 |
+
history (List): The conversation history.
|
338 |
+
param_state (Dict): Generation parameters.
|
339 |
+
prompt_state (Dict): Prompt templates.
|
340 |
+
|
341 |
+
Yields:
|
342 |
+
Generator[List[Dict[str, str]]]: Updated history in OpenAI-style message dictionaries.
|
343 |
"""
|
344 |
+
# Check if the user is asking to explain puns.
|
345 |
+
if "explain" in message.lower() and "pun" in message.lower():
|
346 |
+
explanation = handle_explanation_request(message)
|
347 |
+
history = history + [[message, explanation]]
|
348 |
+
yield format_history(history)
|
349 |
+
return
|
350 |
+
|
351 |
try:
|
352 |
temp = float(param_state.get("temperature", 0.5))
|
353 |
top_p = float(param_state.get("top_p", 0.9))
|
354 |
max_new_tokens = int(param_state.get("max_new_tokens", 300))
|
355 |
memory_top_k = int(param_state.get("memory_top_k", 2))
|
356 |
+
except Exception as e:
|
357 |
+
logging.error(f"Parameter conversion error: {e}")
|
358 |
temp, top_p, max_new_tokens, memory_top_k = 0.5, 0.9, 300, 2
|
359 |
|
|
|
360 |
prompt_brainstorm_text = prompt_state.get("prompt_brainstorm", default_prompt_brainstorm)
|
361 |
prompt_code_generation_text = prompt_state.get("prompt_code_generation", default_prompt_code_generation)
|
362 |
prompt_synthesis_text = prompt_state.get("prompt_synthesis", default_prompt_synthesis)
|
|
|
377 |
):
|
378 |
# Update the last assistant message with the new partial response.
|
379 |
history[-1][1] = partial_response
|
|
|
380 |
yield format_history(history)
|
381 |
|
382 |
+
|
383 |
# --- UI Settings & Styling ---
|
384 |
ui_description = '''
|
385 |
<div>
|
|
|
418 |
}
|
419 |
"""
|
420 |
|
421 |
+
|
422 |
# --- Gradio UI ---
|
423 |
with gr.Blocks(css=css, title="DeepSeek Agent Swarm Chat") as demo:
|
424 |
gr.Markdown(ui_description)
|
425 |
+
|
426 |
# Hidden States to hold parameters and prompt configuration
|
427 |
param_state = gr.State({
|
428 |
"temperature": 0.5,
|
|
|
435 |
"prompt_code_generation": default_prompt_code_generation,
|
436 |
"prompt_synthesis": default_prompt_synthesis,
|
437 |
})
|
438 |
+
|
439 |
# Create top-level Tabs
|
440 |
with gr.Tabs():
|
441 |
# --- Chat Tab ---
|
442 |
with gr.Tab("Chat"):
|
|
|
443 |
chatbot = gr.Chatbot(height=450, placeholder=ui_placeholder, label="Agent Swarm Output", type="messages")
|
|
|
444 |
gr.ChatInterface(
|
445 |
fn=gradio_interface,
|
446 |
chatbot=chatbot,
|
|
|
455 |
cache_examples=False,
|
456 |
type="messages",
|
457 |
)
|
458 |
+
|
459 |
# --- Parameters Tab ---
|
460 |
with gr.Tab("Parameters"):
|
461 |
gr.Markdown("### Generation Parameters")
|
|
|
464 |
max_tokens_num = gr.Number(value=300, label="Max new tokens", precision=0)
|
465 |
memory_topk_slider = gr.Slider(minimum=1, maximum=5, step=1, value=2, label="Memory Retrieval Top K")
|
466 |
save_params_btn = gr.Button("Save Parameters")
|
|
|
467 |
save_params_btn.click(
|
468 |
lambda t, p, m, k: {"temperature": t, "top_p": p, "max_new_tokens": m, "memory_top_k": k},
|
469 |
inputs=[temp_slider, top_p_slider, max_tokens_num, memory_topk_slider],
|
470 |
outputs=param_state,
|
471 |
)
|
472 |
+
|
473 |
# --- Prompt Config Tab ---
|
474 |
with gr.Tab("Prompt Config"):
|
475 |
gr.Markdown("### Configure Prompt Templates")
|
|
|
489 |
lines=8,
|
490 |
)
|
491 |
save_prompts_btn = gr.Button("Save Prompts")
|
|
|
492 |
save_prompts_btn.click(
|
493 |
lambda b, c, s: {
|
494 |
"prompt_brainstorm": b,
|
|
|
498 |
inputs=[prompt_brainstorm_box, prompt_code_generation_box, prompt_synthesis_box],
|
499 |
outputs=prompt_state,
|
500 |
)
|
501 |
+
|
502 |
gr.Markdown(ui_license)
|
503 |
|
504 |
if __name__ == "__main__":
|
505 |
+
demo.launch()
|