anpigon commited on
Commit
56487d0
ยท
1 Parent(s): 47cc7d0

chore: Add utils, config, embeddings, retrievers, prompt, and llm modules

Browse files
Files changed (7) hide show
  1. app.py +43 -54
  2. config.py +13 -0
  3. embeddings.py +12 -0
  4. llm.py +54 -0
  5. prompt.py +31 -0
  6. retrievers.py +22 -0
  7. utils.py +11 -0
app.py CHANGED
@@ -1,63 +1,52 @@
 
 
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
-
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
-
9
-
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
-
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
-
26
- messages.append({"role": "user", "content": message})
27
-
 
 
 
 
28
  response = ""
 
 
 
29
 
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
 
39
- response += token
40
- yield response
 
 
41
 
42
- """
43
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
44
- """
45
  demo = gr.ChatInterface(
46
- respond,
47
- additional_inputs=[
48
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
49
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
50
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
51
- gr.Slider(
52
- minimum=0.1,
53
- maximum=1.0,
54
- value=0.95,
55
- step=0.05,
56
- label="Top-p (nucleus sampling)",
57
- ),
58
- ],
59
  )
60
 
61
-
62
  if __name__ == "__main__":
63
- demo.launch()
 
1
+ # app.py
2
+ import os
3
  import gradio as gr
4
+ from langchain_core.output_parsers import StrOutputParser
5
+ from langchain_core.runnables import RunnablePassthrough, RunnableLambda
6
+ from langchain_community.document_transformers import LongContextReorder
7
+ from config import LLM_MODEL, STREAMING
8
+ from embeddings import get_embeddings
9
+ from retrievers import load_retrievers
10
+ from llm import get_llm
11
+ from prompt import get_prompt
12
+
13
+
14
+ def create_rag_chain(chat_history):
15
+ embeddings = get_embeddings()
16
+ retriever = load_retrievers(embeddings)
17
+ llm = get_llm(streaming=STREAMING)
18
+ prompt = get_prompt(chat_history)
19
+
20
+ return (
21
+ {
22
+ "context": retriever
23
+ | RunnableLambda(LongContextReorder().transform_documents),
24
+ "question": RunnablePassthrough(),
25
+ }
26
+ | prompt
27
+ | llm.with_config(configurable={"llm": LLM_MODEL})
28
+ | StrOutputParser()
29
+ )
30
+
31
+
32
+ def respond_stream(message, history):
33
+ rag_chain = create_rag_chain(history)
34
  response = ""
35
+ for chunk in rag_chain.stream(message):
36
+ response += chunk
37
+ yield response
38
 
 
 
 
 
 
 
 
 
39
 
40
+ def respond(message, history):
41
+ rag_chain = create_rag_chain(history)
42
+ return rag_chain.invoke(message)
43
+
44
 
 
 
 
45
  demo = gr.ChatInterface(
46
+ respond_stream if STREAMING else respond,
47
+ title="ํŒ๋ก€์— ๋Œ€ํ•ด์„œ ๋ฌผ์–ด๋ณด์„ธ์š”!",
48
+ description="์•ˆ๋…•ํ•˜์„ธ์š”!\n์ €๋Š” ํŒ๋ก€์— ๋Œ€ํ•œ ์ธ๊ณต์ง€๋Šฅ QA๋ด‡์ž…๋‹ˆ๋‹ค. ํŒ๋ก€์— ๋Œ€ํ•ด ๊นŠ์€ ์ง€์‹์„ ๊ฐ€์ง€๊ณ  ์žˆ์–ด์š”. ํŒ๋ก€์— ๊ด€ํ•œ ๋„์›€์ด ํ•„์š”ํ•˜์‹œ๋ฉด ์–ธ์ œ๋“ ์ง€ ์งˆ๋ฌธํ•ด์ฃผ์„ธ์š”!",
 
 
 
 
 
 
 
 
 
 
49
  )
50
 
 
51
  if __name__ == "__main__":
52
+ demo.launch()
config.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # config.py
2
+ import os
3
+ from dotenv import load_dotenv
4
+
5
+ load_dotenv()
6
+
7
+ FAISS_DB_INDEX = "./index_faiss"
8
+ BM25_INDEX = "./index_bm25/kiwi.pkl"
9
+ CHUNK_SIZE = 2000
10
+ CHUNK_OVERLAP = 200
11
+ EMBEDDING_MODEL = "BAAI/bge-m3"
12
+ LLM_MODEL = os.getenv("MODEL_KEY", "gemini")
13
+ STREAMING = os.getenv("STREAMING", "true").lower() == "true"
embeddings.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # embeddings.py
2
+ from langchain_community.embeddings import HuggingFaceBgeEmbeddings
3
+ from config import EMBEDDING_MODEL
4
+ from utils import get_device
5
+
6
+
7
+ def get_embeddings():
8
+ return HuggingFaceBgeEmbeddings(
9
+ model_name=EMBEDDING_MODEL,
10
+ model_kwargs={"device": get_device()},
11
+ encode_kwargs={"normalize_embeddings": True},
12
+ )
llm.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # llm.py
2
+ from langchain_openai import ChatOpenAI
3
+ from langchain_anthropic import ChatAnthropic
4
+ from langchain_google_genai import GoogleGenerativeAI
5
+ from langchain_groq import ChatGroq
6
+ from langchain_community.chat_models import ChatOllama
7
+ from langchain_core.runnables import ConfigurableField
8
+ from langchain.callbacks.base import BaseCallbackHandler
9
+
10
+
11
+ class StreamCallback(BaseCallbackHandler):
12
+ def on_llm_new_token(self, token: str, **kwargs):
13
+ print(token, end="", flush=True)
14
+
15
+
16
+ def get_llm(streaming=True):
17
+ return ChatOpenAI(
18
+ model="gpt-4",
19
+ temperature=0,
20
+ streaming=streaming,
21
+ callbacks=[StreamCallback()],
22
+ ).configurable_alternatives(
23
+ ConfigurableField(id="llm"),
24
+ default_key="gpt4",
25
+ claude=ChatAnthropic(
26
+ model="claude-3-opus-20240229",
27
+ temperature=0,
28
+ streaming=streaming,
29
+ callbacks=[StreamCallback()],
30
+ ),
31
+ gpt3=ChatOpenAI(
32
+ model="gpt-3.5-turbo",
33
+ temperature=0,
34
+ streaming=streaming,
35
+ callbacks=[StreamCallback()],
36
+ ),
37
+ gemini=GoogleGenerativeAI(
38
+ model="gemini-1.5-flash",
39
+ temperature=0,
40
+ streaming=streaming,
41
+ callbacks=[StreamCallback()],
42
+ ),
43
+ llama3=ChatGroq(
44
+ model_name="llama3-70b-8192",
45
+ temperature=0,
46
+ streaming=streaming,
47
+ callbacks=[StreamCallback()],
48
+ ),
49
+ ollama=ChatOllama(
50
+ model="EEVE-Korean-10.8B:long",
51
+ streaming=streaming,
52
+ callbacks=[StreamCallback()],
53
+ ),
54
+ )
prompt.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # prompt.py
2
+ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
3
+
4
+ PROMPT_TEMPLATE = """๋‹น์‹ ์€ ํŒ์‚ฌ์ด์ž 20๋…„ ์ฐจ ๋ฒ•๋ฅ  ์ „๋ฌธ๊ฐ€์ž…๋‹ˆ๋‹ค. ์ฃผ์–ด์ง„ ์งˆ๋ฌธ์— ๋Œ€ํ•ด ๋ฌธ์„œ์˜ ์ •๋ณด๋ฅผ ์ตœ๋Œ€ํ•œ ํ™œ์šฉํ•˜์—ฌ ๋‹ต๋ณ€ํ•˜์„ธ์š”. ์งˆ๋ฌธ์ž๋Š” ์ž์‹ ์˜ ์ƒํ™ฉ์„ ์„ค๋ช…ํ•  ๊ฒƒ์ด๋ฉฐ, ์งˆ๋ฌธ์ž์˜ ์ƒํ™ฉ๊ณผ ๋น„์Šทํ•œ ํŒ๋ก€๋ฅผ ์„ค๋ช…ํ•ด์ค˜์•ผ ํ•ฉ๋‹ˆ๋‹ค. ๊ฐ€์žฅ ์ตœ๊ทผ ์‚ฌ๊ฑด ์ˆœ์œผ๋กœ ์†Œ๊ฐœํ•˜๋ฉฐ, ์ดˆ๋“ฑํ•™์ƒ์ด ์ดํ•ดํ•  ์ˆ˜ ์žˆ๋„๋ก ์ตœ๋Œ€ํ•œ ์ž์„ธํ•˜๊ณ  ์‰ฝ๊ฒŒ ์„ค๋ช…ํ•˜์„ธ์š”. ๋‹ต๋ณ€์€ [์‚ฌ๊ฑด๋ช… 1]..., [์‚ฌ๊ฑด๋ช… 2]... ์ˆœ์„œ๋กœ ๊ตฌ์„ฑํ•ฉ๋‹ˆ๋‹ค. ๋ฌธ์„œ์—์„œ ๋‹ต๋ณ€์„ ์ฐพ์„ ์ˆ˜ ์—†๋Š” ๊ฒฝ์šฐ, "๋ฌธ์„œ์— ๋‹ต๋ณ€์ด ์—†์Šต๋‹ˆ๋‹ค."๋ผ๊ณ  ๋‹ต๋ณ€ํ•˜์„ธ์š”.
5
+
6
+ ๋‹ต๋ณ€์˜ ์ถœ์ฒ˜(source)๋ฅผ ๋ฐ˜๋“œ์‹œ ํ‘œ๊ธฐํ•ฉ๋‹ˆ๋‹ค. ์ถœ์ฒ˜๋Š” ๋ฉ”ํƒ€๋ฐ์ดํ„ฐ์˜ ํŒ๋ก€์ผ๋ จ๋ฒˆํ˜ธ, ์‚ฌ๊ฑด๋ช…, ์‚ฌ๊ฑด๋ฒˆํ˜ธ ์ˆœ์œผ๋กœ ํ‘œ๊ธฐํ•ฉ๋‹ˆ๋‹ค.
7
+
8
+ ---
9
+
10
+ # ์ฃผ์–ด์ง„ ๋ฌธ์„œ:
11
+ {context}
12
+
13
+ # ์งˆ๋ฌธ: {question}
14
+
15
+ # ๋‹ต๋ณ€:
16
+
17
+ # ์ถœ์ฒ˜:
18
+ - source1
19
+ - source2
20
+ - ...
21
+ """
22
+
23
+
24
+ def get_prompt(chat_history):
25
+ return ChatPromptTemplate.from_messages(
26
+ [
27
+ ("system", PROMPT_TEMPLATE),
28
+ MessagesPlaceholder(variable_name="history"),
29
+ ("human", "{question}"),
30
+ ]
31
+ ).partial(history=chat_history.messages)
retrievers.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # retrievers.py
2
+ import pickle
3
+ from langchain.vectorstores import FAISS
4
+ from langchain.retrievers import EnsembleRetriever
5
+ from kiwipiepy import Kiwi
6
+ from config import FAISS_DB_INDEX, BM25_INDEX
7
+
8
+
9
+ def load_retrievers(embeddings):
10
+ faiss_db = FAISS.load_local(
11
+ FAISS_DB_INDEX, embeddings, allow_dangerous_deserialization=True
12
+ )
13
+ faiss_retriever = faiss_db.as_retriever(search_type="mmr", search_kwargs={"k": 10})
14
+
15
+ with open(BM25_INDEX, "rb") as f:
16
+ bm25_retriever = pickle.load(f)
17
+
18
+ return EnsembleRetriever(
19
+ retrievers=[bm25_retriever, faiss_retriever],
20
+ weights=[0.7, 0.3],
21
+ search_type="mmr",
22
+ )
utils.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # utils.py
2
+ import torch
3
+
4
+
5
+ def get_device():
6
+ if torch.cuda.is_available():
7
+ return "cuda:0"
8
+ elif torch.backends.mps.is_available():
9
+ return "mps"
10
+ else:
11
+ return "cpu"