|
import gradio as gr |
|
import os |
|
import spaces |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer |
|
import torch |
|
from threading import Thread |
|
import logging |
|
from typing import Tuple, List, Dict, Generator |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
|
|
|
|
MODEL_ID = "unsloth/DeepSeek-R1-Distill-Qwen-7B-unsloth-bnb-4bit" |
|
models: Dict[str, AutoModelForCausalLM] = {} |
|
tokenizers: Dict[str, AutoTokenizer] = {} |
|
|
|
bnb_config_4bit = BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_quant_type="nf4", |
|
bnb_4bit_compute_dtype=torch.bfloat16, |
|
) |
|
|
|
def get_model_and_tokenizer() -> Tuple[AutoModelForCausalLM, AutoTokenizer]: |
|
""" |
|
Lazy-load the model and tokenizer if not already loaded. |
|
Returns: |
|
Tuple[model, tokenizer]: The loaded model and tokenizer. |
|
""" |
|
if "7B" not in models: |
|
logging.info(f"Loading 7B model: {MODEL_ID} on demand") |
|
try: |
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) |
|
model = AutoModelForCausalLM.from_pretrained( |
|
MODEL_ID, |
|
quantization_config=bnb_config_4bit, |
|
torch_dtype=torch.bfloat16, |
|
device_map='auto', |
|
trust_remote_code=True, |
|
) |
|
model.eval() |
|
models["7B"] = model |
|
tokenizers["7B"] = tokenizer |
|
logging.info("Loaded 7B model on demand.") |
|
except Exception as e: |
|
logging.error(f"Failed to load model and tokenizer: {e}") |
|
raise e |
|
return models["7B"], tokenizers["7B"] |
|
|
|
|
|
default_prompts = { |
|
"coding": { |
|
"brainstorm": ( |
|
"**Round 1: Brainstorm & Analysis**\n" |
|
"Please analyze the following coding challenge or question. Consider the overall problem, " |
|
"potential edge cases, and any assumptions you might need to make. Explain your reasoning as you think aloud.\n\n" |
|
"**User Request:**\n{user_prompt}\n" |
|
), |
|
"round2": ( |
|
"**Round 2: Detailed Reasoning & Strategy**\n" |
|
"Based on your initial analysis, please break down the problem into logical steps. " |
|
"Outline a plan or strategy that could be used to solve the challenge, highlighting key algorithms, structures, or design considerations.\n\n" |
|
"**Initial Analysis:**\n{brainstorm_response}\n\n" |
|
"**User Request:**\n{user_prompt}\n" |
|
), |
|
"synthesis": ( |
|
"**Round 3: Synthesis & Implementation**\n" |
|
"Taking into account the steps outlined previously, synthesize a coherent solution. " |
|
"Provide a detailed explanation of how the code addresses the problem while encouraging best practices and clear logic.\n\n" |
|
"**Detailed Strategy:**\n{round2_response}\n" |
|
), |
|
"rationale": ( |
|
"**Round 4: Reflection & Final Output**\n" |
|
"Review your solution and provide a final, well-rounded response that summarizes your reasoning and the implementation strategy. " |
|
"Explain any key decisions made during the process and how they contribute to an effective solution.\n\n" |
|
"**Final Draft:**\n{final_response}\n" |
|
) |
|
}, |
|
"math": { |
|
"brainstorm": ( |
|
"**Round 1: Problem Analysis & Exploration**\n" |
|
"Carefully analyze the mathematical problem provided. Describe the underlying concepts and any assumptions you are making. " |
|
"Detail your initial reasoning and potential methods to tackle the problem.\n\n" |
|
"**Problem:**\n{user_prompt}\n" |
|
), |
|
"round2": ( |
|
"**Round 2: Detailed Reasoning & Methodology**\n" |
|
"Based on your initial exploration, break down the problem into sequential steps or methodologies. " |
|
"Explain the reasoning behind each step and how they connect to solve the problem.\n\n" |
|
"**Initial Analysis:**\n{brainstorm_response}\n\n" |
|
"**Problem:**\n{user_prompt}\n" |
|
), |
|
"synthesis": ( |
|
"**Round 3: Synthesis & Step-by-Step Solution**\n" |
|
"Integrate your previous reasoning into a structured solution. Clearly explain each step of your calculation or proof, " |
|
"ensuring that your logical progression is easy to follow.\n\n" |
|
"**Detailed Methodology:**\n{round2_response}\n" |
|
), |
|
"rationale": ( |
|
"**Round 4: Reflection & Final Explanation**\n" |
|
"Present your final solution along with a detailed explanation of the reasoning behind each step. " |
|
"Discuss any assumptions and insights that helped you arrive at the final answer.\n\n" |
|
"**Final Solution:**\n{final_response}\n" |
|
) |
|
}, |
|
"writing": { |
|
"brainstorm": ( |
|
"**Round 1: Creative Exploration & Conceptualization**\n" |
|
"Read the following writing prompt and explore its themes, tone, and potential narrative directions. " |
|
"Outline your initial thoughts and reasoning behind various creative choices.\n\n" |
|
"**Writing Prompt:**\n{user_prompt}\n" |
|
), |
|
"round2": ( |
|
"**Round 2: Detailed Outline & Narrative Structure**\n" |
|
"Based on your brainstorming, create a detailed outline that organizes the narrative or essay. " |
|
"Explain the reasoning behind your structure, the flow of ideas, and how you plan to incorporate creative elements.\n\n" |
|
"**Initial Brainstorming:**\n{brainstorm_response}\n\n" |
|
"**Writing Prompt:**\n{user_prompt}\n" |
|
), |
|
"synthesis": ( |
|
"**Round 3: Draft Synthesis & Refinement**\n" |
|
"Integrate your outline and creative ideas into a coherent draft. Provide a well-rounded narrative that is both engaging and logically structured. " |
|
"Explain your thought process as you refine the narrative.\n\n" |
|
"**Outline & Strategy:**\n{round2_response}\n" |
|
), |
|
"rationale": ( |
|
"**Round 4: Reflection & Final Editing**\n" |
|
"Review your draft and provide a final version that reflects thoughtful editing and creative reasoning. " |
|
"Explain the choices made in refining the text, from structure to stylistic decisions.\n\n" |
|
"**Final Draft:**\n{final_response}\n" |
|
) |
|
} |
|
} |
|
|
|
|
|
def detect_domain(user_prompt: str) -> str: |
|
""" |
|
Detect the domain based on keywords. |
|
Args: |
|
user_prompt (str): The user query. |
|
Returns: |
|
str: One of 'math', 'writing', or 'coding' (defaulting to coding). |
|
""" |
|
prompt_lower = user_prompt.lower() |
|
math_keywords = ["solve", "integral", "derivative", "equation", "proof", "calculate", "sum", "product"] |
|
writing_keywords = ["write", "story", "essay", "novel", "poem", "article", "narrative", "creative"] |
|
coding_keywords = ["code", "program", "debug", "compile", "algorithm", "function"] |
|
|
|
if any(kw in prompt_lower for kw in math_keywords): |
|
logging.info("Domain detected as: math") |
|
return "math" |
|
elif any(kw in prompt_lower for kw in writing_keywords): |
|
logging.info("Domain detected as: writing") |
|
return "writing" |
|
elif any(kw in prompt_lower for kw in coding_keywords): |
|
logging.info("Domain detected as: coding") |
|
return "coding" |
|
else: |
|
logging.info("No specific domain detected; defaulting to coding") |
|
return "coding" |
|
|
|
|
|
class MemoryManager: |
|
"""Encapsulate shared memory for storing and retrieving conversation items.""" |
|
def __init__(self) -> None: |
|
self.shared_memory: List[str] = [] |
|
|
|
def store(self, item: str) -> None: |
|
"""Store a memory item and log an excerpt.""" |
|
self.shared_memory.append(item) |
|
logging.info(f"[Memory Stored]: {item[:50]}...") |
|
|
|
def retrieve(self, query: str, top_k: int = 3) -> List[str]: |
|
"""Retrieve recent memory items containing the query text.""" |
|
query_lower = query.lower() |
|
relevant = [item for item in self.shared_memory if query_lower in item.lower()] |
|
if not relevant: |
|
logging.info("[Memory Retrieval]: No relevant memories found.") |
|
else: |
|
logging.info(f"[Memory Retrieval]: Found {len(relevant)} relevant memories.") |
|
return relevant[-top_k:] |
|
|
|
global_memory_manager = MemoryManager() |
|
|
|
|
|
def generate_response(model, tokenizer, prompt: str, max_tokens: int, temperature: float, top_p: float) -> str: |
|
"""Generate a response for a given prompt.""" |
|
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device) |
|
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) |
|
kwargs = dict( |
|
input_ids=input_ids, |
|
streamer=streamer, |
|
max_new_tokens=max_tokens, |
|
temperature=temperature, |
|
top_p=top_p, |
|
do_sample=True, |
|
) |
|
thread = Thread(target=model.generate, kwargs=kwargs) |
|
with torch.no_grad(): |
|
thread.start() |
|
response = "" |
|
try: |
|
for text in streamer: |
|
response += text |
|
except Exception as e: |
|
logging.error(f"Error during generation: {e}") |
|
raise e |
|
thread.join() |
|
return response |
|
|
|
|
|
class MultiRoundAgent: |
|
""" |
|
Encapsulate the multi-round prompt chaining and response generation. |
|
This class runs a 4-round pipeline based on the given preset. |
|
""" |
|
def __init__(self, model, tokenizer, prompt_templates: Dict[str, str], memory_manager: MemoryManager): |
|
self.model = model |
|
self.tokenizer = tokenizer |
|
self.prompt_templates = prompt_templates |
|
self.memory_manager = memory_manager |
|
|
|
def run_pipeline(self, user_prompt: str, params: Dict, show_raw: bool = False) -> Generator[str, None, None]: |
|
|
|
logging.info("--- Round 1 ---") |
|
prompt_r1 = self.prompt_templates["brainstorm"].format(user_prompt=user_prompt) |
|
r1 = generate_response(self.model, self.tokenizer, prompt_r1, params.get("max_new_tokens"), params.get("temp"), params.get("top_p")) |
|
self.memory_manager.store(f"Round 1 Response: {r1}") |
|
|
|
|
|
logging.info("--- Round 2 ---") |
|
prompt_r2 = self.prompt_templates["round2"].format(brainstorm_response=r1, user_prompt=user_prompt) |
|
r2 = generate_response(self.model, self.tokenizer, prompt_r2, params.get("max_new_tokens") + 100, params.get("temp"), params.get("top_p")) |
|
self.memory_manager.store(f"Round 2 Response: {r2}") |
|
|
|
|
|
logging.info("--- Round 3 ---") |
|
prompt_r3 = self.prompt_templates["synthesis"].format(round2_response=r2) |
|
input_ids_r3 = self.tokenizer.encode(prompt_r3, return_tensors="pt").to(self.model.device) |
|
streamer_r3 = TextIteratorStreamer(self.tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) |
|
kwargs_r3 = dict( |
|
input_ids=input_ids_r3, |
|
streamer=streamer_r3, |
|
max_new_tokens=params.get("max_new_tokens") // 2, |
|
temperature=params.get("temp"), |
|
top_p=params.get("top_p") |
|
) |
|
thread_r3 = Thread(target=self.model.generate, kwargs=kwargs_r3) |
|
with torch.no_grad(): |
|
thread_r3.start() |
|
r3 = "" |
|
try: |
|
for text in streamer_r3: |
|
r3 += text |
|
yield r3 |
|
except Exception as e: |
|
logging.error(f"Error during Round 3 streaming: {e}") |
|
raise e |
|
thread_r3.join() |
|
self.memory_manager.store(f"Final Synthesis Response: {r3}") |
|
|
|
|
|
logging.info("--- Round 4 ---") |
|
prompt_r4 = self.prompt_templates["rationale"].format(final_response=r3) |
|
r4 = generate_response(self.model, self.tokenizer, prompt_r4, 300, params.get("temp"), params.get("top_p")) |
|
self.memory_manager.store(f"Round 4 Response: {r4}") |
|
|
|
|
|
if show_raw: |
|
final_output = ( |
|
f"{r4}\n\n[Raw Outputs]\n" |
|
f"Round 1:\n{r1}\n\n" |
|
f"Round 2:\n{r2}\n\n" |
|
f"Round 3:\n{r3}\n\n" |
|
f"Round 4:\n{r4}\n" |
|
) |
|
else: |
|
final_output = r4 |
|
|
|
yield final_output |
|
|
|
|
|
@spaces.GPU(duration=180) |
|
def swarm_agent_iterative(user_prompt: str, temp: float, top_p: float, max_new_tokens: int, memory_top_k: int, |
|
prompt_templates: Dict[str, str], domain: str, show_raw: bool) -> Generator[str, None, None]: |
|
""" |
|
Wraps the multi-round agent functionality. Depending on the detected or selected domain, |
|
it runs the 4-round pipeline. |
|
""" |
|
model, tokenizer = get_model_and_tokenizer() |
|
agent = MultiRoundAgent(model, tokenizer, prompt_templates, global_memory_manager) |
|
params = {"temp": temp, "top_p": top_p, "max_new_tokens": max_new_tokens} |
|
return agent.run_pipeline(user_prompt, params, show_raw) |
|
|
|
|
|
def handle_explanation_request(user_prompt: str, history: List) -> str: |
|
""" |
|
Retrieve stored rationale and additional context from conversation history, |
|
then generate an explanation. |
|
""" |
|
retrieved = global_memory_manager.retrieve("Round 4 Response:", top_k=3) |
|
explanation_prompt = "Below are previous final outputs and related context from our conversation:\n" |
|
if retrieved: |
|
for item in retrieved: |
|
explanation_prompt += f"- {item}\n" |
|
else: |
|
explanation_prompt += "No stored final output found.\n" |
|
|
|
explanation_prompt += "\nRecent related exchanges:\n" |
|
for chat in history: |
|
if ("explain" in chat[0].lower()) or (chat[1] and "explain" in chat[1].lower()): |
|
explanation_prompt += f"User: {chat[0]}\nAssistant: {chat[1]}\n" |
|
|
|
explanation_prompt += "\nBased on the above context, please provide a detailed explanation of the creative choices." |
|
model, tokenizer = get_model_and_tokenizer() |
|
explanation = generate_response(model, tokenizer, explanation_prompt, 300, 0.7, 0.9) |
|
return explanation |
|
|
|
|
|
def format_history(history: List) -> List[Dict[str, str]]: |
|
""" |
|
Convert history (list of [user, assistant] pairs) into a list of message dictionaries. |
|
""" |
|
messages = [] |
|
for item in history: |
|
if isinstance(item, (list, tuple)) and len(item) == 2: |
|
user_msg, assistant_msg = item |
|
messages.append({"role": "user", "content": user_msg}) |
|
if assistant_msg: |
|
messages.append({"role": "assistant", "content": assistant_msg}) |
|
elif isinstance(item, dict): |
|
messages.append(item) |
|
return messages |
|
|
|
|
|
def gradio_interface(message: str, history: List, param_state: Dict, prompt_state: Dict, mode: str) -> Generator[List[Dict[str, str]], None, None]: |
|
""" |
|
Called by Gradio's ChatInterface. Uses current generation parameters and preset prompt templates. |
|
If the user asks for an explanation, routes the request accordingly. |
|
The selected mode (coding, math, or writing) overrides automatic domain detection. |
|
""" |
|
if "explain" in message.lower(): |
|
explanation = handle_explanation_request(message, history) |
|
history = history + [[message, explanation]] |
|
yield format_history(history) |
|
return |
|
|
|
try: |
|
temp = float(param_state.get("temperature", 0.5)) |
|
top_p = float(param_state.get("top_p", 0.9)) |
|
max_new_tokens = int(param_state.get("max_new_tokens", 300)) |
|
memory_top_k = int(param_state.get("memory_top_k", 2)) |
|
show_raw = bool(param_state.get("show_raw_output", False)) |
|
except Exception as e: |
|
logging.error(f"Parameter conversion error: {e}") |
|
temp, top_p, max_new_tokens, memory_top_k, show_raw = 0.5, 0.9, 300, 2, False |
|
|
|
|
|
domain = mode if mode in default_prompts else detect_domain(message) |
|
|
|
prompt_templates = prompt_state.get(domain, default_prompts.get(domain, default_prompts["coding"])) |
|
|
|
history = history + [[message, ""]] |
|
for partial_response in swarm_agent_iterative( |
|
user_prompt=message, |
|
temp=temp, |
|
top_p=top_p, |
|
max_new_tokens=max_new_tokens, |
|
memory_top_k=memory_top_k, |
|
prompt_templates=prompt_templates, |
|
domain=domain, |
|
show_raw=show_raw |
|
): |
|
history[-1][1] = partial_response |
|
yield format_history(history) |
|
|
|
|
|
ui_description = ''' |
|
<div> |
|
<h1 style="text-align: center;">DeepSeek Agent Swarm Chat</h1> |
|
<p style="text-align: center;"> |
|
Multi-round agent with 4-round prompt chaining, supporting three modes: |
|
<br>- Coding |
|
<br>- Math |
|
<br>- Writing |
|
</p> |
|
</div> |
|
''' |
|
|
|
ui_license = """ |
|
<p/> |
|
--- |
|
""" |
|
|
|
ui_placeholder = """ |
|
<div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;"> |
|
<h1 style="font-size: 28px; margin-bottom: 2px; opacity: 0.55;">DeepSeek Agent Swarm</h1> |
|
<p style="font-size: 18px; margin-bottom: 2px; opacity: 0.65;">Ask me anything...</p> |
|
</div> |
|
""" |
|
|
|
css = """ |
|
h1 { |
|
text-align: center; |
|
display: block; |
|
} |
|
#duplicate-button { |
|
margin: auto; |
|
color: white; |
|
background: #1565c0; |
|
border-radius: 100vh; |
|
} |
|
""" |
|
|
|
|
|
with gr.Blocks(css=css, title="DeepSeek Agent Swarm Chat") as demo: |
|
gr.Markdown(ui_description) |
|
|
|
param_state = gr.State({ |
|
"temperature": 0.5, |
|
"top_p": 0.9, |
|
"max_new_tokens": 300, |
|
"memory_top_k": 2, |
|
"show_raw_output": False, |
|
}) |
|
prompt_state = gr.State({ |
|
"coding": default_prompts["coding"], |
|
"math": default_prompts["math"], |
|
"writing": default_prompts["writing"], |
|
}) |
|
|
|
with gr.Tabs(): |
|
with gr.Tab("Chat"): |
|
|
|
mode_selector = gr.Radio(choices=["coding", "math", "writing"], value="coding", label="Select Mode") |
|
chatbot = gr.Chatbot(height=450, placeholder=ui_placeholder, label="Agent Swarm Output", type="messages") |
|
gr.ChatInterface( |
|
fn=gradio_interface, |
|
chatbot=chatbot, |
|
additional_inputs=[param_state, prompt_state, mode_selector], |
|
examples=[ |
|
['How can we build a robust web service that scales efficiently under load?'], |
|
['Solve the integral of x^2 from 0 to 1.'], |
|
['Write a short story about a mysterious writer in a busy city.'], |
|
['Create a creative and reflective solution for a coding challenge.'] |
|
], |
|
cache_examples=False, |
|
type="messages", |
|
) |
|
with gr.Tab("Parameters"): |
|
gr.Markdown("### Generation Parameters") |
|
temp_slider = gr.Slider(minimum=0.1, maximum=1.0, step=0.1, value=0.5, label="Temperature") |
|
top_p_slider = gr.Slider(minimum=0.01, maximum=1.0, step=0.05, value=0.9, label="Top P") |
|
max_tokens_num = gr.Number(value=300, label="Max new tokens", precision=0) |
|
memory_topk_slider = gr.Slider(minimum=1, maximum=5, step=1, value=2, label="Memory Retrieval Top K") |
|
show_raw_checkbox = gr.Checkbox(value=False, label="Show Raw Output") |
|
save_params_btn = gr.Button("Save Parameters") |
|
save_params_btn.click( |
|
lambda t, p, m, k, s: { |
|
"temperature": t, |
|
"top_p": p, |
|
"max_new_tokens": m, |
|
"memory_top_k": k, |
|
"show_raw_output": s |
|
}, |
|
inputs=[temp_slider, top_p_slider, max_tokens_num, memory_topk_slider, show_raw_checkbox], |
|
outputs=param_state, |
|
) |
|
with gr.Tab("Prompt Config"): |
|
gr.Markdown("### Configure Prompt Templates for Each Preset") |
|
with gr.Tabs(): |
|
with gr.Tab("Coding"): |
|
prompt_brainstorm_box_code = gr.Textbox( |
|
value=default_prompts["coding"]["brainstorm"], |
|
label="Brainstorm Prompt (Coding)", |
|
lines=8, |
|
) |
|
prompt_round2_box_code = gr.Textbox( |
|
value=default_prompts["coding"]["round2"], |
|
label="Round 2 Prompt (Coding)", |
|
lines=8, |
|
) |
|
prompt_synthesis_box_code = gr.Textbox( |
|
value=default_prompts["coding"]["synthesis"], |
|
label="Synthesis Prompt (Coding)", |
|
lines=8, |
|
) |
|
prompt_rationale_box_code = gr.Textbox( |
|
value=default_prompts["coding"]["rationale"], |
|
label="Rationale Prompt (Coding)", |
|
lines=8, |
|
) |
|
with gr.Tab("Math"): |
|
prompt_brainstorm_box_math = gr.Textbox( |
|
value=default_prompts["math"]["brainstorm"], |
|
label="Brainstorm Prompt (Math)", |
|
lines=8, |
|
) |
|
prompt_round2_box_math = gr.Textbox( |
|
value=default_prompts["math"]["round2"], |
|
label="Round 2 Prompt (Math)", |
|
lines=8, |
|
) |
|
prompt_synthesis_box_math = gr.Textbox( |
|
value=default_prompts["math"]["synthesis"], |
|
label="Synthesis Prompt (Math)", |
|
lines=8, |
|
) |
|
prompt_rationale_box_math = gr.Textbox( |
|
value=default_prompts["math"]["rationale"], |
|
label="Rationale Prompt (Math)", |
|
lines=8, |
|
) |
|
with gr.Tab("Writing"): |
|
prompt_brainstorm_box_writing = gr.Textbox( |
|
value=default_prompts["writing"]["brainstorm"], |
|
label="Brainstorm Prompt (Writing)", |
|
lines=8, |
|
) |
|
prompt_round2_box_writing = gr.Textbox( |
|
value=default_prompts["writing"]["round2"], |
|
label="Round 2 Prompt (Writing)", |
|
lines=8, |
|
) |
|
prompt_synthesis_box_writing = gr.Textbox( |
|
value=default_prompts["writing"]["synthesis"], |
|
label="Synthesis Prompt (Writing)", |
|
lines=8, |
|
) |
|
prompt_rationale_box_writing = gr.Textbox( |
|
value=default_prompts["writing"]["rationale"], |
|
label="Rationale Prompt (Writing)", |
|
lines=8, |
|
) |
|
save_prompts_btn = gr.Button("Save Prompts") |
|
def save_prompts(code_brain, code_r2, code_syn, code_rat, math_brain, math_r2, math_syn, math_rat, writing_brain, writing_r2, writing_syn, writing_rat): |
|
return { |
|
"coding": { |
|
"brainstorm": code_brain, |
|
"round2": code_r2, |
|
"synthesis": code_syn, |
|
"rationale": code_rat, |
|
}, |
|
"math": { |
|
"brainstorm": math_brain, |
|
"round2": math_r2, |
|
"synthesis": math_syn, |
|
"rationale": math_rat, |
|
}, |
|
"writing": { |
|
"brainstorm": writing_brain, |
|
"round2": writing_r2, |
|
"synthesis": writing_syn, |
|
"rationale": writing_rat, |
|
} |
|
} |
|
save_prompts_btn.click( |
|
save_prompts, |
|
inputs=[prompt_brainstorm_box_code, prompt_round2_box_code, prompt_synthesis_box_code, prompt_rationale_box_code, |
|
prompt_brainstorm_box_math, prompt_round2_box_math, prompt_synthesis_box_math, prompt_rationale_box_math, |
|
prompt_brainstorm_box_writing, prompt_round2_box_writing, prompt_synthesis_box_writing, prompt_rationale_box_writing], |
|
outputs=prompt_state, |
|
) |
|
gr.Markdown(ui_license) |
|
|
|
if __name__ == "__main__": |
|
demo.launch(share=True) |
|
|