Corvius's picture
uhhhh
e1d0627 verified
raw
history blame
No virus
9.75 kB
import gradio as gr
import json
import os
import datetime
import asyncio
import aiohttp
from aiohttp import ClientSession
API_URL = os.environ.get('API_URL')
API_KEY = os.environ.get('API_KEY')
headers = {
"Authorization": f"Bearer {API_KEY}",
"Content-Type": "application/json"
}
DEFAULT_PARAMS = {
"temperature": 0.8,
"top_p": 0.95,
"top_k": 40,
"frequency_penalty": 0,
"presence_penalty": 0,
"repetition_penalty": 1.1,
"max_tokens": 512
}
active_tasks = {}
def get_timestamp():
return datetime.datetime.now().strftime("%H:%M:%S")
async def predict(message, history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens):
history_format = [{"role": "system", "content": system_prompt}]
for human, assistant in history:
history_format.append({"role": "user", "content": human})
if assistant:
history_format.append({"role": "assistant", "content": assistant})
history_format.append({"role": "user", "content": message})
if not message.startswith(('*', '"')):
print(f"<|system|> {system_prompt}")
print(f"{get_timestamp()} <|user|> {message}")
current_params = {
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"frequency_penalty": frequency_penalty,
"presence_penalty": presence_penalty,
"repetition_penalty": repetition_penalty,
"max_tokens": max_tokens
}
non_default_params = {k: v for k, v in current_params.items() if v != DEFAULT_PARAMS[k]}
if non_default_params and not message.startswith(('*', '"')):
for param, value in non_default_params.items():
print(f"{param}={value}")
data = {
"model": "meta-llama/Meta-Llama-3.1-405B-Instruct",
"messages": history_format,
"stream": True,
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"frequency_penalty": frequency_penalty,
"presence_penalty": presence_penalty,
"repetition_penalty": repetition_penalty,
"max_tokens": max_tokens
}
try:
async with ClientSession() as session:
async with session.post(API_URL, headers=headers, json=data) as response:
partial_message = ""
async for line in response.content:
if asyncio.current_task().cancelled():
break
if line:
line = line.decode('utf-8')
if line.startswith("data: "):
if line.strip() == "data: [DONE]":
break
try:
json_data = json.loads(line[6:])
if 'choices' in json_data and json_data['choices']:
content = json_data['choices'][0]['delta'].get('content', '')
if content:
partial_message += content
yield partial_message
except json.JSONDecodeError:
continue
if partial_message:
yield partial_message
except Exception as e:
print(f"Request error: {e}")
yield f"An error occurred: {str(e)}"
def import_chat(custom_format_string):
try:
sections = custom_format_string.split('<|')
imported_history = []
system_prompt = ""
for section in sections:
if section.startswith('system|>'):
system_prompt = section.replace('system|>', '').strip()
elif section.startswith('user|>'):
user_message = section.replace('user|>', '').strip()
imported_history.append([user_message, None])
elif section.startswith('assistant|>'):
assistant_message = section.replace('assistant|>', '').strip()
if imported_history:
imported_history[-1][1] = assistant_message
else:
imported_history.append(["", assistant_message])
return imported_history, system_prompt
except Exception as e:
print(f"Error importing chat: {e}")
return None, None
def export_chat(history, system_prompt):
export_data = f"<|system|> {system_prompt}\n\n"
if history is not None:
for user_msg, assistant_msg in history:
export_data += f"<|user|> {user_msg}\n\n"
if assistant_msg:
export_data += f"<|assistant|> {assistant_msg}\n\n"
return export_data
def sanitize_chatbot_history(history):
"""Ensure each entry in the chatbot history is a tuple of two items."""
return [tuple(entry[:2]) for entry in history]
with gr.Blocks(theme='gradio/monochrome') as demo:
with gr.Row():
with gr.Column(scale=2):
chatbot = gr.Chatbot(value=[])
msg = gr.Textbox(label="Message")
with gr.Row():
clear = gr.Button("Clear")
regenerate = gr.Button("Regenerate")
stop_btn = gr.Button("Stop")
with gr.Row():
with gr.Column(scale=4):
import_textbox = gr.Textbox(label="Import textbox", lines=5)
with gr.Column(scale=1):
export_button = gr.Button("Export Chat")
import_button = gr.Button("Import Chat")
with gr.Column(scale=1):
system_prompt = gr.Textbox("", label="System Prompt", lines=5)
temperature = gr.Slider(0, 2, value=0.8, step=0.01, label="Temperature")
top_p = gr.Slider(0, 1, value=0.95, step=0.01, label="Top P")
top_k = gr.Slider(1, 500, value=40, step=1, label="Top K")
frequency_penalty = gr.Slider(-2, 2, value=0, step=0.1, label="Frequency Penalty")
presence_penalty = gr.Slider(-2, 2, value=0, step=0.1, label="Presence Penalty")
repetition_penalty = gr.Slider(0.01, 5, value=1.1, step=0.01, label="Repetition Penalty")
max_tokens = gr.Slider(1, 4096, value=512, step=1, label="Max Output (max_tokens)")
async def user(user_message, history):
history = sanitize_chatbot_history(history or [])
return "", history + [(user_message, None)]
async def bot(history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens):
history = sanitize_chatbot_history(history or [])
if not history:
yield history
return
user_message = history[-1][0]
bot_message = predict(user_message, history[:-1], system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens)
history[-1] = (history[-1][0], "") # Ensure it's a tuple
task_id = id(asyncio.current_task())
active_tasks[task_id] = asyncio.current_task()
try:
async for chunk in bot_message:
if task_id not in active_tasks:
break
history[-1] = (history[-1][0], chunk) # Update as a tuple
yield history
except asyncio.CancelledError:
pass
finally:
if task_id in active_tasks:
del active_tasks[task_id]
if history[-1][1] == "":
history[-1] = (history[-1][0], " [Generation stopped]")
yield history
async def regenerate_response(history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens):
# Cancel any ongoing generation
for task in list(active_tasks.values()):
task.cancel()
# Wait for a short time to ensure cancellation is processed
await asyncio.sleep(0.1)
history = sanitize_chatbot_history(history or [])
if history:
history[-1] = (history[-1][0], None) # Reset last response
async for new_history in bot(history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens):
yield new_history
else:
yield []
def import_chat_wrapper(custom_format_string):
imported_history, imported_system_prompt = import_chat(custom_format_string)
return sanitize_chatbot_history(imported_history), imported_system_prompt
submit_event = msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
bot, [chatbot, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens], chatbot,
concurrency_limit=5
)
clear.click(lambda: [], None, chatbot, queue=False)
regenerate_event = regenerate.click(
regenerate_response,
[chatbot, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens],
chatbot,
concurrency_limit=5
)
import_button.click(import_chat_wrapper, inputs=[import_textbox], outputs=[chatbot, system_prompt], concurrency_limit=5)
export_button.click(
export_chat,
inputs=[chatbot, system_prompt],
outputs=[import_textbox],
concurrency_limit=5
)
stop_btn.click(
lambda: [task.cancel() for task in list(active_tasks.values())],
None,
None,
cancels=[submit_event, regenerate_event],
queue=False
)
if __name__ == "__main__":
demo.launch(debug=True, max_threads=20)