|
import gradio as gr |
|
import requests |
|
import json |
|
import os |
|
import datetime |
|
from requests.exceptions import RequestException |
|
|
|
API_URL = os.environ.get('API_URL') |
|
API_KEY = os.environ.get('API_KEY') |
|
|
|
headers = { |
|
"Authorization": f"Bearer {API_KEY}", |
|
"Content-Type": "application/json", |
|
'Referer': os.environ.get('REFERRER_URL') |
|
} |
|
|
|
|
|
USER_LOGGING_ENABLED = False |
|
RESPONSE_LOGGING_ENABLED = True |
|
|
|
DEFAULT_PARAMS = { |
|
"temperature": 0.8, |
|
"top_p": 0.95, |
|
"top_k": 40, |
|
"frequency_penalty": 0, |
|
"presence_penalty": 0, |
|
"repetition_penalty": 1.1, |
|
"max_tokens": 512 |
|
} |
|
|
|
def get_timestamp(): |
|
return datetime.datetime.now().strftime("%H:%M:%S") |
|
|
|
def predict(message, history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens, stop_flag): |
|
history_format = [{"role": "system", "content": system_prompt}] |
|
for human, assistant in history: |
|
history_format.append({"role": "user", "content": human}) |
|
if assistant: |
|
history_format.append({"role": "assistant", "content": assistant}) |
|
history_format.append({"role": "user", "content": message}) |
|
|
|
if USER_LOGGING_ENABLED and not message.startswith(('*', '"')): |
|
print(f"<|system|> {system_prompt}") |
|
print(f"{get_timestamp()} <|user|> {message}") |
|
|
|
current_params = { |
|
"temperature": temperature, |
|
"top_p": top_p, |
|
"top_k": top_k, |
|
"frequency_penalty": frequency_penalty, |
|
"presence_penalty": presence_penalty, |
|
"repetition_penalty": repetition_penalty, |
|
"max_tokens": max_tokens |
|
} |
|
|
|
non_default_params = {k: v for k, v in current_params.items() if v != DEFAULT_PARAMS[k]} |
|
|
|
if USER_LOGGING_ENABLED and non_default_params and not message.startswith(('*', '"')): |
|
for param, value in non_default_params.items(): |
|
print(f"{param}={value}") |
|
|
|
data = { |
|
"model": "meta-llama/Meta-Llama-3.1-70B-Instruct", |
|
"messages": history_format, |
|
"stream": True, |
|
"temperature": temperature, |
|
"top_p": top_p, |
|
"top_k": top_k, |
|
"frequency_penalty": frequency_penalty, |
|
"presence_penalty": presence_penalty, |
|
"repetition_penalty": repetition_penalty, |
|
"max_tokens": max_tokens |
|
} |
|
|
|
try: |
|
with requests.post(API_URL, headers=headers, data=json.dumps(data), stream=True) as response: |
|
partial_message = "" |
|
for line in response.iter_lines(): |
|
if stop_flag[0]: |
|
response.close() |
|
break |
|
if line: |
|
line = line.decode('utf-8') |
|
if RESPONSE_LOGGING_ENABLED: |
|
print(f"API Response: {line}") |
|
if line.startswith("data: "): |
|
if line.strip() == "data: [DONE]": |
|
break |
|
try: |
|
json_data = json.loads(line[6:]) |
|
if 'choices' in json_data and json_data['choices']: |
|
content = json_data['choices'][0]['delta'].get('content', '') |
|
if content: |
|
partial_message += content |
|
yield partial_message |
|
except json.JSONDecodeError: |
|
continue |
|
|
|
if partial_message: |
|
yield partial_message |
|
|
|
except RequestException as e: |
|
print(f"Request error: {e}") |
|
yield f"An error occurred: {str(e)}" |
|
|
|
def import_chat(custom_format_string): |
|
try: |
|
sections = custom_format_string.split('<|') |
|
|
|
imported_history = [] |
|
system_prompt = "" |
|
|
|
for section in sections: |
|
if section.startswith('system|>'): |
|
system_prompt = section.replace('system|>', '').strip() |
|
elif section.startswith('user|>'): |
|
user_message = section.replace('user|>', '').strip() |
|
imported_history.append([user_message, None]) |
|
elif section.startswith('assistant|>'): |
|
assistant_message = section.replace('assistant|>', '').strip() |
|
if imported_history: |
|
imported_history[-1][1] = assistant_message |
|
else: |
|
imported_history.append(["", assistant_message]) |
|
|
|
return imported_history, system_prompt |
|
except Exception as e: |
|
print(f"Error importing chat: {e}") |
|
return None, None |
|
|
|
def export_chat(history, system_prompt): |
|
export_data = f"<|system|> {system_prompt}\n\n" |
|
if history is not None: |
|
for user_msg, assistant_msg in history: |
|
export_data += f"<|user|> {user_msg}\n\n" |
|
if assistant_msg: |
|
export_data += f"<|assistant|> {assistant_msg}\n\n" |
|
return export_data |
|
|
|
def stop_generation_func(stop_flag): |
|
stop_flag[0] = True |
|
return stop_flag |
|
|
|
with gr.Blocks(theme='gradio/monochrome') as demo: |
|
stop_flag = gr.State([False]) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=2): |
|
chatbot = gr.Chatbot(value=[]) |
|
msg = gr.Textbox(label="Message (70B for now. The provider might bug out at random. The space may restart frequently)") |
|
with gr.Row(): |
|
clear = gr.Button("Clear") |
|
regenerate = gr.Button("Regenerate") |
|
stop_btn = gr.Button("Stop") |
|
with gr.Row(): |
|
with gr.Column(scale=4): |
|
import_textbox = gr.Textbox(label="Import textbox", lines=5) |
|
with gr.Column(scale=1): |
|
export_button = gr.Button("Export Chat") |
|
import_button = gr.Button("Import Chat") |
|
|
|
with gr.Column(scale=1): |
|
system_prompt = gr.Textbox("", label="System Prompt", lines=5) |
|
temperature = gr.Slider(0, 2, value=0.8, step=0.01, label="Temperature") |
|
top_p = gr.Slider(0, 1, value=0.95, step=0.01, label="Top P") |
|
top_k = gr.Slider(1, 500, value=40, step=1, label="Top K") |
|
frequency_penalty = gr.Slider(-2, 2, value=0, step=0.1, label="Frequency Penalty") |
|
presence_penalty = gr.Slider(-2, 2, value=0, step=0.1, label="Presence Penalty") |
|
repetition_penalty = gr.Slider(0.01, 5, value=1.1, step=0.01, label="Repetition Penalty") |
|
max_tokens = gr.Slider(1, 4096, value=512, step=1, label="Max Output (max_tokens)") |
|
|
|
def user(user_message, history): |
|
history = history or [] |
|
return "", history + [[user_message, None]] |
|
|
|
def bot(history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens, stop_flag): |
|
stop_flag[0] = False |
|
history = history or [] |
|
if not history: |
|
return history |
|
user_message = history[-1][0] |
|
bot_message = predict(user_message, history[:-1], system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens, stop_flag) |
|
history[-1][1] = "" |
|
for chunk in bot_message: |
|
if stop_flag[0]: |
|
history[-1][1] += " [Generation stopped]" |
|
break |
|
history[-1][1] = chunk |
|
yield history |
|
|
|
def regenerate_response(history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens, stop_flag): |
|
if history and len(history) > 0: |
|
last_user_message = history[-1][0] |
|
history[-1][1] = None |
|
for new_history in bot(history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens, stop_flag): |
|
yield new_history |
|
else: |
|
yield [] |
|
|
|
def import_chat_wrapper(custom_format_string): |
|
imported_history, imported_system_prompt = import_chat(custom_format_string) |
|
return imported_history, imported_system_prompt |
|
|
|
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then( |
|
bot, [chatbot, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens, stop_flag], chatbot |
|
) |
|
|
|
clear.click(lambda: None, None, chatbot, queue=False) |
|
|
|
regenerate.click( |
|
regenerate_response, |
|
[chatbot, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens, stop_flag], |
|
chatbot |
|
) |
|
|
|
import_button.click(import_chat_wrapper, inputs=[import_textbox], outputs=[chatbot, system_prompt]) |
|
|
|
export_button.click( |
|
export_chat, |
|
inputs=[chatbot, system_prompt], |
|
outputs=[import_textbox] |
|
) |
|
|
|
stop_btn.click(stop_generation_func, inputs=[stop_flag], outputs=[stop_flag]) |
|
|
|
if __name__ == "__main__": |
|
demo.queue(max_size=3, default_concurrency_limit=3).launch(debug=True) |