Spaces:
Running
Running
File size: 4,266 Bytes
9c880cb f9c7426 5bdf9aa cfab2e6 32957d4 cfab2e6 f9c7426 d345b65 cfab2e6 a41f6e0 cfab2e6 a41f6e0 cfab2e6 4b530fd 0f2c971 4b530fd f9c7426 0f2c971 c35470f cfab2e6 f9c7426 c14f735 f9c7426 a41f6e0 4d79cf7 a41f6e0 2cc406d 4d79cf7 a41f6e0 ec8a9b8 a41f6e0 4d79cf7 f9c7426 4d79cf7 a41f6e0 f9c7426 cfab2e6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
import gradio as gr
from huggingface_hub import InferenceClient
import os
MODELS = {
"Zephyr 7B Beta": "HuggingFaceH4/zephyr-7b-beta",
"DeepSeek Coder V2": "deepseek-ai/DeepSeek-Coder-V2-Instruct",
"Meta Llama 3.1 8B": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"Mixtral 8x7B": "mistralai/Mixtral-8x7B-Instruct-v0.1",
"Cohere Command R+": "CohereForAI/c4ai-command-r-plus",
}
def get_client(model_name):
model_id = MODELS[model_name]
hf_token = os.getenv("HF_TOKEN")
if not hf_token:
raise ValueError("HF_TOKEN environment variable is required")
return InferenceClient(model_id, token=hf_token)
def respond(
message,
chat_history,
model_name,
max_tokens,
temperature,
top_p,
system_message,
):
try:
client = get_client(model_name)
except ValueError as e:
chat_history.append((message, str(e)))
return chat_history
messages = [{"role": "system", "content": system_message}]
for human, assistant in chat_history:
messages.append({"role": "user", "content": human})
messages.append({"role": "assistant", "content": assistant})
messages.append({"role": "user", "content": message})
try:
if "Cohere" in model_name:
# Cohere 모델을 위한 비스트리밍 처리
response = client.chat_completion(
messages,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
)
assistant_message = response.choices[0].message.content
chat_history.append((message, assistant_message))
yield chat_history
else:
# 다른 모델들을 위한 스트리밍 처리
stream = client.chat_completion(
messages,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
stream=True,
)
partial_message = ""
for response in stream:
if response.choices[0].delta.content is not None:
partial_message += response.choices[0].delta.content
if len(chat_history) > 0 and chat_history[-1][0] == message:
chat_history[-1] = (message, partial_message)
else:
chat_history.append((message, partial_message))
yield chat_history
except Exception as e:
error_message = f"An error occurred: {str(e)}"
chat_history.append((message, error_message))
yield chat_history
def clear_conversation():
return []
with gr.Blocks() as demo:
gr.Markdown("# Prompting AI Chatbot")
gr.Markdown("언어모델별 프롬프트 테스트 챗봇입니다.")
with gr.Row():
with gr.Column(scale=1):
model_name = gr.Radio(
choices=list(MODELS.keys()),
label="Language Model",
value="Zephyr 7B Beta"
)
max_tokens = gr.Slider(minimum=0, maximum=2000, value=500, step=100, label="Max Tokens")
temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.05, label="Temperature")
top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p")
system_message = gr.Textbox(
value="""반드시 한글로 답변할 것.
너는 최고의 비서이다.
내가 요구하는것들을 최대한 자세하고 정확하게 답변하라.
""",
label="System Message",
lines=3
)
with gr.Column(scale=2):
chatbot = gr.Chatbot()
msg = gr.Textbox(label="메세지를 입력하세요")
with gr.Row():
submit_button = gr.Button("전송")
clear_button = gr.Button("대화 내역 지우기")
msg.submit(respond, [msg, chatbot, model_name, max_tokens, temperature, top_p, system_message], chatbot)
submit_button.click(respond, [msg, chatbot, model_name, max_tokens, temperature, top_p, system_message], chatbot)
clear_button.click(clear_conversation, outputs=chatbot, queue=False)
if __name__ == "__main__":
demo.launch() |