import gradio as gr import requests import json import os import datetime from requests.exceptions import RequestException API_URL = os.environ.get('API_URL') API_KEY = os.environ.get('API_KEY') headers = { "Authorization": f"Bearer {API_KEY}", "Content-Type": "application/json", 'Referer': os.environ.get('REFERRER_URL') } # temporary debug switches: # needed in order to identify # some provider retardation going on # that causes all future responses to # commit sudoku out of nowhere # AND for absolutely no reason... USER_LOGGING_ENABLED = False RESPONSE_LOGGING_ENABLED = True DEFAULT_PARAMS = { "temperature": 0.8, "top_p": 0.95, "top_k": 40, "frequency_penalty": 0, "presence_penalty": 0, "repetition_penalty": 1.1, "max_tokens": 512 } def get_timestamp(): return datetime.datetime.now().strftime("%H:%M:%S") def predict(message, history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens, stop_flag): history_format = [{"role": "system", "content": system_prompt}] for human, assistant in history: history_format.append({"role": "user", "content": human}) if assistant: history_format.append({"role": "assistant", "content": assistant}) history_format.append({"role": "user", "content": message}) if USER_LOGGING_ENABLED and not message.startswith(('*', '"')): print(f"<|system|> {system_prompt}") print(f"{get_timestamp()} <|user|> {message}") current_params = { "temperature": temperature, "top_p": top_p, "top_k": top_k, "frequency_penalty": frequency_penalty, "presence_penalty": presence_penalty, "repetition_penalty": repetition_penalty, "max_tokens": max_tokens } non_default_params = {k: v for k, v in current_params.items() if v != DEFAULT_PARAMS[k]} if USER_LOGGING_ENABLED and non_default_params and not message.startswith(('*', '"')): for param, value in non_default_params.items(): print(f"{param}={value}") data = { "model": "meta-llama/Meta-Llama-3.1-70B-Instruct", "messages": history_format, "stream": True, "temperature": temperature, "top_p": top_p, "top_k": top_k, "frequency_penalty": frequency_penalty, "presence_penalty": presence_penalty, "repetition_penalty": repetition_penalty, "max_tokens": max_tokens } try: with requests.post(API_URL, headers=headers, data=json.dumps(data), stream=True) as response: partial_message = "" for line in response.iter_lines(): if stop_flag[0]: response.close() break if line: line = line.decode('utf-8') if RESPONSE_LOGGING_ENABLED: print(f"API Response: {line}") if line.startswith("data: "): if line.strip() == "data: [DONE]": break try: json_data = json.loads(line[6:]) if 'choices' in json_data and json_data['choices']: content = json_data['choices'][0]['delta'].get('content', '') if content: partial_message += content yield partial_message except json.JSONDecodeError: continue if partial_message: yield partial_message except RequestException as e: print(f"Request error: {e}") yield f"An error occurred: {str(e)}" def import_chat(custom_format_string): try: sections = custom_format_string.split('<|') imported_history = [] system_prompt = "" for section in sections: if section.startswith('system|>'): system_prompt = section.replace('system|>', '').strip() elif section.startswith('user|>'): user_message = section.replace('user|>', '').strip() imported_history.append([user_message, None]) elif section.startswith('assistant|>'): assistant_message = section.replace('assistant|>', '').strip() if imported_history: imported_history[-1][1] = assistant_message else: imported_history.append(["", assistant_message]) return imported_history, system_prompt except Exception as e: print(f"Error importing chat: {e}") return None, None def export_chat(history, system_prompt): export_data = f"<|system|> {system_prompt}\n\n" if history is not None: for user_msg, assistant_msg in history: export_data += f"<|user|> {user_msg}\n\n" if assistant_msg: export_data += f"<|assistant|> {assistant_msg}\n\n" return export_data def stop_generation_func(stop_flag): stop_flag[0] = True return stop_flag with gr.Blocks(theme='gradio/monochrome') as demo: stop_flag = gr.State([False]) with gr.Row(): with gr.Column(scale=2): chatbot = gr.Chatbot(value=[]) msg = gr.Textbox(label="Message (70B for now. The provider might bug out at random. The space may restart frequently)") with gr.Row(): clear = gr.Button("Clear") regenerate = gr.Button("Regenerate") stop_btn = gr.Button("Stop") with gr.Row(): with gr.Column(scale=4): import_textbox = gr.Textbox(label="Import textbox", lines=5) with gr.Column(scale=1): export_button = gr.Button("Export Chat") import_button = gr.Button("Import Chat") with gr.Column(scale=1): system_prompt = gr.Textbox("", label="System Prompt", lines=5) temperature = gr.Slider(0, 2, value=0.8, step=0.01, label="Temperature") top_p = gr.Slider(0, 1, value=0.95, step=0.01, label="Top P") top_k = gr.Slider(1, 500, value=40, step=1, label="Top K") frequency_penalty = gr.Slider(-2, 2, value=0, step=0.1, label="Frequency Penalty") presence_penalty = gr.Slider(-2, 2, value=0, step=0.1, label="Presence Penalty") repetition_penalty = gr.Slider(0.01, 5, value=1.1, step=0.01, label="Repetition Penalty") max_tokens = gr.Slider(1, 4096, value=512, step=1, label="Max Output (max_tokens)") def user(user_message, history): history = history or [] return "", history + [[user_message, None]] def bot(history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens, stop_flag): stop_flag[0] = False history = history or [] if not history: return history user_message = history[-1][0] bot_message = predict(user_message, history[:-1], system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens, stop_flag) history[-1][1] = "" for chunk in bot_message: if stop_flag[0]: history[-1][1] += " [Generation stopped]" break history[-1][1] = chunk yield history def regenerate_response(history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens, stop_flag): if history and len(history) > 0: last_user_message = history[-1][0] history[-1][1] = None for new_history in bot(history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens, stop_flag): yield new_history else: yield [] def import_chat_wrapper(custom_format_string): imported_history, imported_system_prompt = import_chat(custom_format_string) return imported_history, imported_system_prompt msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then( bot, [chatbot, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens, stop_flag], chatbot ) clear.click(lambda: None, None, chatbot, queue=False) regenerate.click( regenerate_response, [chatbot, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens, stop_flag], chatbot ) import_button.click(import_chat_wrapper, inputs=[import_textbox], outputs=[chatbot, system_prompt]) export_button.click( export_chat, inputs=[chatbot, system_prompt], outputs=[import_textbox] ) stop_btn.click(stop_generation_func, inputs=[stop_flag], outputs=[stop_flag]) if __name__ == "__main__": demo.queue(max_size=3, default_concurrency_limit=3).launch(debug=True)