import gradio as gr import json import os import datetime import asyncio import aiohttp from aiohttp import ClientSession API_URL = os.environ.get('API_URL') API_KEY = os.environ.get('API_KEY') headers = { "Authorization": f"Bearer {API_KEY}", "Content-Type": "application/json" } DEFAULT_PARAMS = { "temperature": 0.8, "top_p": 0.95, "top_k": 40, "frequency_penalty": 0, "presence_penalty": 0, "repetition_penalty": 1.1, "max_tokens": 512 } active_tasks = {} def get_timestamp(): return datetime.datetime.now().strftime("%H:%M:%S") async def predict(message, history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens): history_format = [{"role": "system", "content": system_prompt}] for human, assistant in history: history_format.append({"role": "user", "content": human}) if assistant: history_format.append({"role": "assistant", "content": assistant}) history_format.append({"role": "user", "content": message}) if not message.startswith(('*', '"')): print(f"<|system|> {system_prompt}") print(f"{get_timestamp()} <|user|> {message}") current_params = { "temperature": temperature, "top_p": top_p, "top_k": top_k, "frequency_penalty": frequency_penalty, "presence_penalty": presence_penalty, "repetition_penalty": repetition_penalty, "max_tokens": max_tokens } non_default_params = {k: v for k, v in current_params.items() if v != DEFAULT_PARAMS[k]} if non_default_params and not message.startswith(('*', '"')): for param, value in non_default_params.items(): print(f"{param}={value}") data = { "model": "meta-llama/Meta-Llama-3.1-405B-Instruct", "messages": history_format, "stream": True, "temperature": temperature, "top_p": top_p, "top_k": top_k, "frequency_penalty": frequency_penalty, "presence_penalty": presence_penalty, "repetition_penalty": repetition_penalty, "max_tokens": max_tokens } try: async with ClientSession() as session: async with session.post(API_URL, headers=headers, json=data) as response: partial_message = "" async for line in response.content: if asyncio.current_task().cancelled(): break if line: line = line.decode('utf-8') if line.startswith("data: "): if line.strip() == "data: [DONE]": break try: json_data = json.loads(line[6:]) if 'choices' in json_data and json_data['choices']: content = json_data['choices'][0]['delta'].get('content', '') if content: partial_message += content yield partial_message except json.JSONDecodeError: continue if partial_message: yield partial_message except Exception as e: print(f"Request error: {e}") yield f"An error occurred: {str(e)}" def import_chat(custom_format_string): try: sections = custom_format_string.split('<|') imported_history = [] system_prompt = "" for section in sections: if section.startswith('system|>'): system_prompt = section.replace('system|>', '').strip() elif section.startswith('user|>'): user_message = section.replace('user|>', '').strip() imported_history.append([user_message, None]) elif section.startswith('assistant|>'): assistant_message = section.replace('assistant|>', '').strip() if imported_history: imported_history[-1][1] = assistant_message else: imported_history.append(["", assistant_message]) return imported_history, system_prompt except Exception as e: print(f"Error importing chat: {e}") return None, None def export_chat(history, system_prompt): export_data = f"<|system|> {system_prompt}\n\n" if history is not None: for user_msg, assistant_msg in history: export_data += f"<|user|> {user_msg}\n\n" if assistant_msg: export_data += f"<|assistant|> {assistant_msg}\n\n" return export_data def sanitize_chatbot_history(history): """Ensure each entry in the chatbot history is a tuple of two items.""" return [tuple(entry[:2]) for entry in history] with gr.Blocks(theme='gradio/monochrome') as demo: with gr.Row(): with gr.Column(scale=2): chatbot = gr.Chatbot(value=[]) msg = gr.Textbox(label="Message") with gr.Row(): clear = gr.Button("Clear") regenerate = gr.Button("Regenerate") stop_btn = gr.Button("Stop") with gr.Row(): with gr.Column(scale=4): import_textbox = gr.Textbox(label="Import textbox", lines=5) with gr.Column(scale=1): export_button = gr.Button("Export Chat") import_button = gr.Button("Import Chat") with gr.Column(scale=1): system_prompt = gr.Textbox("", label="System Prompt", lines=5) temperature = gr.Slider(0, 2, value=0.8, step=0.01, label="Temperature") top_p = gr.Slider(0, 1, value=0.95, step=0.01, label="Top P") top_k = gr.Slider(1, 500, value=40, step=1, label="Top K") frequency_penalty = gr.Slider(-2, 2, value=0, step=0.1, label="Frequency Penalty") presence_penalty = gr.Slider(-2, 2, value=0, step=0.1, label="Presence Penalty") repetition_penalty = gr.Slider(0.01, 5, value=1.1, step=0.01, label="Repetition Penalty") max_tokens = gr.Slider(1, 4096, value=512, step=1, label="Max Output (max_tokens)") async def user(user_message, history): history = sanitize_chatbot_history(history or []) return "", history + [(user_message, None)] async def bot(history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens): history = sanitize_chatbot_history(history or []) if not history: yield history return user_message = history[-1][0] bot_message = predict(user_message, history[:-1], system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens) history[-1] = (history[-1][0], "") # Ensure it's a tuple task_id = id(asyncio.current_task()) active_tasks[task_id] = asyncio.current_task() try: async for chunk in bot_message: if task_id not in active_tasks: break history[-1] = (history[-1][0], chunk) # Update as a tuple yield history except asyncio.CancelledError: pass finally: if task_id in active_tasks: del active_tasks[task_id] if history[-1][1] == "": history[-1] = (history[-1][0], " [Generation stopped]") yield history async def regenerate_response(history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens): # Cancel any ongoing generation for task in list(active_tasks.values()): task.cancel() # Wait for a short time to ensure cancellation is processed await asyncio.sleep(0.1) history = sanitize_chatbot_history(history or []) if history: history[-1] = (history[-1][0], None) # Reset last response async for new_history in bot(history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens): yield new_history else: yield [] def import_chat_wrapper(custom_format_string): imported_history, imported_system_prompt = import_chat(custom_format_string) return sanitize_chatbot_history(imported_history), imported_system_prompt submit_event = msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then( bot, [chatbot, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens], chatbot, concurrency_limit=5 ) clear.click(lambda: [], None, chatbot, queue=False) regenerate_event = regenerate.click( regenerate_response, [chatbot, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens], chatbot, concurrency_limit=5 ) import_button.click(import_chat_wrapper, inputs=[import_textbox], outputs=[chatbot, system_prompt], concurrency_limit=5) export_button.click( export_chat, inputs=[chatbot, system_prompt], outputs=[import_textbox], concurrency_limit=5 ) stop_btn.click( lambda: [task.cancel() for task in list(active_tasks.values())], None, None, cancels=[submit_event, regenerate_event], queue=False ) if __name__ == "__main__": demo.launch(debug=True, max_threads=20)