import os import gradio as gr import torch from TTS.api import TTS import spaces # assumed custom module providing GPU decorators from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer from threading import Thread import logging from typing import Tuple, List, Dict, Generator import time # NEW: Import whisper for speech-to-text. import whisper # =========================== # Global Environment Settings # =========================== os.environ["COQUI_TOS_AGREED"] = "1" # Global device override (will be updated from UI later) device = "cuda" if torch.cuda.is_available() else "cpu" # Load the Whisper model (this may take a moment at startup) whisper_model = whisper.load_model("base") # Global dictionary for storing saved voice clones. voice_bank: Dict[str, str] = {} # --------------------------- # Simple Response Cache # --------------------------- response_cache: Dict[str, str] = {} # =========================== # Voice Cloning Setup # =========================== tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2").to(device) @spaces.GPU(enable_queue=True) def clone(text, audio): """ Generate a voice-cloned audio file given text and a reference audio file. Returns the path to the output audio file. """ try: tts.tts_to_file(text=text, speaker_wav=audio, language="en", file_path="./output.wav") return "./output.wav" except Exception as e: logging.error(f"TTS cloning failed: {e}") return None def save_voice(voice_name: str, voice_audio: str) -> None: """ Save a cloned voice under the given name. """ global voice_bank if voice_name and voice_audio: voice_bank[voice_name] = voice_audio def get_voice_options() -> List[str]: """ Returns a list of saved voice names. """ return list(voice_bank.keys()) def refresh_voice_list() -> gr.update: """ Returns an update with the latest voice list. """ options = get_voice_options() new_val = options[0] if options else "" return gr.update(choices=options, value=new_val) # =========================== # Deep Agent Chat Setup # =========================== logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') MODEL_ID = "unsloth/DeepSeek-R1-Distill-Qwen-7B-unsloth-bnb-4bit" models: Dict[str, AutoModelForCausalLM] = {} tokenizers: Dict[str, AutoTokenizer] = {} bnb_config_4bit = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, ) def get_model_and_tokenizer() -> Tuple[AutoModelForCausalLM, AutoTokenizer]: # Warm-up: if the model isn’t loaded, load it now. if "7B" not in models: logging.info(f"Loading 7B model: {MODEL_ID} on demand") try: tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( MODEL_ID, quantization_config=bnb_config_4bit, torch_dtype=torch.bfloat16, device_map='auto', trust_remote_code=True, ) model.eval() models["7B"] = model tokenizers["7B"] = tokenizer logging.info("Loaded 7B model on demand.") except Exception as e: logging.error(f"Failed to load model and tokenizer: {e}") raise e return models["7B"], tokenizers["7B"] # --------------------------- # Prompt Templates # --------------------------- default_prompts = { "coding": { "brainstorm": ( "**Round 1: Brainstorm & Analysis**\n" "Please analyze the following coding challenge or question. Consider the overall problem, " "potential edge cases, and any assumptions you might need to make. Explain your reasoning as you think aloud.\n\n" "**User Request:**\n{user_prompt}\n" ), "round2": ( "**Round 2: Detailed Reasoning & Strategy**\n" "Based on your initial analysis, please break down the problem into logical steps. " "Outline a plan or strategy that could be used to solve the challenge, highlighting key algorithms, structures, or design considerations.\n\n" "**Initial Analysis:**\n{brainstorm_response}\n\n" "**User Request:**\n{user_prompt}\n" ), "synthesis": ( "**Round 3: Synthesis & Implementation**\n" "Taking into account the steps outlined previously, synthesize a coherent solution. " "Provide a detailed explanation of how the code addresses the problem while encouraging best practices and clear logic.\n\n" "**Detailed Strategy:**\n{round2_response}\n" ), "rationale": ( "**Round 4: Reflection & Final Output**\n" "Review your solution and provide a final, well-rounded response that summarizes your reasoning and the implementation strategy. " "Explain any key decisions made during the process and how they contribute to an effective solution.\n\n" "**Final Draft:**\n{final_response}\n" ) }, "math": { "brainstorm": ( "**Round 1: Problem Analysis & Exploration**\n" "Carefully analyze the mathematical problem provided. Describe the underlying concepts and any assumptions you are making. " "Detail your initial reasoning and potential methods to tackle the problem.\n\n" "**Problem:**\n{user_prompt}\n" ), "round2": ( "**Round 2: Detailed Reasoning & Methodology**\n" "Based on your initial exploration, break down the problem into sequential steps or methodologies. " "Explain the reasoning behind each step and how they connect to solve the problem.\n\n" "**Initial Analysis:**\n{brainstorm_response}\n\n" "**Problem:**\n{user_prompt}\n" ), "synthesis": ( "**Round 3: Synthesis & Step-by-Step Solution**\n" "Integrate your previous reasoning into a structured solution. Clearly explain each step of your calculation or proof, " "ensuring that your logical progression is easy to follow.\n\n" "**Detailed Methodology:**\n{round2_response}\n" ), "rationale": ( "**Round 4: Reflection & Final Explanation**\n" "Present your final solution along with a detailed explanation of the reasoning behind each step. " "Discuss any assumptions and insights that helped you arrive at the final answer.\n\n" "**Final Solution:**\n{final_response}\n" ) }, "writing": { "brainstorm": ( "**Round 1: Creative Exploration & Conceptualization**\n" "Read the following writing prompt and explore its themes, tone, and potential narrative directions. " "Outline your initial thoughts and reasoning behind various creative choices.\n\n" "**Writing Prompt:**\n{user_prompt}\n" ), "round2": ( "**Round 2: Detailed Outline & Narrative Structure**\n" "Based on your brainstorming, create a detailed outline that organizes the narrative or essay. " "Explain the reasoning behind your structure, the flow of ideas, and how you plan to incorporate creative elements.\n\n" "**Initial Brainstorming:**\n{brainstorm_response}\n\n" "**Writing Prompt:**\n{user_prompt}\n" ), "synthesis": ( "**Round 3: Draft Synthesis & Refinement**\n" "Integrate your outline and creative ideas into a coherent draft. Provide a well-rounded narrative that is both engaging and logically structured. " "Explain your thought process as you refine the narrative.\n\n" "**Outline & Strategy:**\n{round2_response}\n" ), "rationale": ( "**Round 4: Reflection & Final Editing**\n" "Review your draft and provide a final version that reflects thoughtful editing and creative reasoning. " "Explain the choices made in refining the text, from structure to stylistic decisions.\n\n" "**Final Draft:**\n{final_response}\n" ) } } # The prompt state now contains both default and custom modes. initial_prompt_state = { "default": default_prompts, "custom": {} # custom modes will be added here as {mode_name: [round_prompt1, round_prompt2, ...]} } def detect_domain(user_prompt: str) -> str: prompt_lower = user_prompt.lower() math_keywords = ["solve", "integral", "derivative", "equation", "proof", "calculate", "sum", "product"] writing_keywords = ["write", "story", "essay", "novel", "poem", "article", "narrative", "creative"] coding_keywords = ["code", "program", "debug", "compile", "algorithm", "function"] if any(kw in prompt_lower for kw in math_keywords): logging.info("Domain detected as: math") return "math" elif any(kw in prompt_lower for kw in writing_keywords): logging.info("Domain detected as: writing") return "writing" elif any(kw in prompt_lower for kw in coding_keywords): logging.info("Domain detected as: coding") return "coding" else: logging.info("No specific domain detected; defaulting to coding") return "coding" class MemoryManager: def __init__(self) -> None: self.shared_memory: List[str] = [] def store(self, item: str) -> None: self.shared_memory.append(item) logging.info(f"[Memory Stored]: {item[:50]}...") def retrieve(self, query: str, top_k: int = 3) -> List[str]: query_lower = query.lower() relevant = [item for item in self.shared_memory if query_lower in item.lower()] if not relevant: logging.info("[Memory Retrieval]: No relevant memories found.") else: logging.info(f"[Memory Retrieval]: Found {len(relevant)} relevant memories.") return relevant[-top_k:] global_memory_manager = MemoryManager() def generate_response(model, tokenizer, prompt: str, max_tokens: int, temperature: float, top_p: float, repetition_penalty: float = 1.0, num_beams: int = 1) -> str: # Check cache first cache_key = f"{prompt}-{max_tokens}-{temperature}-{top_p}-{repetition_penalty}-{num_beams}" if cache_key in response_cache: logging.info("Returning cached response.") return response_cache[cache_key] input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device) streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) kwargs = dict( input_ids=input_ids, streamer=streamer, max_new_tokens=max_tokens, temperature=temperature, top_p=top_p, do_sample=True, repetition_penalty=repetition_penalty, num_beams=num_beams, ) thread = Thread(target=model.generate, kwargs=kwargs) with torch.no_grad(): thread.start() response = "" try: for text in streamer: response += text except Exception as e: logging.error(f"Error during generation: {e}") raise e thread.join() # Cache the response response_cache[cache_key] = response return response class MultiRoundAgent: def __init__(self, model, tokenizer, prompt_templates, memory_manager: MemoryManager): """ prompt_templates can be a dict (for default modes) or a list (for custom modes) """ self.model = model self.tokenizer = tokenizer self.prompt_templates = prompt_templates self.memory_manager = memory_manager def run_pipeline(self, user_prompt: str, params: Dict, show_raw: bool = False) -> Generator[str, None, None]: if isinstance(self.prompt_templates, dict): # Default fixed 4-round pipeline logging.info("--- Round 1 ---") prompt_r1 = self.prompt_templates["brainstorm"].format(user_prompt=user_prompt) r1 = generate_response(self.model, self.tokenizer, prompt_r1, params.get("max_new_tokens"), params.get("temp"), params.get("top_p"), params.get("repetition_penalty"), params.get("num_beams")) self.memory_manager.store(f"Round 1 Response: {r1}") logging.info("--- Round 2 ---") prompt_r2 = self.prompt_templates["round2"].format(brainstorm_response=r1, user_prompt=user_prompt) r2 = generate_response(self.model, self.tokenizer, prompt_r2, params.get("max_new_tokens") + 100, params.get("temp"), params.get("top_p"), params.get("repetition_penalty"), params.get("num_beams")) self.memory_manager.store(f"Round 2 Response: {r2}") logging.info("--- Round 3 ---") prompt_r3 = self.prompt_templates["synthesis"].format(round2_response=r2) input_ids_r3 = self.tokenizer.encode(prompt_r3, return_tensors="pt").to(self.model.device) streamer_r3 = TextIteratorStreamer(self.tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) kwargs_r3 = dict( input_ids=input_ids_r3, streamer=streamer_r3, max_new_tokens=params.get("max_new_tokens") // 2, temperature=params.get("temp"), top_p=params.get("top_p"), repetition_penalty=params.get("repetition_penalty"), num_beams=params.get("num_beams") ) thread_r3 = Thread(target=self.model.generate, kwargs=kwargs_r3) with torch.no_grad(): thread_r3.start() r3 = "" try: for text in streamer_r3: r3 += text yield r3 # Progressive updates except Exception as e: logging.error(f"Error during Round 3 streaming: {e}") raise e thread_r3.join() self.memory_manager.store(f"Final Synthesis Response: {r3}") logging.info("--- Round 4 ---") prompt_r4 = self.prompt_templates["rationale"].format(final_response=r3) r4 = generate_response(self.model, self.tokenizer, prompt_r4, 300, params.get("temp"), params.get("top_p"), params.get("repetition_penalty"), params.get("num_beams")) self.memory_manager.store(f"Round 4 Response: {r4}") final_output = (f"{r4}\n\n[Raw Outputs]\nRound 1:\n{r1}\n\nRound 2:\n{r2}\n\nRound 3:\n{r3}\n\nRound 4:\n{r4}\n") if show_raw else r4 yield final_output elif isinstance(self.prompt_templates, list): # Custom mode: iterate over rounds. prev_response = "" full_output = "" total_rounds = len(self.prompt_templates) for idx, round_template in enumerate(self.prompt_templates): round_num = idx + 1 logging.info(f"--- Custom Mode: Round {round_num} of {total_rounds} ---") if idx == 0: prompt = round_template.format(user_prompt=user_prompt) else: prompt = round_template.format(user_prompt=user_prompt, prev_response=prev_response) response = generate_response(self.model, self.tokenizer, prompt, params.get("max_new_tokens"), params.get("temp"), params.get("top_p"), params.get("repetition_penalty"), params.get("num_beams")) self.memory_manager.store(f"Custom Mode Round {round_num} Response: {response}") full_output += f"\n--- Round {round_num} ---\n{response}" prev_response = response yield full_output else: yield "Invalid prompt template format." @spaces.GPU(duration=180) def swarm_agent_iterative(user_prompt: str, temp: float, top_p: float, max_new_tokens: int, memory_top_k: int, prompt_templates, domain: str, show_raw: bool, repetition_penalty: float, num_beams: int) -> Generator[str, None, None]: model, tokenizer = get_model_and_tokenizer() agent = MultiRoundAgent(model, tokenizer, prompt_templates, global_memory_manager) params = { "temp": temp, "top_p": top_p, "max_new_tokens": max_new_tokens, "repetition_penalty": repetition_penalty, "num_beams": num_beams } return agent.run_pipeline(user_prompt, params, show_raw) def handle_explanation_request(user_prompt: str, history: List) -> str: retrieved = global_memory_manager.retrieve("Round 4 Response:", top_k=3) explanation_prompt = "Below are previous final outputs and related context from our conversation:\n" if retrieved: for item in retrieved: explanation_prompt += f"- {item}\n" else: explanation_prompt += "No stored final output found.\n" explanation_prompt += "\nRecent related exchanges:\n" for chat in history: if ("explain" in chat[0].lower()) or (chat[1] and "explain" in chat[1].lower()): explanation_prompt += f"User: {chat[0]}\nAssistant: {chat[1]}\n" explanation_prompt += "\nBased on the above context, please provide a detailed explanation of the creative choices." model, tokenizer = get_model_and_tokenizer() explanation = generate_response(model, tokenizer, explanation_prompt, 300, 0.7, 0.9) return explanation def format_history(history: List) -> List[Dict[str, str]]: messages = [] for item in history: if isinstance(item, (list, tuple)) and len(item) == 2: user_msg, assistant_msg = item if user_msg == "__final_agent_response__": continue messages.append({"role": "user", "content": user_msg}) if assistant_msg: messages.append({"role": "assistant", "content": assistant_msg}) elif isinstance(item, dict): messages.append(item) return messages def gradio_interface(message: str, history: List, param_state: Dict, prompt_state: Dict, mode: str) -> Generator[List[Dict[str, str]], None, None]: if "explain" in message.lower(): explanation = handle_explanation_request(message, history) history = history + [[message, explanation]] yield format_history(history) return try: temp = float(param_state.get("temperature", 0.5)) top_p = float(param_state.get("top_p", 0.9)) max_new_tokens = int(param_state.get("max_new_tokens", 300)) repetition_penalty = float(param_state.get("repetition_penalty", 1.0)) num_beams = int(param_state.get("num_beams", 1)) memory_top_k = int(param_state.get("memory_top_k", 2)) show_raw = bool(param_state.get("show_raw_output", False)) except Exception as e: logging.error(f"Parameter conversion error: {e}") temp, top_p, max_new_tokens, repetition_penalty, num_beams, memory_top_k, show_raw = 0.5, 0.9, 300, 1.0, 1, 2, False if mode in prompt_state.get("default", {}): prompt_templates = prompt_state["default"][mode] elif mode in prompt_state.get("custom", {}): prompt_templates = prompt_state["custom"][mode] else: detected = detect_domain(message) prompt_templates = prompt_state["default"].get(detected, prompt_state["default"]["coding"]) mode = detected history = history + [[message, ""]] # Show a loading status yield format_history(history) for partial_response in swarm_agent_iterative( user_prompt=message, temp=temp, top_p=top_p, max_new_tokens=max_new_tokens, memory_top_k=memory_top_k, prompt_templates=prompt_templates, domain=mode, show_raw=show_raw, repetition_penalty=repetition_penalty, num_beams=num_beams ): history[-1][1] = partial_response yield format_history(history) yield format_history(history) def generate_agent_audio(latest_text: str, voice_reference: str) -> str: """ Generate an audio response using the cloned voice. If the provided voice_reference is a key in the voice bank, its stored file path is used. """ if latest_text: if voice_reference in voice_bank: audio_path = clone(latest_text, voice_bank[voice_reference]) else: audio_path = clone(latest_text, voice_reference) return audio_path return None # NEW: Speech-to-Text Function using Whisper. def transcribe_audio(audio_file: str) -> str: """ Transcribe the provided audio file to text using the Whisper model. """ try: result = whisper_model.transcribe(audio_file) transcription = result.get("text", "").strip() logging.info(f"Transcription result: {transcription}") return transcription except Exception as e: logging.error(f"Transcription error: {e}") return "Transcription failed." # --------------------------- # Warm-Up Model Function # --------------------------- def warmup_model(): try: get_model_and_tokenizer() logging.info("Model warm-up complete.") except Exception as e: logging.error(f"Model warm-up failed: {e}") warmup_model() # =========================== # Custom Gradio Theme # =========================== theme = gr.themes.Soft( primary_hue="pink", secondary_hue="pink", neutral_hue="purple", font=['IBM Plex Sans', 'ui-sans-serif', 'system-ui', 'sans-serif'], ).set( background_fill_primary='white', shadow_drop='rgba(0,0,0,0.05) 0px 1px 2px 0px', shadow_drop_lg='0 1px 3px 0 rgb(0 0 0 / 0.1), 0 1px 2px -1px rgb(0 0 0 / 0.1)', shadow_spread='3px', block_background_fill='*background_fill_primary', block_border_width='1px', block_border_width_dark='1px', block_label_background_fill='*background_fill_primary', block_label_background_fill_dark='*background_fill_secondary', block_label_text_color='*neutral_500', block_label_text_color_dark='*neutral_200', block_label_margin='0', block_label_padding='*spacing_sm *spacing_lg', block_label_radius='calc(*radius_sm - 1px) 0 calc(*radius_sm - 1px) 0', block_label_text_size='*text_sm', block_label_text_weight='400', block_title_background_fill='none', block_title_background_fill_dark='none', block_title_text_color='*neutral_500', block_title_text_color_dark='*neutral_200', block_title_padding='0', block_title_radius='none', block_title_text_weight='400', panel_border_width='0', panel_border_width_dark='0', checkbox_background_color_selected='*color_accent', checkbox_background_color_selected_dark='*color_accent', checkbox_border_color='*neutral_300', checkbox_border_color_dark='*neutral_700', checkbox_border_color_focus='*color_accent', checkbox_border_color_focus_dark='*color_accent', checkbox_border_color_selected='*color_accent', checkbox_border_color_selected_dark='*color_accent', checkbox_border_width='*input_border_width', checkbox_shadow='*input_shadow', checkbox_label_background_fill_selected='*checkbox_label_background_fill', checkbox_label_background_fill_selected_dark='*checkbox_label_background_fill', checkbox_label_shadow='none', checkbox_label_text_color_selected='*checkbox_label_text_color', input_background_fill='*neutral_100', input_border_color='*border_color_primary', input_shadow='none', input_shadow_dark='none', input_shadow_focus='*input_shadow', input_shadow_focus_dark='*input_shadow', slider_color='*color_accent', slider_color_dark='*color_accent', button_primary_background_fill_hover='*primary_600', button_primary_background_fill_hover_dark='*primary_700', button_primary_shadow='none', button_primary_shadow_hover='*button_primary_shadow', button_primary_shadow_active='*button_primary_shadow', button_primary_shadow_dark='none', button_secondary_background_fill='*neutral_200', button_secondary_background_fill_hover='*neutral_300', button_secondary_background_fill_hover_dark='*neutral_700', button_secondary_text_color='black', button_secondary_shadow='*button_primary_shadow', button_secondary_shadow_hover='*button_secondary_shadow', button_secondary_shadow_active='*button_secondary_shadow', button_secondary_shadow_dark='*button_primary_shadow' ) # =========================== # Combined Gradio Interface # =========================== with gr.Blocks(theme=theme, title="Combined Voice Clone & Agent Chat") as demo: # Shared states for project settings, prompt configuration, and voice selection. param_state = gr.State({ "temperature": 0.5, "top_p": 0.9, "max_new_tokens": 300, "memory_top_k": 2, "show_raw_output": False, "repetition_penalty": 1.0, "num_beams": 1, "use_cpu": False # Toggle for device override }) prompt_state = gr.State(initial_prompt_state) selected_voice = gr.State(value="") # holds the currently selected voice # A status display to show device info. device_status = gr.Markdown(f"**Running on:** {device.upper()}") with gr.Tabs(): # ----- Tab 1: Voice Setup ----- with gr.Tab("Voice Setup"): gr.Markdown("

Voice Setup

") with gr.Column(variant="panel"): gr.Markdown("

Clone a voice and save it with a custom name. Test TTS using your cloned voices.

") with gr.Row(): text_input = gr.Textbox(label='Text to Clone', placeholder="Enter the text to speak...", elem_classes="full-width") with gr.Row(): audio_input = gr.Audio(label='Voice Reference Audio', type='filepath') with gr.Row(): clone_btn = gr.Button("Clone Voice") with gr.Row(): output_audio = gr.Audio(label='Cloned Voice Output', type='filepath') clone_btn.click(fn=clone, inputs=[text_input, audio_input], outputs=output_audio) with gr.Row(): voice_name_input = gr.Textbox(label="Voice Name", placeholder="Enter a name for this voice clone") with gr.Row(): save_voice_btn = gr.Button("Save Voice") save_voice_btn.click(fn=save_voice, inputs=[voice_name_input, output_audio], outputs=[]) with gr.Row(): refresh_voice_btn_setup = gr.Button("Refresh Voice List") voice_dropdown_setup = gr.Dropdown(choices=get_voice_options(), label="Select Saved Voice", interactive=True) set_voice_btn = gr.Button("Set Selected Voice") refresh_voice_btn_setup.click(fn=refresh_voice_list, outputs=voice_dropdown_setup) set_voice_btn.click(fn=lambda x: x, inputs=[voice_dropdown_setup], outputs=selected_voice) gr.Markdown("

(The selected voice will be used for TTS responses in Chat.)

") gr.Markdown("
") gr.Markdown("

TTS Test

") with gr.Row(): tts_test_input = gr.Textbox(label="Test Text", placeholder="Enter text to test TTS...", elem_classes="full-width") with gr.Row(): tts_test_btn = gr.Button("Test TTS") tts_test_output = gr.Audio(label="TTS Output", type="filepath") tts_test_btn.click(fn=lambda txt, override, sel: generate_agent_audio(txt, override if override else sel), inputs=[tts_test_input, audio_input, selected_voice], outputs=tts_test_output) # ----- Tab 2: Chat ----- with gr.Tab("Chat"): gr.Markdown("""

DeepSeek Agent Swarm Chat

Multi-round agent with prompt chaining. Ask me anything!

""") with gr.Column(): with gr.Row(): mode_selector = gr.Radio(choices=["coding", "math", "writing"], value="coding", label="Select Mode") with gr.Row(): chat_input = gr.Textbox(label="Your Message", placeholder="Type your message here...", lines=2, elem_id="msg_input") with gr.Row(): chat_audio_input = gr.Audio(label="Or record/upload your message", type="filepath") transcribe_btn = gr.Button("Transcribe Audio") transcribe_btn.click(fn=transcribe_audio, inputs=chat_audio_input, outputs=chat_input) with gr.Row(): send_btn = gr.Button("Send", variant="primary") export_btn = gr.Button("Generate Chat Transcript") chatbot = gr.Chatbot(height=450, label="Agent Swarm Output", type="messages") with gr.Row(): use_tts_checkbox = gr.Checkbox(label="Generate Audio Response using TTS", value=False) chat_voice_dropdown = gr.Dropdown(choices=get_voice_options(), label="Select Voice for TTS", interactive=True) refresh_voice_btn_chat = gr.Button("Refresh Voice List") refresh_voice_btn_chat.click(fn=refresh_voice_list, outputs=chat_voice_dropdown) agent_audio = gr.Audio(label="Agent Audio Response", type="filepath") def chat_wrapper(message, history, param_state, prompt_state, mode): final_history = [] history.append(["", "**Generating response...**"]) for h in gradio_interface(message, history, param_state, prompt_state, mode): final_history = h return final_history send_btn.click(fn=chat_wrapper, inputs=[chat_input, chatbot, param_state, prompt_state, mode_selector], outputs=[chatbot]) def conditional_tts(latest_text, use_tts, selected_voice_val): if use_tts: return generate_agent_audio(latest_text, selected_voice_val) return None def get_latest_text(chat_history): for msg in reversed(chat_history): if msg.get("role") == "assistant" and msg.get("content"): return msg["content"] return "" latest_text_state = gr.State(value="") gen_audio_btn = gr.Button("Generate Audio from Agent Response") gen_audio_btn.click(fn=lambda chat: get_latest_text(chat), inputs=[chatbot], outputs=latest_text_state) gen_audio_btn.click(fn=conditional_tts, inputs=[latest_text_state, use_tts_checkbox, chat_voice_dropdown], outputs=agent_audio) def export_transcript(history): transcript = "" for item in history: if isinstance(item, list) and len(item) == 2: transcript += f"User: {item[0]}\nAssistant: {item[1]}\n\n" return transcript export_btn.click(fn=export_transcript, inputs=[chatbot], outputs=chatbot) # ----- Tab 3: Project Settings ----- with gr.Tab("Project Settings"): gr.Markdown("

Project Settings

") with gr.Tabs(): with gr.Tab("Generation Parameters"): gr.Markdown("

Generation Parameters

") with gr.Row(): temp_slider = gr.Slider(minimum=0.1, maximum=1.0, step=0.1, value=0.5, label="Temperature") top_p_slider = gr.Slider(minimum=0.01, maximum=1.0, step=0.05, value=0.9, label="Top P") with gr.Row(): max_tokens_num = gr.Number(value=300, label="Max New Tokens", precision=0) memory_topk_slider = gr.Slider(minimum=1, maximum=5, step=1, value=2, label="Memory Retrieval Top K") with gr.Row(): rep_penalty_slider = gr.Slider(minimum=0.5, maximum=2.0, step=0.1, value=1.0, label="Repetition Penalty") num_beams_slider = gr.Slider(minimum=1, maximum=5, step=1, value=1, label="Number of Beams") with gr.Row(): show_raw_checkbox = gr.Checkbox(value=False, label="Show Raw Output") use_cpu_checkbox = gr.Checkbox(value=False, label="Force Use CPU") save_params_btn = gr.Button("Save Generation Parameters") def save_params(t, p, m, k, rp, nb, s, use_cpu): global device if use_cpu: device = "cpu" else: device = "cuda" if torch.cuda.is_available() else "cpu" return { "temperature": t, "top_p": p, "max_new_tokens": m, "memory_top_k": k, "repetition_penalty": rp, "num_beams": nb, "show_raw_output": s, "use_cpu": use_cpu } save_params_btn.click( save_params, inputs=[temp_slider, top_p_slider, max_tokens_num, memory_topk_slider, rep_penalty_slider, num_beams_slider, show_raw_checkbox, use_cpu_checkbox], outputs=param_state, ) save_params_btn.click(fn=lambda params: f"**Running on:** {device.upper()}", inputs=param_state, outputs=device_status) gr.Markdown("Note: Repetition penalty and number of beams affect generation diversity and quality.") with gr.Tab("Prompt Config (Default Modes)"): gr.Markdown("

Prompt Configurations for Default Modes

") with gr.Tabs(): with gr.Tab("Coding"): prompt_brainstorm_box_code = gr.Textbox( value=default_prompts["coding"]["brainstorm"], label="Brainstorm Prompt (Coding)", lines=8, ) prompt_round2_box_code = gr.Textbox( value=default_prompts["coding"]["round2"], label="Round 2 Prompt (Coding)", lines=8, ) prompt_synthesis_box_code = gr.Textbox( value=default_prompts["coding"]["synthesis"], label="Synthesis Prompt (Coding)", lines=8, ) prompt_rationale_box_code = gr.Textbox( value=default_prompts["coding"]["rationale"], label="Rationale Prompt (Coding)", lines=8, ) with gr.Tab("Math"): prompt_brainstorm_box_math = gr.Textbox( value=default_prompts["math"]["brainstorm"], label="Brainstorm Prompt (Math)", lines=8, ) prompt_round2_box_math = gr.Textbox( value=default_prompts["math"]["round2"], label="Round 2 Prompt (Math)", lines=8, ) prompt_synthesis_box_math = gr.Textbox( value=default_prompts["math"]["synthesis"], label="Synthesis Prompt (Math)", lines=8, ) prompt_rationale_box_math = gr.Textbox( value=default_prompts["math"]["rationale"], label="Rationale Prompt (Math)", lines=8, ) with gr.Tab("Writing"): prompt_brainstorm_box_writing = gr.Textbox( value=default_prompts["writing"]["brainstorm"], label="Brainstorm Prompt (Writing)", lines=8, ) prompt_round2_box_writing = gr.Textbox( value=default_prompts["writing"]["round2"], label="Round 2 Prompt (Writing)", lines=8, ) prompt_synthesis_box_writing = gr.Textbox( value=default_prompts["writing"]["synthesis"], label="Synthesis Prompt (Writing)", lines=8, ) prompt_rationale_box_writing = gr.Textbox( value=default_prompts["writing"]["rationale"], label="Rationale Prompt (Writing)", lines=8, ) save_prompts_btn = gr.Button("Save Default Prompt Configurations") def save_default_prompts(code_brain, code_r2, code_syn, code_rat, math_brain, math_r2, math_syn, math_rat, writing_brain, writing_r2, writing_syn, writing_rat): return { "default": { "coding": { "brainstorm": code_brain, "round2": code_r2, "synthesis": code_syn, "rationale": code_rat, }, "math": { "brainstorm": math_brain, "round2": math_r2, "synthesis": math_syn, "rationale": math_rat, }, "writing": { "brainstorm": writing_brain, "round2": writing_r2, "synthesis": writing_syn, "rationale": writing_rat, } }, "custom": prompt_state.value.get("custom", {}) } save_prompts_btn.click( save_default_prompts, inputs=[prompt_brainstorm_box_code, prompt_round2_box_code, prompt_synthesis_box_code, prompt_rationale_box_code, prompt_brainstorm_box_math, prompt_round2_box_math, prompt_synthesis_box_math, prompt_rationale_box_math, prompt_brainstorm_box_writing, prompt_round2_box_writing, prompt_synthesis_box_writing, prompt_rationale_box_writing], outputs=prompt_state, ) with gr.Tab("Custom Modes"): gr.Markdown("

Create / Edit Custom Modes

") gr.Markdown( "Define a custom mode by providing a unique mode name, selecting the number of rounds (up to 10), " "and editing the prompt for each round. In custom mode prompts, you can use the placeholders `{user_prompt}` " "(for the first round) and `{prev_response}` (for subsequent rounds)." ) with gr.Row(): custom_mode_name = gr.Textbox(label="Custom Mode Name", placeholder="Enter a unique mode name") custom_round_count = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Number of Rounds") custom_round1 = gr.Textbox(label="Round 1 Prompt", lines=4, placeholder="e.g., Use {user_prompt} here") custom_round2 = gr.Textbox(label="Round 2 Prompt", lines=4, placeholder="e.g., Use {user_prompt} and {prev_response}") custom_round3 = gr.Textbox(label="Round 3 Prompt", lines=4, placeholder="e.g., Use {user_prompt} and {prev_response}") custom_round4 = gr.Textbox(label="Round 4 Prompt", lines=4, placeholder="Optional") custom_round5 = gr.Textbox(label="Round 5 Prompt", lines=4, placeholder="Optional") custom_round6 = gr.Textbox(label="Round 6 Prompt", lines=4, placeholder="Optional") custom_round7 = gr.Textbox(label="Round 7 Prompt", lines=4, placeholder="Optional") custom_round8 = gr.Textbox(label="Round 8 Prompt", lines=4, placeholder="Optional") custom_round9 = gr.Textbox(label="Round 9 Prompt", lines=4, placeholder="Optional") custom_round10 = gr.Textbox(label="Round 10 Prompt", lines=4, placeholder="Optional") def save_custom_mode(name, round_count, r1, r2, r3, r4, r5, r6, r7, r8, r9, r10, current_prompt_state): if not name: return gr.update(), current_prompt_state rounds = [] round_prompts = [r1, r2, r3, r4, r5, r6, r7, r8, r9, r10] for i in range(round_count): if round_prompts[i].strip(): rounds.append(round_prompts[i]) custom_modes = current_prompt_state.get("custom", {}) custom_modes[name] = rounds new_prompt_state = { "default": current_prompt_state.get("default", {}), "custom": custom_modes } return gr.update(value=""), new_prompt_state save_custom_mode_btn = gr.Button("Save Custom Mode") save_custom_mode_btn.click( save_custom_mode, inputs=[custom_mode_name, custom_round_count, custom_round1, custom_round2, custom_round3, custom_round4, custom_round5, custom_round6, custom_round7, custom_round8, custom_round9, custom_round10, prompt_state], outputs=[custom_mode_name, prompt_state] ) def update_mode_choices(current_prompt_state): default_modes = list(current_prompt_state.get("default", {}).keys()) custom_modes = list(current_prompt_state.get("custom", {}).keys()) all_modes = default_modes + custom_modes default_choice = default_modes[0] if default_modes else (custom_modes[0] if custom_modes else "") return gr.update(choices=all_modes, value=default_choice) refresh_mode_selector_btn = gr.Button("Refresh Mode List") refresh_mode_selector_btn.click(fn=update_mode_choices, inputs=prompt_state, outputs=mode_selector) gr.Markdown("
") gr.Markdown("

These settings affect the entire project.

") gr.Markdown("

Agent Chat using DeepSeek Agent Swarm

") if __name__ == "__main__": demo.launch(share=True)