import os import gradio as gr import torch from TTS.api import TTS import spaces # assumed custom module providing GPU decorators from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer from threading import Thread import logging from typing import Tuple, List, Dict, Generator import time # NEW: Import whisper for speech-to-text. import whisper # =========================== # Global Environment Settings # =========================== os.environ["COQUI_TOS_AGREED"] = "1" # Global device override (will be updated from UI later) device = "cuda" if torch.cuda.is_available() else "cpu" # Load the Whisper model (this may take a moment at startup) whisper_model = whisper.load_model("base") # Global dictionary for storing saved voice clones. voice_bank: Dict[str, str] = {} # --------------------------- # Simple Response Cache # --------------------------- response_cache: Dict[str, str] = {} # =========================== # Voice Cloning Setup # =========================== tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2").to(device) @spaces.GPU(enable_queue=True) def clone(text, audio): """ Generate a voice-cloned audio file given text and a reference audio file. Returns the path to the output audio file. """ try: tts.tts_to_file(text=text, speaker_wav=audio, language="en", file_path="./output.wav") return "./output.wav" except Exception as e: logging.error(f"TTS cloning failed: {e}") return None def save_voice(voice_name: str, voice_audio: str) -> None: """ Save a cloned voice under the given name. """ global voice_bank if voice_name and voice_audio: voice_bank[voice_name] = voice_audio def get_voice_options() -> List[str]: """ Returns a list of saved voice names. """ return list(voice_bank.keys()) def refresh_voice_list() -> gr.update: """ Returns an update with the latest voice list. """ options = get_voice_options() new_val = options[0] if options else "" return gr.update(choices=options, value=new_val) # =========================== # Deep Agent Chat Setup # =========================== logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') MODEL_ID = "unsloth/DeepSeek-R1-Distill-Qwen-7B-unsloth-bnb-4bit" models: Dict[str, AutoModelForCausalLM] = {} tokenizers: Dict[str, AutoTokenizer] = {} bnb_config_4bit = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, ) def get_model_and_tokenizer() -> Tuple[AutoModelForCausalLM, AutoTokenizer]: # Warm-up: if the model isn’t loaded, load it now. if "7B" not in models: logging.info(f"Loading 7B model: {MODEL_ID} on demand") try: tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( MODEL_ID, quantization_config=bnb_config_4bit, torch_dtype=torch.bfloat16, device_map='auto', trust_remote_code=True, ) model.eval() models["7B"] = model tokenizers["7B"] = tokenizer logging.info("Loaded 7B model on demand.") except Exception as e: logging.error(f"Failed to load model and tokenizer: {e}") raise e return models["7B"], tokenizers["7B"] # --------------------------- # Prompt Templates # --------------------------- default_prompts = { "coding": { "brainstorm": ( "**Round 1: Brainstorm & Analysis**\n" "Please analyze the following coding challenge or question. Consider the overall problem, " "potential edge cases, and any assumptions you might need to make. Explain your reasoning as you think aloud.\n\n" "**User Request:**\n{user_prompt}\n" ), "round2": ( "**Round 2: Detailed Reasoning & Strategy**\n" "Based on your initial analysis, please break down the problem into logical steps. " "Outline a plan or strategy that could be used to solve the challenge, highlighting key algorithms, structures, or design considerations.\n\n" "**Initial Analysis:**\n{brainstorm_response}\n\n" "**User Request:**\n{user_prompt}\n" ), "synthesis": ( "**Round 3: Synthesis & Implementation**\n" "Taking into account the steps outlined previously, synthesize a coherent solution. " "Provide a detailed explanation of how the code addresses the problem while encouraging best practices and clear logic.\n\n" "**Detailed Strategy:**\n{round2_response}\n" ), "rationale": ( "**Round 4: Reflection & Final Output**\n" "Review your solution and provide a final, well-rounded response that summarizes your reasoning and the implementation strategy. " "Explain any key decisions made during the process and how they contribute to an effective solution.\n\n" "**Final Draft:**\n{final_response}\n" ) }, "math": { "brainstorm": ( "**Round 1: Problem Analysis & Exploration**\n" "Carefully analyze the mathematical problem provided. Describe the underlying concepts and any assumptions you are making. " "Detail your initial reasoning and potential methods to tackle the problem.\n\n" "**Problem:**\n{user_prompt}\n" ), "round2": ( "**Round 2: Detailed Reasoning & Methodology**\n" "Based on your initial exploration, break down the problem into sequential steps or methodologies. " "Explain the reasoning behind each step and how they connect to solve the problem.\n\n" "**Initial Analysis:**\n{brainstorm_response}\n\n" "**Problem:**\n{user_prompt}\n" ), "synthesis": ( "**Round 3: Synthesis & Step-by-Step Solution**\n" "Integrate your previous reasoning into a structured solution. Clearly explain each step of your calculation or proof, " "ensuring that your logical progression is easy to follow.\n\n" "**Detailed Methodology:**\n{round2_response}\n" ), "rationale": ( "**Round 4: Reflection & Final Explanation**\n" "Present your final solution along with a detailed explanation of the reasoning behind each step. " "Discuss any assumptions and insights that helped you arrive at the final answer.\n\n" "**Final Solution:**\n{final_response}\n" ) }, "writing": { "brainstorm": ( "**Round 1: Creative Exploration & Conceptualization**\n" "Read the following writing prompt and explore its themes, tone, and potential narrative directions. " "Outline your initial thoughts and reasoning behind various creative choices.\n\n" "**Writing Prompt:**\n{user_prompt}\n" ), "round2": ( "**Round 2: Detailed Outline & Narrative Structure**\n" "Based on your brainstorming, create a detailed outline that organizes the narrative or essay. " "Explain the reasoning behind your structure, the flow of ideas, and how you plan to incorporate creative elements.\n\n" "**Initial Brainstorming:**\n{brainstorm_response}\n\n" "**Writing Prompt:**\n{user_prompt}\n" ), "synthesis": ( "**Round 3: Draft Synthesis & Refinement**\n" "Integrate your outline and creative ideas into a coherent draft. Provide a well-rounded narrative that is both engaging and logically structured. " "Explain your thought process as you refine the narrative.\n\n" "**Outline & Strategy:**\n{round2_response}\n" ), "rationale": ( "**Round 4: Reflection & Final Editing**\n" "Review your draft and provide a final version that reflects thoughtful editing and creative reasoning. " "Explain the choices made in refining the text, from structure to stylistic decisions.\n\n" "**Final Draft:**\n{final_response}\n" ) } } # The prompt state now contains both default and custom modes. initial_prompt_state = { "default": default_prompts, "custom": {} # custom modes will be added here as {mode_name: [round_prompt1, round_prompt2, ...]} } def detect_domain(user_prompt: str) -> str: prompt_lower = user_prompt.lower() math_keywords = ["solve", "integral", "derivative", "equation", "proof", "calculate", "sum", "product"] writing_keywords = ["write", "story", "essay", "novel", "poem", "article", "narrative", "creative"] coding_keywords = ["code", "program", "debug", "compile", "algorithm", "function"] if any(kw in prompt_lower for kw in math_keywords): logging.info("Domain detected as: math") return "math" elif any(kw in prompt_lower for kw in writing_keywords): logging.info("Domain detected as: writing") return "writing" elif any(kw in prompt_lower for kw in coding_keywords): logging.info("Domain detected as: coding") return "coding" else: logging.info("No specific domain detected; defaulting to coding") return "coding" class MemoryManager: def __init__(self) -> None: self.shared_memory: List[str] = [] def store(self, item: str) -> None: self.shared_memory.append(item) logging.info(f"[Memory Stored]: {item[:50]}...") def retrieve(self, query: str, top_k: int = 3) -> List[str]: query_lower = query.lower() relevant = [item for item in self.shared_memory if query_lower in item.lower()] if not relevant: logging.info("[Memory Retrieval]: No relevant memories found.") else: logging.info(f"[Memory Retrieval]: Found {len(relevant)} relevant memories.") return relevant[-top_k:] global_memory_manager = MemoryManager() def generate_response(model, tokenizer, prompt: str, max_tokens: int, temperature: float, top_p: float, repetition_penalty: float = 1.0, num_beams: int = 1) -> str: # Check cache first cache_key = f"{prompt}-{max_tokens}-{temperature}-{top_p}-{repetition_penalty}-{num_beams}" if cache_key in response_cache: logging.info("Returning cached response.") return response_cache[cache_key] input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device) streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) kwargs = dict( input_ids=input_ids, streamer=streamer, max_new_tokens=max_tokens, temperature=temperature, top_p=top_p, do_sample=True, repetition_penalty=repetition_penalty, num_beams=num_beams, ) thread = Thread(target=model.generate, kwargs=kwargs) with torch.no_grad(): thread.start() response = "" try: for text in streamer: response += text except Exception as e: logging.error(f"Error during generation: {e}") raise e thread.join() # Cache the response response_cache[cache_key] = response return response class MultiRoundAgent: def __init__(self, model, tokenizer, prompt_templates, memory_manager: MemoryManager): """ prompt_templates can be a dict (for default modes) or a list (for custom modes) """ self.model = model self.tokenizer = tokenizer self.prompt_templates = prompt_templates self.memory_manager = memory_manager def run_pipeline(self, user_prompt: str, params: Dict, show_raw: bool = False) -> Generator[str, None, None]: if isinstance(self.prompt_templates, dict): # Default fixed 4-round pipeline logging.info("--- Round 1 ---") prompt_r1 = self.prompt_templates["brainstorm"].format(user_prompt=user_prompt) r1 = generate_response(self.model, self.tokenizer, prompt_r1, params.get("max_new_tokens"), params.get("temp"), params.get("top_p"), params.get("repetition_penalty"), params.get("num_beams")) self.memory_manager.store(f"Round 1 Response: {r1}") logging.info("--- Round 2 ---") prompt_r2 = self.prompt_templates["round2"].format(brainstorm_response=r1, user_prompt=user_prompt) r2 = generate_response(self.model, self.tokenizer, prompt_r2, params.get("max_new_tokens") + 100, params.get("temp"), params.get("top_p"), params.get("repetition_penalty"), params.get("num_beams")) self.memory_manager.store(f"Round 2 Response: {r2}") logging.info("--- Round 3 ---") prompt_r3 = self.prompt_templates["synthesis"].format(round2_response=r2) input_ids_r3 = self.tokenizer.encode(prompt_r3, return_tensors="pt").to(self.model.device) streamer_r3 = TextIteratorStreamer(self.tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) kwargs_r3 = dict( input_ids=input_ids_r3, streamer=streamer_r3, max_new_tokens=params.get("max_new_tokens") // 2, temperature=params.get("temp"), top_p=params.get("top_p"), repetition_penalty=params.get("repetition_penalty"), num_beams=params.get("num_beams") ) thread_r3 = Thread(target=self.model.generate, kwargs=kwargs_r3) with torch.no_grad(): thread_r3.start() r3 = "" try: for text in streamer_r3: r3 += text yield r3 # Progressive updates except Exception as e: logging.error(f"Error during Round 3 streaming: {e}") raise e thread_r3.join() self.memory_manager.store(f"Final Synthesis Response: {r3}") logging.info("--- Round 4 ---") prompt_r4 = self.prompt_templates["rationale"].format(final_response=r3) r4 = generate_response(self.model, self.tokenizer, prompt_r4, 300, params.get("temp"), params.get("top_p"), params.get("repetition_penalty"), params.get("num_beams")) self.memory_manager.store(f"Round 4 Response: {r4}") final_output = (f"{r4}\n\n[Raw Outputs]\nRound 1:\n{r1}\n\nRound 2:\n{r2}\n\nRound 3:\n{r3}\n\nRound 4:\n{r4}\n") if show_raw else r4 yield final_output elif isinstance(self.prompt_templates, list): # Custom mode: iterate over rounds. prev_response = "" full_output = "" total_rounds = len(self.prompt_templates) for idx, round_template in enumerate(self.prompt_templates): round_num = idx + 1 logging.info(f"--- Custom Mode: Round {round_num} of {total_rounds} ---") if idx == 0: prompt = round_template.format(user_prompt=user_prompt) else: prompt = round_template.format(user_prompt=user_prompt, prev_response=prev_response) response = generate_response(self.model, self.tokenizer, prompt, params.get("max_new_tokens"), params.get("temp"), params.get("top_p"), params.get("repetition_penalty"), params.get("num_beams")) self.memory_manager.store(f"Custom Mode Round {round_num} Response: {response}") full_output += f"\n--- Round {round_num} ---\n{response}" prev_response = response yield full_output else: yield "Invalid prompt template format." @spaces.GPU(duration=180) def swarm_agent_iterative(user_prompt: str, temp: float, top_p: float, max_new_tokens: int, memory_top_k: int, prompt_templates, domain: str, show_raw: bool, repetition_penalty: float, num_beams: int) -> Generator[str, None, None]: model, tokenizer = get_model_and_tokenizer() agent = MultiRoundAgent(model, tokenizer, prompt_templates, global_memory_manager) params = { "temp": temp, "top_p": top_p, "max_new_tokens": max_new_tokens, "repetition_penalty": repetition_penalty, "num_beams": num_beams } return agent.run_pipeline(user_prompt, params, show_raw) def handle_explanation_request(user_prompt: str, history: List) -> str: retrieved = global_memory_manager.retrieve("Round 4 Response:", top_k=3) explanation_prompt = "Below are previous final outputs and related context from our conversation:\n" if retrieved: for item in retrieved: explanation_prompt += f"- {item}\n" else: explanation_prompt += "No stored final output found.\n" explanation_prompt += "\nRecent related exchanges:\n" for chat in history: if ("explain" in chat[0].lower()) or (chat[1] and "explain" in chat[1].lower()): explanation_prompt += f"User: {chat[0]}\nAssistant: {chat[1]}\n" explanation_prompt += "\nBased on the above context, please provide a detailed explanation of the creative choices." model, tokenizer = get_model_and_tokenizer() explanation = generate_response(model, tokenizer, explanation_prompt, 300, 0.7, 0.9) return explanation def format_history(history: List) -> List[Dict[str, str]]: messages = [] for item in history: if isinstance(item, (list, tuple)) and len(item) == 2: user_msg, assistant_msg = item if user_msg == "__final_agent_response__": continue messages.append({"role": "user", "content": user_msg}) if assistant_msg: messages.append({"role": "assistant", "content": assistant_msg}) elif isinstance(item, dict): messages.append(item) return messages def gradio_interface(message: str, history: List, param_state: Dict, prompt_state: Dict, mode: str) -> Generator[List[Dict[str, str]], None, None]: if "explain" in message.lower(): explanation = handle_explanation_request(message, history) history = history + [[message, explanation]] yield format_history(history) return try: temp = float(param_state.get("temperature", 0.5)) top_p = float(param_state.get("top_p", 0.9)) max_new_tokens = int(param_state.get("max_new_tokens", 300)) repetition_penalty = float(param_state.get("repetition_penalty", 1.0)) num_beams = int(param_state.get("num_beams", 1)) memory_top_k = int(param_state.get("memory_top_k", 2)) show_raw = bool(param_state.get("show_raw_output", False)) except Exception as e: logging.error(f"Parameter conversion error: {e}") temp, top_p, max_new_tokens, repetition_penalty, num_beams, memory_top_k, show_raw = 0.5, 0.9, 300, 1.0, 1, 2, False if mode in prompt_state.get("default", {}): prompt_templates = prompt_state["default"][mode] elif mode in prompt_state.get("custom", {}): prompt_templates = prompt_state["custom"][mode] else: detected = detect_domain(message) prompt_templates = prompt_state["default"].get(detected, prompt_state["default"]["coding"]) mode = detected history = history + [[message, ""]] # Show a loading status yield format_history(history) for partial_response in swarm_agent_iterative( user_prompt=message, temp=temp, top_p=top_p, max_new_tokens=max_new_tokens, memory_top_k=memory_top_k, prompt_templates=prompt_templates, domain=mode, show_raw=show_raw, repetition_penalty=repetition_penalty, num_beams=num_beams ): history[-1][1] = partial_response yield format_history(history) yield format_history(history) def generate_agent_audio(latest_text: str, voice_reference: str) -> str: """ Generate an audio response using the cloned voice. If the provided voice_reference is a key in the voice bank, its stored file path is used. """ if latest_text: if voice_reference in voice_bank: audio_path = clone(latest_text, voice_bank[voice_reference]) else: audio_path = clone(latest_text, voice_reference) return audio_path return None # NEW: Speech-to-Text Function using Whisper. def transcribe_audio(audio_file: str) -> str: """ Transcribe the provided audio file to text using the Whisper model. """ try: result = whisper_model.transcribe(audio_file) transcription = result.get("text", "").strip() logging.info(f"Transcription result: {transcription}") return transcription except Exception as e: logging.error(f"Transcription error: {e}") return "Transcription failed." # --------------------------- # Warm-Up Model Function # --------------------------- def warmup_model(): try: get_model_and_tokenizer() logging.info("Model warm-up complete.") except Exception as e: logging.error(f"Model warm-up failed: {e}") warmup_model() # =========================== # Custom Gradio Theme # =========================== theme = gr.themes.Soft( primary_hue="pink", secondary_hue="pink", neutral_hue="purple", font=['IBM Plex Sans', 'ui-sans-serif', 'system-ui', 'sans-serif'], ).set( background_fill_primary='white', shadow_drop='rgba(0,0,0,0.05) 0px 1px 2px 0px', shadow_drop_lg='0 1px 3px 0 rgb(0 0 0 / 0.1), 0 1px 2px -1px rgb(0 0 0 / 0.1)', shadow_spread='3px', block_background_fill='*background_fill_primary', block_border_width='1px', block_border_width_dark='1px', block_label_background_fill='*background_fill_primary', block_label_background_fill_dark='*background_fill_secondary', block_label_text_color='*neutral_500', block_label_text_color_dark='*neutral_200', block_label_margin='0', block_label_padding='*spacing_sm *spacing_lg', block_label_radius='calc(*radius_sm - 1px) 0 calc(*radius_sm - 1px) 0', block_label_text_size='*text_sm', block_label_text_weight='400', block_title_background_fill='none', block_title_background_fill_dark='none', block_title_text_color='*neutral_500', block_title_text_color_dark='*neutral_200', block_title_padding='0', block_title_radius='none', block_title_text_weight='400', panel_border_width='0', panel_border_width_dark='0', checkbox_background_color_selected='*color_accent', checkbox_background_color_selected_dark='*color_accent', checkbox_border_color='*neutral_300', checkbox_border_color_dark='*neutral_700', checkbox_border_color_focus='*color_accent', checkbox_border_color_focus_dark='*color_accent', checkbox_border_color_selected='*color_accent', checkbox_border_color_selected_dark='*color_accent', checkbox_border_width='*input_border_width', checkbox_shadow='*input_shadow', checkbox_label_background_fill_selected='*checkbox_label_background_fill', checkbox_label_background_fill_selected_dark='*checkbox_label_background_fill', checkbox_label_shadow='none', checkbox_label_text_color_selected='*checkbox_label_text_color', input_background_fill='*neutral_100', input_border_color='*border_color_primary', input_shadow='none', input_shadow_dark='none', input_shadow_focus='*input_shadow', input_shadow_focus_dark='*input_shadow', slider_color='*color_accent', slider_color_dark='*color_accent', button_primary_background_fill_hover='*primary_600', button_primary_background_fill_hover_dark='*primary_700', button_primary_shadow='none', button_primary_shadow_hover='*button_primary_shadow', button_primary_shadow_active='*button_primary_shadow', button_primary_shadow_dark='none', button_secondary_background_fill='*neutral_200', button_secondary_background_fill_hover='*neutral_300', button_secondary_background_fill_hover_dark='*neutral_700', button_secondary_text_color='black', button_secondary_shadow='*button_primary_shadow', button_secondary_shadow_hover='*button_secondary_shadow', button_secondary_shadow_active='*button_secondary_shadow', button_secondary_shadow_dark='*button_primary_shadow' ) # =========================== # Combined Gradio Interface # =========================== with gr.Blocks(theme=theme, title="Combined Voice Clone & Agent Chat") as demo: # Shared states for project settings, prompt configuration, and voice selection. param_state = gr.State({ "temperature": 0.5, "top_p": 0.9, "max_new_tokens": 300, "memory_top_k": 2, "show_raw_output": False, "repetition_penalty": 1.0, "num_beams": 1, "use_cpu": False # Toggle for device override }) prompt_state = gr.State(initial_prompt_state) selected_voice = gr.State(value="") # holds the currently selected voice # A status display to show device info. device_status = gr.Markdown(f"**Running on:** {device.upper()}") with gr.Tabs(): # ----- Tab 1: Voice Setup ----- with gr.Tab("Voice Setup"): gr.Markdown("
Clone a voice and save it with a custom name. Test TTS using your cloned voices.
") with gr.Row(): text_input = gr.Textbox(label='Text to Clone', placeholder="Enter the text to speak...", elem_classes="full-width") with gr.Row(): audio_input = gr.Audio(label='Voice Reference Audio', type='filepath') with gr.Row(): clone_btn = gr.Button("Clone Voice") with gr.Row(): output_audio = gr.Audio(label='Cloned Voice Output', type='filepath') clone_btn.click(fn=clone, inputs=[text_input, audio_input], outputs=output_audio) with gr.Row(): voice_name_input = gr.Textbox(label="Voice Name", placeholder="Enter a name for this voice clone") with gr.Row(): save_voice_btn = gr.Button("Save Voice") save_voice_btn.click(fn=save_voice, inputs=[voice_name_input, output_audio], outputs=[]) with gr.Row(): refresh_voice_btn_setup = gr.Button("Refresh Voice List") voice_dropdown_setup = gr.Dropdown(choices=get_voice_options(), label="Select Saved Voice", interactive=True) set_voice_btn = gr.Button("Set Selected Voice") refresh_voice_btn_setup.click(fn=refresh_voice_list, outputs=voice_dropdown_setup) set_voice_btn.click(fn=lambda x: x, inputs=[voice_dropdown_setup], outputs=selected_voice) gr.Markdown("(The selected voice will be used for TTS responses in Chat.)
") gr.Markdown("Multi-round agent with prompt chaining. Ask me anything!
These settings affect the entire project.
") gr.Markdown("Agent Chat using DeepSeek Agent Swarm
") if __name__ == "__main__": demo.launch(share=True)