Corvius's picture
we are so unback...
f99e888 verified
raw
history blame contribute delete
No virus
9.01 kB
import gradio as gr
import requests
import json
import os
import datetime
from requests.exceptions import RequestException
API_URL = os.environ.get('API_URL')
API_KEY = os.environ.get('API_KEY')
headers = {
"Authorization": f"Bearer {API_KEY}",
"Content-Type": "application/json",
'Referer': os.environ.get('REFERRER_URL')
}
# debug switches
USER_LOGGING_ENABLED = False
RESPONSE_LOGGING_ENABLED = True
DEFAULT_PARAMS = {
"temperature": 0.8,
"top_p": 0.95,
"top_k": 40,
"frequency_penalty": 0,
"presence_penalty": 0,
"repetition_penalty": 1.1,
"max_tokens": 512
}
def get_timestamp():
return datetime.datetime.now().strftime("%H:%M:%S")
def predict(message, history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens, stop_flag):
history_format = [{"role": "system", "content": system_prompt}]
for human, assistant in history:
history_format.append({"role": "user", "content": human})
if assistant:
history_format.append({"role": "assistant", "content": assistant})
history_format.append({"role": "user", "content": message})
if USER_LOGGING_ENABLED and not message.startswith(('*', '"')):
print(f"<|system|> {system_prompt}")
print(f"{get_timestamp()} <|user|> {message}")
current_params = {
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"frequency_penalty": frequency_penalty,
"presence_penalty": presence_penalty,
"repetition_penalty": repetition_penalty,
"max_tokens": max_tokens
}
non_default_params = {k: v for k, v in current_params.items() if v != DEFAULT_PARAMS[k]}
if USER_LOGGING_ENABLED and non_default_params and not message.startswith(('*', '"')):
for param, value in non_default_params.items():
print(f"{param}={value}")
data = {
"model": "meta-llama/Meta-Llama-3.1-70B-Instruct",
"messages": history_format,
"stream": True,
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"frequency_penalty": frequency_penalty,
"presence_penalty": presence_penalty,
"repetition_penalty": repetition_penalty,
"max_tokens": max_tokens
}
try:
with requests.post(API_URL, headers=headers, data=json.dumps(data), stream=True) as response:
partial_message = ""
for line in response.iter_lines():
if stop_flag[0]:
response.close()
break
if line:
line = line.decode('utf-8')
if RESPONSE_LOGGING_ENABLED:
print(f"API Response: {line}")
if line.startswith("data: "):
if line.strip() == "data: [DONE]":
break
try:
json_data = json.loads(line[6:])
if 'choices' in json_data and json_data['choices']:
content = json_data['choices'][0]['delta'].get('content', '')
if content:
partial_message += content
yield partial_message
except json.JSONDecodeError:
continue
if partial_message:
yield partial_message
except RequestException as e:
print(f"Request error: {e}")
yield f"An error occurred: {str(e)}"
def import_chat(custom_format_string):
try:
sections = custom_format_string.split('<|')
imported_history = []
system_prompt = ""
for section in sections:
if section.startswith('system|>'):
system_prompt = section.replace('system|>', '').strip()
elif section.startswith('user|>'):
user_message = section.replace('user|>', '').strip()
imported_history.append([user_message, None])
elif section.startswith('assistant|>'):
assistant_message = section.replace('assistant|>', '').strip()
if imported_history:
imported_history[-1][1] = assistant_message
else:
imported_history.append(["", assistant_message])
return imported_history, system_prompt
except Exception as e:
print(f"Error importing chat: {e}")
return None, None
def export_chat(history, system_prompt):
export_data = f"<|system|> {system_prompt}\n\n"
if history is not None:
for user_msg, assistant_msg in history:
export_data += f"<|user|> {user_msg}\n\n"
if assistant_msg:
export_data += f"<|assistant|> {assistant_msg}\n\n"
return export_data
def stop_generation_func(stop_flag):
stop_flag[0] = True
return stop_flag
with gr.Blocks(theme='gradio/monochrome') as demo:
stop_flag = gr.State([False])
with gr.Row():
with gr.Column(scale=2):
chatbot = gr.Chatbot(value=[])
msg = gr.Textbox(label="Message (70B for now. The provider might bug out at random. The space may restart frequently)")
with gr.Row():
clear = gr.Button("Clear")
regenerate = gr.Button("Regenerate")
stop_btn = gr.Button("Stop")
with gr.Row():
with gr.Column(scale=4):
import_textbox = gr.Textbox(label="Import textbox", lines=5)
with gr.Column(scale=1):
export_button = gr.Button("Export Chat")
import_button = gr.Button("Import Chat")
with gr.Column(scale=1):
system_prompt = gr.Textbox("", label="System Prompt", lines=5)
temperature = gr.Slider(0, 2, value=0.8, step=0.01, label="Temperature")
top_p = gr.Slider(0, 1, value=0.95, step=0.01, label="Top P")
top_k = gr.Slider(1, 500, value=40, step=1, label="Top K")
frequency_penalty = gr.Slider(-2, 2, value=0, step=0.1, label="Frequency Penalty")
presence_penalty = gr.Slider(-2, 2, value=0, step=0.1, label="Presence Penalty")
repetition_penalty = gr.Slider(0.01, 5, value=1.1, step=0.01, label="Repetition Penalty")
max_tokens = gr.Slider(1, 4096, value=512, step=1, label="Max Output (max_tokens)")
def user(user_message, history):
history = history or []
return "", history + [[user_message, None]]
def bot(history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens, stop_flag):
stop_flag[0] = False
history = history or []
if not history:
return history
user_message = history[-1][0]
bot_message = predict(user_message, history[:-1], system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens, stop_flag)
history[-1][1] = ""
for chunk in bot_message:
if stop_flag[0]:
history[-1][1] += " [Generation stopped]"
break
history[-1][1] = chunk
yield history
def regenerate_response(history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens, stop_flag):
if history and len(history) > 0:
last_user_message = history[-1][0]
history[-1][1] = None
for new_history in bot(history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens, stop_flag):
yield new_history
else:
yield []
def import_chat_wrapper(custom_format_string):
imported_history, imported_system_prompt = import_chat(custom_format_string)
return imported_history, imported_system_prompt
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
bot, [chatbot, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens, stop_flag], chatbot
)
clear.click(lambda: None, None, chatbot, queue=False)
regenerate.click(
regenerate_response,
[chatbot, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens, stop_flag],
chatbot
)
import_button.click(import_chat_wrapper, inputs=[import_textbox], outputs=[chatbot, system_prompt])
export_button.click(
export_chat,
inputs=[chatbot, system_prompt],
outputs=[import_textbox]
)
stop_btn.click(stop_generation_func, inputs=[stop_flag], outputs=[stop_flag])
if __name__ == "__main__":
demo.queue(max_size=3, default_concurrency_limit=3).launch(debug=True)