Corvius's picture
why, oh WHYYYYYY
a58f56b verified
raw
history blame
9.19 kB
import gradio as gr
import requests
import json
import os
import datetime
from requests.exceptions import RequestException
API_URL = os.environ.get('API_URL')
API_KEY = os.environ.get('API_KEY')
headers = {
"Authorization": f"Bearer {API_KEY}",
"Content-Type": "application/json",
'Referer': os.environ.get('REFERRER_URL')
}
# temporary debug switches:
# needed in order to identify
# some provider retardation going on
# that causes all future responses to
# commit sudoku out of nowhere
# AND for absolutely no reason...
USER_LOGGING_ENABLED = False
RESPONSE_LOGGING_ENABLED = True
DEFAULT_PARAMS = {
"temperature": 0.8,
"top_p": 0.95,
"top_k": 40,
"frequency_penalty": 0,
"presence_penalty": 0,
"repetition_penalty": 1.1,
"max_tokens": 512
}
def get_timestamp():
return datetime.datetime.now().strftime("%H:%M:%S")
def predict(message, history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens, stop_flag):
history_format = [{"role": "system", "content": system_prompt}]
for human, assistant in history:
history_format.append({"role": "user", "content": human})
if assistant:
history_format.append({"role": "assistant", "content": assistant})
history_format.append({"role": "user", "content": message})
if USER_LOGGING_ENABLED and not message.startswith(('*', '"')):
print(f"<|system|> {system_prompt}")
print(f"{get_timestamp()} <|user|> {message}")
current_params = {
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"frequency_penalty": frequency_penalty,
"presence_penalty": presence_penalty,
"repetition_penalty": repetition_penalty,
"max_tokens": max_tokens
}
non_default_params = {k: v for k, v in current_params.items() if v != DEFAULT_PARAMS[k]}
if USER_LOGGING_ENABLED and non_default_params and not message.startswith(('*', '"')):
for param, value in non_default_params.items():
print(f"{param}={value}")
data = {
"model": "meta-llama/Meta-Llama-3.1-70B-Instruct",
"messages": history_format,
"stream": True,
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"frequency_penalty": frequency_penalty,
"presence_penalty": presence_penalty,
"repetition_penalty": repetition_penalty,
"max_tokens": max_tokens
}
try:
with requests.post(API_URL, headers=headers, data=json.dumps(data), stream=True) as response:
partial_message = ""
for line in response.iter_lines():
if stop_flag[0]:
response.close()
break
if line:
line = line.decode('utf-8')
if RESPONSE_LOGGING_ENABLED:
print(f"API Response: {line}")
if line.startswith("data: "):
if line.strip() == "data: [DONE]":
break
try:
json_data = json.loads(line[6:])
if 'choices' in json_data and json_data['choices']:
content = json_data['choices'][0]['delta'].get('content', '')
if content:
partial_message += content
yield partial_message
except json.JSONDecodeError:
continue
if partial_message:
yield partial_message
except RequestException as e:
print(f"Request error: {e}")
yield f"An error occurred: {str(e)}"
def import_chat(custom_format_string):
try:
sections = custom_format_string.split('<|')
imported_history = []
system_prompt = ""
for section in sections:
if section.startswith('system|>'):
system_prompt = section.replace('system|>', '').strip()
elif section.startswith('user|>'):
user_message = section.replace('user|>', '').strip()
imported_history.append([user_message, None])
elif section.startswith('assistant|>'):
assistant_message = section.replace('assistant|>', '').strip()
if imported_history:
imported_history[-1][1] = assistant_message
else:
imported_history.append(["", assistant_message])
return imported_history, system_prompt
except Exception as e:
print(f"Error importing chat: {e}")
return None, None
def export_chat(history, system_prompt):
export_data = f"<|system|> {system_prompt}\n\n"
if history is not None:
for user_msg, assistant_msg in history:
export_data += f"<|user|> {user_msg}\n\n"
if assistant_msg:
export_data += f"<|assistant|> {assistant_msg}\n\n"
return export_data
def stop_generation_func(stop_flag):
stop_flag[0] = True
return stop_flag
with gr.Blocks(theme='gradio/monochrome') as demo:
stop_flag = gr.State([False])
with gr.Row():
with gr.Column(scale=2):
chatbot = gr.Chatbot(value=[])
msg = gr.Textbox(label="Message (Doesn't work right now, provider committed sudoku. I'll switch to openrouter in a moment.)")
with gr.Row():
clear = gr.Button("Clear")
regenerate = gr.Button("Regenerate")
stop_btn = gr.Button("Stop")
with gr.Row():
with gr.Column(scale=4):
import_textbox = gr.Textbox(label="Import textbox", lines=5)
with gr.Column(scale=1):
export_button = gr.Button("Export Chat")
import_button = gr.Button("Import Chat")
with gr.Column(scale=1):
system_prompt = gr.Textbox("", label="System Prompt", lines=5)
temperature = gr.Slider(0, 2, value=0.8, step=0.01, label="Temperature")
top_p = gr.Slider(0, 1, value=0.95, step=0.01, label="Top P")
top_k = gr.Slider(1, 500, value=40, step=1, label="Top K")
frequency_penalty = gr.Slider(-2, 2, value=0, step=0.1, label="Frequency Penalty")
presence_penalty = gr.Slider(-2, 2, value=0, step=0.1, label="Presence Penalty")
repetition_penalty = gr.Slider(0.01, 5, value=1.1, step=0.01, label="Repetition Penalty")
max_tokens = gr.Slider(1, 4096, value=512, step=1, label="Max Output (max_tokens)")
def user(user_message, history):
history = history or []
return "", history + [[user_message, None]]
def bot(history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens, stop_flag):
stop_flag[0] = False
history = history or []
if not history:
return history
user_message = history[-1][0]
bot_message = predict(user_message, history[:-1], system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens, stop_flag)
history[-1][1] = ""
for chunk in bot_message:
if stop_flag[0]:
history[-1][1] += " [Generation stopped]"
break
history[-1][1] = chunk
yield history
def regenerate_response(history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens, stop_flag):
if history and len(history) > 0:
last_user_message = history[-1][0]
history[-1][1] = None
for new_history in bot(history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens, stop_flag):
yield new_history
else:
yield []
def import_chat_wrapper(custom_format_string):
imported_history, imported_system_prompt = import_chat(custom_format_string)
return imported_history, imported_system_prompt
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
bot, [chatbot, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens, stop_flag], chatbot
)
clear.click(lambda: None, None, chatbot, queue=False)
regenerate.click(
regenerate_response,
[chatbot, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens, stop_flag],
chatbot
)
import_button.click(import_chat_wrapper, inputs=[import_textbox], outputs=[chatbot, system_prompt])
export_button.click(
export_chat,
inputs=[chatbot, system_prompt],
outputs=[import_textbox]
)
stop_btn.click(stop_generation_func, inputs=[stop_flag], outputs=[stop_flag])
if __name__ == "__main__":
demo.queue(max_size=3, default_concurrency_limit=3).launch(debug=True)