File size: 4,266 Bytes
9c880cb
 
f9c7426
5bdf9aa
cfab2e6
 
 
 
 
 
32957d4
 
cfab2e6
f9c7426
d345b65
 
 
 
cfab2e6
 
 
a41f6e0
cfab2e6
 
 
 
a41f6e0
cfab2e6
4b530fd
0f2c971
4b530fd
 
 
 
 
 
 
 
 
 
f9c7426
0f2c971
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c35470f
 
 
 
cfab2e6
f9c7426
c14f735
f9c7426
a41f6e0
4d79cf7
 
a41f6e0
 
 
 
 
 
 
 
2cc406d
4d79cf7
 
a41f6e0
ec8a9b8
 
 
 
a41f6e0
 
 
 
 
 
4d79cf7
f9c7426
4d79cf7
 
a41f6e0
 
f9c7426
 
cfab2e6
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import gradio as gr
from huggingface_hub import InferenceClient
import os

MODELS = {
    "Zephyr 7B Beta": "HuggingFaceH4/zephyr-7b-beta",
    "DeepSeek Coder V2": "deepseek-ai/DeepSeek-Coder-V2-Instruct",
    "Meta Llama 3.1 8B": "meta-llama/Meta-Llama-3.1-8B-Instruct",
    "Mixtral 8x7B": "mistralai/Mixtral-8x7B-Instruct-v0.1",
    "Cohere Command R+": "CohereForAI/c4ai-command-r-plus",
}

def get_client(model_name):
    model_id = MODELS[model_name]
    hf_token = os.getenv("HF_TOKEN")
    if not hf_token:
        raise ValueError("HF_TOKEN environment variable is required")
    return InferenceClient(model_id, token=hf_token)

def respond(
    message,
    chat_history,
    model_name,
    max_tokens,
    temperature,
    top_p,
    system_message,
):
    try:
        client = get_client(model_name)
    except ValueError as e:
        chat_history.append((message, str(e)))
        return chat_history

    messages = [{"role": "system", "content": system_message}]
    for human, assistant in chat_history:
        messages.append({"role": "user", "content": human})
        messages.append({"role": "assistant", "content": assistant})
    messages.append({"role": "user", "content": message})

    try:
        if "Cohere" in model_name:
            # Cohere 모델을 위한 비스트리밍 처리
            response = client.chat_completion(
                messages,
                max_tokens=max_tokens,
                temperature=temperature,
                top_p=top_p,
            )
            assistant_message = response.choices[0].message.content
            chat_history.append((message, assistant_message))
            yield chat_history
        else:
            # 다른 모델들을 위한 스트리밍 처리
            stream = client.chat_completion(
                messages,
                max_tokens=max_tokens,
                temperature=temperature,
                top_p=top_p,
                stream=True,
            )
            partial_message = ""
            for response in stream:
                if response.choices[0].delta.content is not None:
                    partial_message += response.choices[0].delta.content
                    if len(chat_history) > 0 and chat_history[-1][0] == message:
                        chat_history[-1] = (message, partial_message)
                    else:
                        chat_history.append((message, partial_message))
                    yield chat_history
    except Exception as e:
        error_message = f"An error occurred: {str(e)}"
        chat_history.append((message, error_message))
        yield chat_history

def clear_conversation():
    return []

with gr.Blocks() as demo:
    gr.Markdown("# Prompting AI Chatbot")
    gr.Markdown("언어모델별 프롬프트 테스트 챗봇입니다.")

    with gr.Row():
        with gr.Column(scale=1):
            model_name = gr.Radio(
                choices=list(MODELS.keys()),
                label="Language Model",
                value="Zephyr 7B Beta"
            )
            max_tokens = gr.Slider(minimum=0, maximum=2000, value=500, step=100, label="Max Tokens")
            temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.05, label="Temperature")
            top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p")
            system_message = gr.Textbox(
                value="""반드시 한글로 답변할 것.
너는 최고의 비서이다.
내가 요구하는것들을 최대한 자세하고 정확하게 답변하라.
""",
                label="System Message",
                lines=3
            )

        with gr.Column(scale=2):
            chatbot = gr.Chatbot()
            msg = gr.Textbox(label="메세지를 입력하세요")
            with gr.Row():
                submit_button = gr.Button("전송")
                clear_button = gr.Button("대화 내역 지우기")

    msg.submit(respond, [msg, chatbot, model_name, max_tokens, temperature, top_p, system_message], chatbot)
    submit_button.click(respond, [msg, chatbot, model_name, max_tokens, temperature, top_p, system_message], chatbot)
    clear_button.click(clear_conversation, outputs=chatbot, queue=False)

if __name__ == "__main__":
    demo.launch()