AIRider commited on
Commit
87dda7a
·
verified ·
1 Parent(s): b16cf8b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -29
app.py CHANGED
@@ -21,15 +21,14 @@ MAX_HISTORY_LENGTH = 5 # 히스토리에 유지할 최대 대화 수
21
  def truncate_history(history):
22
  return history[-MAX_HISTORY_LENGTH:] if len(history) > MAX_HISTORY_LENGTH else history
23
 
24
- def respond(message, history, system_message, max_tokens, temperature, top_p, selected_model):
25
  stop_event.clear()
26
  client = InferenceClient(model=selected_model, token=hf_token)
27
 
28
- truncated_history = truncate_history(history)
29
-
30
- messages = [{"role": "system", "content": system_message + "\n사용자의 입력에만 직접적으로 답변하세요. 추가 질문을 생성하거나 사용자의 입력을 확장하지 마세요."}]
31
- messages.extend([{"role": "user" if i % 2 == 0 else "assistant", "content": m} for h in truncated_history for i, m in enumerate(h) if m])
32
- messages.append({"role": "user", "content": message})
33
 
34
  try:
35
  response = ""
@@ -44,26 +43,15 @@ def respond(message, history, system_message, max_tokens, temperature, top_p, se
44
  break
45
  if chunk:
46
  response += chunk
47
- if response.startswith(message):
48
- response = response[len(message):].lstrip()
49
- yield truncated_history + [(message, response)]
50
 
51
  except Exception as e:
52
- yield truncated_history + [(message, f"오류 발생: {str(e)}")]
53
-
54
- def continue_writing(message, history, system_message, max_tokens, temperature, top_p, selected_model):
55
- if not history:
56
- yield [("시스템", "대화 내역이 없습니다.")]
57
- return
58
 
59
- truncated_history = truncate_history(history)
60
- last_assistant_message = truncated_history[-1][1]
61
-
62
- prompt = f"이전 대화를 간단히 요약하고 이어서 작성해주세요. 마지막 응답: {last_assistant_message[:100]}..."
63
 
64
- for response in respond(prompt, truncated_history[:-1], system_message, max_tokens, temperature, top_p, selected_model):
65
- yield response
66
-
67
  def stop_generation():
68
  stop_event.set()
69
  return "생성이 중단되었습니다."
@@ -74,14 +62,42 @@ def regenerate(chat_history, system_message, max_tokens, temperature, top_p, sel
74
  last_user_message = chat_history[-1][0]
75
  return respond(last_user_message, chat_history[:-1], system_message, max_tokens, temperature, top_p, selected_model)
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  with gr.Blocks() as demo:
78
  chatbot = gr.Chatbot()
79
- msg = gr.Textbox(label="메시지 입력", placeholder="메시지를 입력하세요. Enter로 전송, Shift+Enter로 줄바꿈")
80
 
81
  with gr.Row():
82
  send = gr.Button("전송")
83
  continue_btn = gr.Button("계속 작성")
84
- regenerate_btn = gr.Button("🔄 재생성")
85
  stop = gr.Button("🛑 생성 중단")
86
  clear = gr.Button("🗑️ 대화 내역 지우기")
87
 
@@ -97,12 +113,11 @@ with gr.Blocks() as demo:
97
  model = gr.Radio(list(models.keys()), value=list(models.keys())[0], label="언어 모델 선택", info="사용할 언어 모델을 선택하세요")
98
 
99
  # Event handlers
100
- msg.submit(respond, [msg, chatbot, system_message, max_tokens, temperature, top_p, model], [chatbot])
101
- send.click(respond, [msg, chatbot, system_message, max_tokens, temperature, top_p, model], [chatbot])
102
  continue_btn.click(continue_writing,
103
- inputs=[msg, chatbot, system_message, max_tokens, temperature, top_p, model],
104
- outputs=[chatbot])
105
- regenerate_btn.click(regenerate, [chatbot, system_message, max_tokens, temperature, top_p, model], [chatbot])
106
  stop.click(stop_generation, outputs=[msg])
107
  clear.click(lambda: None, outputs=[chatbot])
108
 
 
21
  def truncate_history(history):
22
  return history[-MAX_HISTORY_LENGTH:] if len(history) > MAX_HISTORY_LENGTH else history
23
 
24
+ def respond(message, system_message, max_tokens, temperature, top_p, selected_model):
25
  stop_event.clear()
26
  client = InferenceClient(model=selected_model, token=hf_token)
27
 
28
+ messages = [
29
+ {"role": "system", "content": system_message},
30
+ {"role": "user", "content": message}
31
+ ]
 
32
 
33
  try:
34
  response = ""
 
43
  break
44
  if chunk:
45
  response += chunk
46
+ yield [(message, response)]
 
 
47
 
48
  except Exception as e:
49
+ yield [(message, f"오류 발생: {str(e)}")]
 
 
 
 
 
50
 
51
+ def stop_generation():
52
+ stop_event.set()
53
+ return "생성이 중단되었습니다."
 
54
 
 
 
 
55
  def stop_generation():
56
  stop_event.set()
57
  return "생성이 중단되었습니다."
 
62
  last_user_message = chat_history[-1][0]
63
  return respond(last_user_message, chat_history[:-1], system_message, max_tokens, temperature, top_p, selected_model)
64
 
65
+ def continue_writing(last_response, system_message, max_tokens, temperature, top_p, selected_model):
66
+ stop_event.clear()
67
+ client = InferenceClient(model=selected_model, token=hf_token)
68
+
69
+ prompt = f"이전 응답을 이어서 작성해주세요. 이전 응답: {last_response}"
70
+ messages = [
71
+ {"role": "system", "content": system_message},
72
+ {"role": "user", "content": prompt}
73
+ ]
74
+
75
+ try:
76
+ response = last_response
77
+ for chunk in client.text_generation(
78
+ prompt="\n".join([f"{m['role']}: {m['content']}" for m in messages]),
79
+ max_new_tokens=max_tokens,
80
+ temperature=temperature,
81
+ top_p=top_p,
82
+ stream=True
83
+ ):
84
+ if stop_event.is_set():
85
+ break
86
+ if chunk:
87
+ response += chunk
88
+ yield [("계속 작성", response)]
89
+
90
+ except Exception as e:
91
+ yield [("계속 작성", f"오류 발생: {str(e)}")]
92
+
93
+ # Gradio 인터페이스 수정
94
  with gr.Blocks() as demo:
95
  chatbot = gr.Chatbot()
96
+ msg = gr.Textbox(label="메시지 입력")
97
 
98
  with gr.Row():
99
  send = gr.Button("전송")
100
  continue_btn = gr.Button("계속 작성")
 
101
  stop = gr.Button("🛑 생성 중단")
102
  clear = gr.Button("🗑️ 대화 내역 지우기")
103
 
 
113
  model = gr.Radio(list(models.keys()), value=list(models.keys())[0], label="언어 모델 선택", info="사용할 언어 모델을 선택하세요")
114
 
115
  # Event handlers
116
+ send.click(respond, inputs=[msg, system_message, max_tokens, temperature, top_p, model], outputs=[chatbot])
117
+ msg.submit(respond, inputs=[msg, system_message, max_tokens, temperature, top_p, model], outputs=[chatbot])
118
  continue_btn.click(continue_writing,
119
+ inputs=[lambda: chatbot[-1][1] if chatbot else "", system_message, max_tokens, temperature, top_p, model],
120
+ outputs=[chatbot])
 
121
  stop.click(stop_generation, outputs=[msg])
122
  clear.click(lambda: None, outputs=[chatbot])
123