anpigon commited on
Commit
ceaa913
โ€ข
1 Parent(s): 43e63ae

refactor: Update app.py and libs/llm.py to improve model selection and configuration

Browse files
Files changed (2) hide show
  1. app.py +22 -22
  2. libs/llm.py +1 -1
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import os
2
  import gradio as gr
3
  from kiwipiepy import Kiwi
 
4
 
5
  from langchain_core.output_parsers import StrOutputParser
6
  from langchain_core.runnables import RunnablePassthrough, RunnableLambda
@@ -24,16 +25,17 @@ embeddings = get_embeddings()
24
  retriever = load_retrievers(embeddings)
25
 
26
 
27
- def kiwi_tokenize(text):
28
- kiwi = Kiwi()
29
- return [token.form for token in kiwi.tokenize(text)]
30
-
31
-
32
- embeddings = get_embeddings()
33
- retriever = load_retrievers(embeddings)
 
34
 
35
 
36
- def create_rag_chain(chat_history, model):
37
  llm = get_llm(streaming=STREAMING).with_config(configurable={"llm": model})
38
  prompt = get_prompt(chat_history)
39
 
@@ -49,32 +51,30 @@ def create_rag_chain(chat_history, model):
49
  )
50
 
51
 
52
- def respond_stream(message, history, model):
 
 
 
 
 
 
53
  rag_chain = create_rag_chain(history, model)
54
  for chunk in rag_chain.stream(message):
55
  yield chunk
56
 
57
 
58
- def respond(message, history, model):
59
  rag_chain = create_rag_chain(history, model)
60
  return rag_chain.invoke(message)
61
 
62
 
63
- # ์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ ๋ชจ๋ธ ๋ชฉ๋ก (key: ๋ชจ๋ธ ์‹๋ณ„์ž, value: ์‚ฌ์šฉ์ž์—๊ฒŒ ํ‘œ์‹œํ•  ๋ ˆ์ด๋ธ”)
64
- AVAILABLE_MODELS = {
65
- "gpt_3_5_turbo": "GPT-3.5 Turbo",
66
- "gpt_4o": "GPT-4o",
67
- "claude_3_5_sonnet": "Claude 3.5 Sonnet",
68
- "gemini_1_5_flash": "Gemini 1.5 Flash",
69
- "llama3_70b": "Llama3 70b",
70
- }
71
-
72
-
73
- def get_model_key(label):
74
  return next(key for key, value in AVAILABLE_MODELS.items() if value == label)
75
 
76
 
77
- def chat_function(message, history, model_label):
 
 
78
  model_key = get_model_key(model_label)
79
  if STREAMING:
80
  response = ""
 
1
  import os
2
  import gradio as gr
3
  from kiwipiepy import Kiwi
4
+ from typing import List, Tuple, Generator, Union
5
 
6
  from langchain_core.output_parsers import StrOutputParser
7
  from langchain_core.runnables import RunnablePassthrough, RunnableLambda
 
25
  retriever = load_retrievers(embeddings)
26
 
27
 
28
+ # ์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ ๋ชจ๋ธ ๋ชฉ๋ก (key: ๋ชจ๋ธ ์‹๋ณ„์ž, value: ์‚ฌ์šฉ์ž์—๊ฒŒ ํ‘œ์‹œํ•  ๋ ˆ์ด๋ธ”)
29
+ AVAILABLE_MODELS = {
30
+ "gpt_3_5_turbo": "GPT-3.5 Turbo",
31
+ "gpt_4o": "GPT-4o",
32
+ "claude_3_5_sonnet": "Claude 3.5 Sonnet",
33
+ "gemini_1_5_flash": "Gemini 1.5 Flash",
34
+ "llama3_70b": "Llama3 70b",
35
+ }
36
 
37
 
38
+ def create_rag_chain(chat_history: List[Tuple[str, str]], model: str):
39
  llm = get_llm(streaming=STREAMING).with_config(configurable={"llm": model})
40
  prompt = get_prompt(chat_history)
41
 
 
51
  )
52
 
53
 
54
+ def get_model_key(label):
55
+ return next(key for key, value in AVAILABLE_MODELS.items() if value == label)
56
+
57
+
58
+ def respond_stream(
59
+ message: str, history: List[Tuple[str, str]], model: str
60
+ ) -> Generator[str, None, None]:
61
  rag_chain = create_rag_chain(history, model)
62
  for chunk in rag_chain.stream(message):
63
  yield chunk
64
 
65
 
66
+ def respond(message: str, history: List[Tuple[str, str]], model: str) -> str:
67
  rag_chain = create_rag_chain(history, model)
68
  return rag_chain.invoke(message)
69
 
70
 
71
+ def get_model_key(label: str) -> str:
 
 
 
 
 
 
 
 
 
 
72
  return next(key for key, value in AVAILABLE_MODELS.items() if value == label)
73
 
74
 
75
+ def chat_function(
76
+ message: str, history: List[Tuple[str, str]], model_label: str
77
+ ) -> Generator[str, None, None]:
78
  model_key = get_model_key(model_label)
79
  if STREAMING:
80
  response = ""
libs/llm.py CHANGED
@@ -30,7 +30,7 @@ def get_llm(streaming=True):
30
  callbacks=[StreamCallback()],
31
  ),
32
  gpt_3_5_turbo=ChatOpenAI(
33
- model="gpt-3.5-turbo",
34
  temperature=0,
35
  streaming=streaming,
36
  callbacks=[StreamCallback()],
 
30
  callbacks=[StreamCallback()],
31
  ),
32
  gpt_3_5_turbo=ChatOpenAI(
33
+ model="gpt-3.5-turbo-0125",
34
  temperature=0,
35
  streaming=streaming,
36
  callbacks=[StreamCallback()],