refactor: Update app.py and libs/llm.py to improve model selection and configuration
Browse files- app.py +22 -22
- libs/llm.py +1 -1
app.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import os
|
2 |
import gradio as gr
|
3 |
from kiwipiepy import Kiwi
|
|
|
4 |
|
5 |
from langchain_core.output_parsers import StrOutputParser
|
6 |
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
|
@@ -24,16 +25,17 @@ embeddings = get_embeddings()
|
|
24 |
retriever = load_retrievers(embeddings)
|
25 |
|
26 |
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
|
|
34 |
|
35 |
|
36 |
-
def create_rag_chain(chat_history, model):
|
37 |
llm = get_llm(streaming=STREAMING).with_config(configurable={"llm": model})
|
38 |
prompt = get_prompt(chat_history)
|
39 |
|
@@ -49,32 +51,30 @@ def create_rag_chain(chat_history, model):
|
|
49 |
)
|
50 |
|
51 |
|
52 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
rag_chain = create_rag_chain(history, model)
|
54 |
for chunk in rag_chain.stream(message):
|
55 |
yield chunk
|
56 |
|
57 |
|
58 |
-
def respond(message, history, model):
|
59 |
rag_chain = create_rag_chain(history, model)
|
60 |
return rag_chain.invoke(message)
|
61 |
|
62 |
|
63 |
-
|
64 |
-
AVAILABLE_MODELS = {
|
65 |
-
"gpt_3_5_turbo": "GPT-3.5 Turbo",
|
66 |
-
"gpt_4o": "GPT-4o",
|
67 |
-
"claude_3_5_sonnet": "Claude 3.5 Sonnet",
|
68 |
-
"gemini_1_5_flash": "Gemini 1.5 Flash",
|
69 |
-
"llama3_70b": "Llama3 70b",
|
70 |
-
}
|
71 |
-
|
72 |
-
|
73 |
-
def get_model_key(label):
|
74 |
return next(key for key, value in AVAILABLE_MODELS.items() if value == label)
|
75 |
|
76 |
|
77 |
-
def chat_function(
|
|
|
|
|
78 |
model_key = get_model_key(model_label)
|
79 |
if STREAMING:
|
80 |
response = ""
|
|
|
1 |
import os
|
2 |
import gradio as gr
|
3 |
from kiwipiepy import Kiwi
|
4 |
+
from typing import List, Tuple, Generator, Union
|
5 |
|
6 |
from langchain_core.output_parsers import StrOutputParser
|
7 |
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
|
|
|
25 |
retriever = load_retrievers(embeddings)
|
26 |
|
27 |
|
28 |
+
# ์ฌ์ฉ ๊ฐ๋ฅํ ๋ชจ๋ธ ๋ชฉ๋ก (key: ๋ชจ๋ธ ์๋ณ์, value: ์ฌ์ฉ์์๊ฒ ํ์ํ ๋ ์ด๋ธ)
|
29 |
+
AVAILABLE_MODELS = {
|
30 |
+
"gpt_3_5_turbo": "GPT-3.5 Turbo",
|
31 |
+
"gpt_4o": "GPT-4o",
|
32 |
+
"claude_3_5_sonnet": "Claude 3.5 Sonnet",
|
33 |
+
"gemini_1_5_flash": "Gemini 1.5 Flash",
|
34 |
+
"llama3_70b": "Llama3 70b",
|
35 |
+
}
|
36 |
|
37 |
|
38 |
+
def create_rag_chain(chat_history: List[Tuple[str, str]], model: str):
|
39 |
llm = get_llm(streaming=STREAMING).with_config(configurable={"llm": model})
|
40 |
prompt = get_prompt(chat_history)
|
41 |
|
|
|
51 |
)
|
52 |
|
53 |
|
54 |
+
def get_model_key(label):
|
55 |
+
return next(key for key, value in AVAILABLE_MODELS.items() if value == label)
|
56 |
+
|
57 |
+
|
58 |
+
def respond_stream(
|
59 |
+
message: str, history: List[Tuple[str, str]], model: str
|
60 |
+
) -> Generator[str, None, None]:
|
61 |
rag_chain = create_rag_chain(history, model)
|
62 |
for chunk in rag_chain.stream(message):
|
63 |
yield chunk
|
64 |
|
65 |
|
66 |
+
def respond(message: str, history: List[Tuple[str, str]], model: str) -> str:
|
67 |
rag_chain = create_rag_chain(history, model)
|
68 |
return rag_chain.invoke(message)
|
69 |
|
70 |
|
71 |
+
def get_model_key(label: str) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
return next(key for key, value in AVAILABLE_MODELS.items() if value == label)
|
73 |
|
74 |
|
75 |
+
def chat_function(
|
76 |
+
message: str, history: List[Tuple[str, str]], model_label: str
|
77 |
+
) -> Generator[str, None, None]:
|
78 |
model_key = get_model_key(model_label)
|
79 |
if STREAMING:
|
80 |
response = ""
|
libs/llm.py
CHANGED
@@ -30,7 +30,7 @@ def get_llm(streaming=True):
|
|
30 |
callbacks=[StreamCallback()],
|
31 |
),
|
32 |
gpt_3_5_turbo=ChatOpenAI(
|
33 |
-
model="gpt-3.5-turbo",
|
34 |
temperature=0,
|
35 |
streaming=streaming,
|
36 |
callbacks=[StreamCallback()],
|
|
|
30 |
callbacks=[StreamCallback()],
|
31 |
),
|
32 |
gpt_3_5_turbo=ChatOpenAI(
|
33 |
+
model="gpt-3.5-turbo-0125",
|
34 |
temperature=0,
|
35 |
streaming=streaming,
|
36 |
callbacks=[StreamCallback()],
|