blockenters commited on
Commit
e0ed46b
Β·
1 Parent(s): 0e0b085
Files changed (1) hide show
  1. app.py +34 -22
app.py CHANGED
@@ -2,17 +2,23 @@ import streamlit as st
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
 
5
- # λͺ¨λΈ λ‘œλ“œ (DialoGPT-medium μ˜ˆμ‹œ)
6
  @st.cache_resource
7
- def load_model(model_name="microsoft/DialoGPT-medium"):
8
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
9
- model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
 
 
 
 
 
10
  return tokenizer, model
11
 
12
  # μ•± μ‹€ν–‰ ν•¨μˆ˜
13
  def main():
14
- st.title("ChatGPT μœ μ‚¬ λŒ€ν™” 데λͺ¨")
15
- st.write("μ—¬κΈ°λŠ” DialoGPT λͺ¨λΈμ„ ν™œμš©ν•œ κ°„λ‹¨ν•œ λŒ€ν™” ν…ŒμŠ€νŠΈμš© 데λͺ¨μž…λ‹ˆλ‹€.")
 
16
 
17
  # μ„Έμ…˜ μŠ€ν…Œμ΄νŠΈ μ΄ˆκΈ°ν™”
18
  if "chat_history_ids" not in st.session_state:
@@ -21,9 +27,10 @@ def main():
21
  st.session_state["past_user_inputs"] = []
22
  if "generated_responses" not in st.session_state:
23
  st.session_state["generated_responses"] = []
24
-
25
- tokenizer, model = load_model(model_name="deepseek-ai/DeepSeek-R1")
26
-
 
27
  # κΈ°μ‘΄ λŒ€ν™” λ‚΄μ—­ ν‘œμ‹œ
28
  if st.session_state["past_user_inputs"]:
29
  for user_text, bot_text in zip(st.session_state["past_user_inputs"], st.session_state["generated_responses"]):
@@ -33,40 +40,45 @@ def main():
33
  # 봇 λ©”μ‹œμ§€
34
  with st.chat_message("assistant"):
35
  st.write(bot_text)
36
-
37
  # μ±„νŒ… μž…λ ₯μ°½
38
  user_input = st.chat_input("λ©”μ‹œμ§€λ₯Ό μž…λ ₯ν•˜μ„Έμš”...")
39
-
40
  if user_input:
41
  # μ‚¬μš©μž λ©”μ‹œμ§€ ν‘œμ‹œ
42
  with st.chat_message("user"):
43
  st.write(user_input)
44
-
45
- # μƒˆ μž…λ ₯을 토큰화
46
- new_user_input_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors='pt')
47
-
48
  if st.session_state["chat_history_ids"] is not None:
49
  # κΈ°μ‘΄ νžˆμŠ€ν† λ¦¬μ— 이어 뢙이기
50
  bot_input_ids = torch.cat([st.session_state["chat_history_ids"], new_user_input_ids], dim=-1)
51
  else:
52
  bot_input_ids = new_user_input_ids
53
-
54
  # λͺ¨λΈ μΆ”λ‘ 
55
  with torch.no_grad():
56
  chat_history_ids = model.generate(
57
  bot_input_ids,
58
- max_length=1000,
59
- pad_token_id=tokenizer.eos_token_id
 
 
 
 
60
  )
61
-
62
- # κ²°κ³Ό λ””μ½”λ”©
63
- bot_text = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
64
-
 
65
  # μ„Έμ…˜ μŠ€ν…Œμ΄νŠΈμ— λŒ€ν™” λ‚΄μš© μ—…λ°μ΄νŠΈ
66
  st.session_state["past_user_inputs"].append(user_input)
67
  st.session_state["generated_responses"].append(bot_text)
68
  st.session_state["chat_history_ids"] = chat_history_ids
69
-
70
  # 봇 λ©”μ‹œμ§€ ν‘œμ‹œ
71
  with st.chat_message("assistant"):
72
  st.write(bot_text)
 
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
 
5
+ # λͺ¨λΈ λ‘œλ“œ (DeepSeek-R1-Distill-Qwen-1.5B μ˜ˆμ‹œ)
6
  @st.cache_resource
7
+ def load_model(model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"):
8
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
9
+ model = AutoModelForCausalLM.from_pretrained(
10
+ model_name,
11
+ device_map="auto",
12
+ torch_dtype=torch.float16,
13
+ trust_remote_code=True # λ§Œμ•½ μ»€μŠ€ν…€ μ½”λ“œκ°€ ν•„μš”ν•œ 경우 ν™œμ„±ν™”
14
+ )
15
  return tokenizer, model
16
 
17
  # μ•± μ‹€ν–‰ ν•¨μˆ˜
18
  def main():
19
+ st.set_page_config(page_title="DeepSeek-R1 Chatbot", page_icon="πŸ€–")
20
+ st.title("DeepSeek-R1 기반 λŒ€ν™”ν˜• 챗봇")
21
+ st.write("DeepSeek-R1-Distill-Qwen-1.5B λͺ¨λΈμ„ ν™œμš©ν•œ ν•œκ΅­μ–΄ λŒ€ν™” ν…ŒμŠ€νŠΈμš© 데λͺ¨μž…λ‹ˆλ‹€.")
22
 
23
  # μ„Έμ…˜ μŠ€ν…Œμ΄νŠΈ μ΄ˆκΈ°ν™”
24
  if "chat_history_ids" not in st.session_state:
 
27
  st.session_state["past_user_inputs"] = []
28
  if "generated_responses" not in st.session_state:
29
  st.session_state["generated_responses"] = []
30
+
31
+ # λͺ¨λΈκ³Ό ν† ν¬λ‚˜μ΄μ € 뢈러였기
32
+ tokenizer, model = load_model()
33
+
34
  # κΈ°μ‘΄ λŒ€ν™” λ‚΄μ—­ ν‘œμ‹œ
35
  if st.session_state["past_user_inputs"]:
36
  for user_text, bot_text in zip(st.session_state["past_user_inputs"], st.session_state["generated_responses"]):
 
40
  # 봇 λ©”μ‹œμ§€
41
  with st.chat_message("assistant"):
42
  st.write(bot_text)
43
+
44
  # μ±„νŒ… μž…λ ₯μ°½
45
  user_input = st.chat_input("λ©”μ‹œμ§€λ₯Ό μž…λ ₯ν•˜μ„Έμš”...")
46
+
47
  if user_input:
48
  # μ‚¬μš©μž λ©”μ‹œμ§€ ν‘œμ‹œ
49
  with st.chat_message("user"):
50
  st.write(user_input)
51
+
52
+ # λͺ¨λΈ μž…λ ₯ μ „μ²˜λ¦¬
53
+ new_user_input_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors='pt').to(model.device)
54
+
55
  if st.session_state["chat_history_ids"] is not None:
56
  # κΈ°μ‘΄ νžˆμŠ€ν† λ¦¬μ— 이어 뢙이기
57
  bot_input_ids = torch.cat([st.session_state["chat_history_ids"], new_user_input_ids], dim=-1)
58
  else:
59
  bot_input_ids = new_user_input_ids
60
+
61
  # λͺ¨λΈ μΆ”λ‘ 
62
  with torch.no_grad():
63
  chat_history_ids = model.generate(
64
  bot_input_ids,
65
+ max_length=32768, # λͺ¨λΈ μΉ΄λ“œ ꢌμž₯ μ΅œλŒ€ 길이
66
+ temperature=0.6, # λͺ¨λΈ μΉ΄λ“œ ꢌμž₯ μ˜¨λ„
67
+ top_p=0.95, # λͺ¨λΈ μΉ΄λ“œ ꢌμž₯ top-p
68
+ pad_token_id=tokenizer.eos_token_id,
69
+ do_sample=True,
70
+ num_return_sequences=1
71
  )
72
+
73
+ # μƒˆλ‘œ μƒμ„±λœ ν† ν°λ§Œ λ””μ½”λ”©
74
+ bot_output_ids = chat_history_ids[:, bot_input_ids.shape[-1]:]
75
+ bot_text = tokenizer.decode(bot_output_ids[0], skip_special_tokens=True)
76
+
77
  # μ„Έμ…˜ μŠ€ν…Œμ΄νŠΈμ— λŒ€ν™” λ‚΄μš© μ—…λ°μ΄νŠΈ
78
  st.session_state["past_user_inputs"].append(user_input)
79
  st.session_state["generated_responses"].append(bot_text)
80
  st.session_state["chat_history_ids"] = chat_history_ids
81
+
82
  # 봇 λ©”μ‹œμ§€ ν‘œμ‹œ
83
  with st.chat_message("assistant"):
84
  st.write(bot_text)