chore: Add utils, config, embeddings, retrievers, prompt, and llm modules
Browse files- app.py +43 -54
- config.py +13 -0
- embeddings.py +12 -0
- llm.py +54 -0
- prompt.py +31 -0
- retrievers.py +22 -0
- utils.py +11 -0
app.py
CHANGED
@@ -1,63 +1,52 @@
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
-
from
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
|
|
|
|
|
|
|
|
28 |
response = ""
|
|
|
|
|
|
|
29 |
|
30 |
-
for message in client.chat_completion(
|
31 |
-
messages,
|
32 |
-
max_tokens=max_tokens,
|
33 |
-
stream=True,
|
34 |
-
temperature=temperature,
|
35 |
-
top_p=top_p,
|
36 |
-
):
|
37 |
-
token = message.choices[0].delta.content
|
38 |
|
39 |
-
|
40 |
-
|
|
|
|
|
41 |
|
42 |
-
"""
|
43 |
-
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
|
44 |
-
"""
|
45 |
demo = gr.ChatInterface(
|
46 |
-
respond,
|
47 |
-
|
48 |
-
|
49 |
-
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
|
50 |
-
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
|
51 |
-
gr.Slider(
|
52 |
-
minimum=0.1,
|
53 |
-
maximum=1.0,
|
54 |
-
value=0.95,
|
55 |
-
step=0.05,
|
56 |
-
label="Top-p (nucleus sampling)",
|
57 |
-
),
|
58 |
-
],
|
59 |
)
|
60 |
|
61 |
-
|
62 |
if __name__ == "__main__":
|
63 |
-
demo.launch()
|
|
|
1 |
+
# app.py
|
2 |
+
import os
|
3 |
import gradio as gr
|
4 |
+
from langchain_core.output_parsers import StrOutputParser
|
5 |
+
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
|
6 |
+
from langchain_community.document_transformers import LongContextReorder
|
7 |
+
from config import LLM_MODEL, STREAMING
|
8 |
+
from embeddings import get_embeddings
|
9 |
+
from retrievers import load_retrievers
|
10 |
+
from llm import get_llm
|
11 |
+
from prompt import get_prompt
|
12 |
+
|
13 |
+
|
14 |
+
def create_rag_chain(chat_history):
|
15 |
+
embeddings = get_embeddings()
|
16 |
+
retriever = load_retrievers(embeddings)
|
17 |
+
llm = get_llm(streaming=STREAMING)
|
18 |
+
prompt = get_prompt(chat_history)
|
19 |
+
|
20 |
+
return (
|
21 |
+
{
|
22 |
+
"context": retriever
|
23 |
+
| RunnableLambda(LongContextReorder().transform_documents),
|
24 |
+
"question": RunnablePassthrough(),
|
25 |
+
}
|
26 |
+
| prompt
|
27 |
+
| llm.with_config(configurable={"llm": LLM_MODEL})
|
28 |
+
| StrOutputParser()
|
29 |
+
)
|
30 |
+
|
31 |
+
|
32 |
+
def respond_stream(message, history):
|
33 |
+
rag_chain = create_rag_chain(history)
|
34 |
response = ""
|
35 |
+
for chunk in rag_chain.stream(message):
|
36 |
+
response += chunk
|
37 |
+
yield response
|
38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
+
def respond(message, history):
|
41 |
+
rag_chain = create_rag_chain(history)
|
42 |
+
return rag_chain.invoke(message)
|
43 |
+
|
44 |
|
|
|
|
|
|
|
45 |
demo = gr.ChatInterface(
|
46 |
+
respond_stream if STREAMING else respond,
|
47 |
+
title="ํ๋ก์ ๋ํด์ ๋ฌผ์ด๋ณด์ธ์!",
|
48 |
+
description="์๋
ํ์ธ์!\n์ ๋ ํ๋ก์ ๋ํ ์ธ๊ณต์ง๋ฅ QA๋ด์
๋๋ค. ํ๋ก์ ๋ํด ๊น์ ์ง์์ ๊ฐ์ง๊ณ ์์ด์. ํ๋ก์ ๊ดํ ๋์์ด ํ์ํ์๋ฉด ์ธ์ ๋ ์ง ์ง๋ฌธํด์ฃผ์ธ์!",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
)
|
50 |
|
|
|
51 |
if __name__ == "__main__":
|
52 |
+
demo.launch()
|
config.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# config.py
|
2 |
+
import os
|
3 |
+
from dotenv import load_dotenv
|
4 |
+
|
5 |
+
load_dotenv()
|
6 |
+
|
7 |
+
FAISS_DB_INDEX = "./index_faiss"
|
8 |
+
BM25_INDEX = "./index_bm25/kiwi.pkl"
|
9 |
+
CHUNK_SIZE = 2000
|
10 |
+
CHUNK_OVERLAP = 200
|
11 |
+
EMBEDDING_MODEL = "BAAI/bge-m3"
|
12 |
+
LLM_MODEL = os.getenv("MODEL_KEY", "gemini")
|
13 |
+
STREAMING = os.getenv("STREAMING", "true").lower() == "true"
|
embeddings.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# embeddings.py
|
2 |
+
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
|
3 |
+
from config import EMBEDDING_MODEL
|
4 |
+
from utils import get_device
|
5 |
+
|
6 |
+
|
7 |
+
def get_embeddings():
|
8 |
+
return HuggingFaceBgeEmbeddings(
|
9 |
+
model_name=EMBEDDING_MODEL,
|
10 |
+
model_kwargs={"device": get_device()},
|
11 |
+
encode_kwargs={"normalize_embeddings": True},
|
12 |
+
)
|
llm.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# llm.py
|
2 |
+
from langchain_openai import ChatOpenAI
|
3 |
+
from langchain_anthropic import ChatAnthropic
|
4 |
+
from langchain_google_genai import GoogleGenerativeAI
|
5 |
+
from langchain_groq import ChatGroq
|
6 |
+
from langchain_community.chat_models import ChatOllama
|
7 |
+
from langchain_core.runnables import ConfigurableField
|
8 |
+
from langchain.callbacks.base import BaseCallbackHandler
|
9 |
+
|
10 |
+
|
11 |
+
class StreamCallback(BaseCallbackHandler):
|
12 |
+
def on_llm_new_token(self, token: str, **kwargs):
|
13 |
+
print(token, end="", flush=True)
|
14 |
+
|
15 |
+
|
16 |
+
def get_llm(streaming=True):
|
17 |
+
return ChatOpenAI(
|
18 |
+
model="gpt-4",
|
19 |
+
temperature=0,
|
20 |
+
streaming=streaming,
|
21 |
+
callbacks=[StreamCallback()],
|
22 |
+
).configurable_alternatives(
|
23 |
+
ConfigurableField(id="llm"),
|
24 |
+
default_key="gpt4",
|
25 |
+
claude=ChatAnthropic(
|
26 |
+
model="claude-3-opus-20240229",
|
27 |
+
temperature=0,
|
28 |
+
streaming=streaming,
|
29 |
+
callbacks=[StreamCallback()],
|
30 |
+
),
|
31 |
+
gpt3=ChatOpenAI(
|
32 |
+
model="gpt-3.5-turbo",
|
33 |
+
temperature=0,
|
34 |
+
streaming=streaming,
|
35 |
+
callbacks=[StreamCallback()],
|
36 |
+
),
|
37 |
+
gemini=GoogleGenerativeAI(
|
38 |
+
model="gemini-1.5-flash",
|
39 |
+
temperature=0,
|
40 |
+
streaming=streaming,
|
41 |
+
callbacks=[StreamCallback()],
|
42 |
+
),
|
43 |
+
llama3=ChatGroq(
|
44 |
+
model_name="llama3-70b-8192",
|
45 |
+
temperature=0,
|
46 |
+
streaming=streaming,
|
47 |
+
callbacks=[StreamCallback()],
|
48 |
+
),
|
49 |
+
ollama=ChatOllama(
|
50 |
+
model="EEVE-Korean-10.8B:long",
|
51 |
+
streaming=streaming,
|
52 |
+
callbacks=[StreamCallback()],
|
53 |
+
),
|
54 |
+
)
|
prompt.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# prompt.py
|
2 |
+
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
3 |
+
|
4 |
+
PROMPT_TEMPLATE = """๋น์ ์ ํ์ฌ์ด์ 20๋
์ฐจ ๋ฒ๋ฅ ์ ๋ฌธ๊ฐ์
๋๋ค. ์ฃผ์ด์ง ์ง๋ฌธ์ ๋ํด ๋ฌธ์์ ์ ๋ณด๋ฅผ ์ต๋ํ ํ์ฉํ์ฌ ๋ต๋ณํ์ธ์. ์ง๋ฌธ์๋ ์์ ์ ์ํฉ์ ์ค๋ช
ํ ๊ฒ์ด๋ฉฐ, ์ง๋ฌธ์์ ์ํฉ๊ณผ ๋น์ทํ ํ๋ก๋ฅผ ์ค๋ช
ํด์ค์ผ ํฉ๋๋ค. ๊ฐ์ฅ ์ต๊ทผ ์ฌ๊ฑด ์์ผ๋ก ์๊ฐํ๋ฉฐ, ์ด๋ฑํ์์ด ์ดํดํ ์ ์๋๋ก ์ต๋ํ ์์ธํ๊ณ ์ฝ๊ฒ ์ค๋ช
ํ์ธ์. ๋ต๋ณ์ [์ฌ๊ฑด๋ช
1]..., [์ฌ๊ฑด๋ช
2]... ์์๋ก ๊ตฌ์ฑํฉ๋๋ค. ๋ฌธ์์์ ๋ต๋ณ์ ์ฐพ์ ์ ์๋ ๊ฒฝ์ฐ, "๋ฌธ์์ ๋ต๋ณ์ด ์์ต๋๋ค."๋ผ๊ณ ๋ต๋ณํ์ธ์.
|
5 |
+
|
6 |
+
๋ต๋ณ์ ์ถ์ฒ(source)๋ฅผ ๋ฐ๋์ ํ๊ธฐํฉ๋๋ค. ์ถ์ฒ๋ ๋ฉํ๋ฐ์ดํฐ์ ํ๋ก์ผ๋ จ๋ฒํธ, ์ฌ๊ฑด๋ช
, ์ฌ๊ฑด๋ฒํธ ์์ผ๋ก ํ๊ธฐํฉ๋๋ค.
|
7 |
+
|
8 |
+
---
|
9 |
+
|
10 |
+
# ์ฃผ์ด์ง ๋ฌธ์:
|
11 |
+
{context}
|
12 |
+
|
13 |
+
# ์ง๋ฌธ: {question}
|
14 |
+
|
15 |
+
# ๋ต๋ณ:
|
16 |
+
|
17 |
+
# ์ถ์ฒ:
|
18 |
+
- source1
|
19 |
+
- source2
|
20 |
+
- ...
|
21 |
+
"""
|
22 |
+
|
23 |
+
|
24 |
+
def get_prompt(chat_history):
|
25 |
+
return ChatPromptTemplate.from_messages(
|
26 |
+
[
|
27 |
+
("system", PROMPT_TEMPLATE),
|
28 |
+
MessagesPlaceholder(variable_name="history"),
|
29 |
+
("human", "{question}"),
|
30 |
+
]
|
31 |
+
).partial(history=chat_history.messages)
|
retrievers.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# retrievers.py
|
2 |
+
import pickle
|
3 |
+
from langchain.vectorstores import FAISS
|
4 |
+
from langchain.retrievers import EnsembleRetriever
|
5 |
+
from kiwipiepy import Kiwi
|
6 |
+
from config import FAISS_DB_INDEX, BM25_INDEX
|
7 |
+
|
8 |
+
|
9 |
+
def load_retrievers(embeddings):
|
10 |
+
faiss_db = FAISS.load_local(
|
11 |
+
FAISS_DB_INDEX, embeddings, allow_dangerous_deserialization=True
|
12 |
+
)
|
13 |
+
faiss_retriever = faiss_db.as_retriever(search_type="mmr", search_kwargs={"k": 10})
|
14 |
+
|
15 |
+
with open(BM25_INDEX, "rb") as f:
|
16 |
+
bm25_retriever = pickle.load(f)
|
17 |
+
|
18 |
+
return EnsembleRetriever(
|
19 |
+
retrievers=[bm25_retriever, faiss_retriever],
|
20 |
+
weights=[0.7, 0.3],
|
21 |
+
search_type="mmr",
|
22 |
+
)
|
utils.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# utils.py
|
2 |
+
import torch
|
3 |
+
|
4 |
+
|
5 |
+
def get_device():
|
6 |
+
if torch.cuda.is_available():
|
7 |
+
return "cuda:0"
|
8 |
+
elif torch.backends.mps.is_available():
|
9 |
+
return "mps"
|
10 |
+
else:
|
11 |
+
return "cpu"
|