feat: vectorsearch-based QA
Browse files- app.py +112 -2
- embeddings.py +37 -0
- requirements.txt +83 -0
app.py
CHANGED
@@ -1,4 +1,114 @@
|
|
|
|
|
|
1 |
import streamlit as st
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
-
|
4 |
-
st.
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from embeddings import KorRobertaEmbeddings
|
2 |
+
|
3 |
import streamlit as st
|
4 |
+
from streamlit import session_state as sst
|
5 |
+
|
6 |
+
from langchain_core.runnables import (
|
7 |
+
RunnablePassthrough,
|
8 |
+
RunnableParallel,
|
9 |
+
)
|
10 |
+
|
11 |
+
PINECONE_API_KEY = st.secrets["PINECONE_API_KEY"]
|
12 |
+
|
13 |
+
|
14 |
+
def create_or_get_pinecone_index(index_name: str, dimension: int = 768):
|
15 |
+
from pinecone import Pinecone, ServerlessSpec
|
16 |
+
|
17 |
+
client = Pinecone(api_key=PINECONE_API_KEY)
|
18 |
+
if index_name in [index["name"] for index in client.list_indexes()]:
|
19 |
+
pc_index = client.Index(index_name)
|
20 |
+
print("☑️ Got the existing Pinecone index")
|
21 |
+
else:
|
22 |
+
client.create_index(
|
23 |
+
name=index_name,
|
24 |
+
dimension=dimension,
|
25 |
+
metric="cosine",
|
26 |
+
spec=ServerlessSpec("aws", "us-west-2"),
|
27 |
+
)
|
28 |
+
pc_index = client.Index(index_name)
|
29 |
+
print("☑️ Created a new Pinecone index")
|
30 |
+
|
31 |
+
print(pc_index.describe_index_stats())
|
32 |
+
return pc_index
|
33 |
+
|
34 |
+
|
35 |
+
def get_pinecone_vectorstore(
|
36 |
+
index_name: str,
|
37 |
+
embedding_fn=KorRobertaEmbeddings(),
|
38 |
+
dimension: int = 768,
|
39 |
+
namespace: str = None,
|
40 |
+
):
|
41 |
+
from langchain_pinecone import Pinecone
|
42 |
+
|
43 |
+
index = create_or_get_pinecone_index(
|
44 |
+
index_name,
|
45 |
+
dimension,
|
46 |
+
)
|
47 |
+
vs = Pinecone(
|
48 |
+
index,
|
49 |
+
embedding_fn,
|
50 |
+
pinecone_api_key=PINECONE_API_KEY,
|
51 |
+
index_name=index_name,
|
52 |
+
namespace=namespace,
|
53 |
+
)
|
54 |
+
print(vs)
|
55 |
+
return vs
|
56 |
+
|
57 |
+
|
58 |
+
def build_pinecone_retrieval_chain(vectorstore):
|
59 |
+
retriever = vectorstore.as_retriever()
|
60 |
+
rag_chain_with_source = RunnableParallel(
|
61 |
+
{"context": retriever, "question": RunnablePassthrough()}
|
62 |
+
)
|
63 |
+
|
64 |
+
return rag_chain_with_source
|
65 |
+
|
66 |
+
|
67 |
+
@st.cache_resource
|
68 |
+
def get_pinecone_retrieval_chain(collection_name):
|
69 |
+
print("☑️ Building a new pinecone retrieval chain...")
|
70 |
+
embed_fn = KorRobertaEmbeddings()
|
71 |
+
pinecone_vectorstore = get_pinecone_vectorstore(
|
72 |
+
index_name=collection_name,
|
73 |
+
embedding_fn=embed_fn,
|
74 |
+
dimension=768,
|
75 |
+
namespace="0221",
|
76 |
+
)
|
77 |
+
|
78 |
+
chain = build_pinecone_retrieval_chain(pinecone_vectorstore)
|
79 |
+
return chain
|
80 |
+
|
81 |
+
|
82 |
+
def rerun():
|
83 |
+
st.rerun()
|
84 |
+
|
85 |
+
|
86 |
+
st.title("이노션 데모")
|
87 |
+
|
88 |
+
with st.spinner("환경 설정 중"):
|
89 |
+
sst.retrieval_chain = get_pinecone_retrieval_chain(
|
90 |
+
collection_name="innocean",
|
91 |
+
)
|
92 |
+
|
93 |
+
if prompt := st.chat_input("정보 검색"):
|
94 |
+
|
95 |
+
# Display user message in chat message container
|
96 |
+
with st.chat_message("human"):
|
97 |
+
st.markdown(prompt)
|
98 |
+
|
99 |
+
# Get assistant response
|
100 |
+
outputs = sst.retrieval_chain.invoke(prompt)
|
101 |
+
print(outputs)
|
102 |
+
retrieval_docs = outputs["context"]
|
103 |
+
|
104 |
+
# Display assistant response in chat message container
|
105 |
+
with st.chat_message("assistant"):
|
106 |
+
st.markdown(retrieval_docs[0].metadata["answer"])
|
107 |
|
108 |
+
with st.expander("출처 보기", expanded=True):
|
109 |
+
st.info(f"출처 페이지: {retrieval_docs[0].metadata['page']}")
|
110 |
+
st.markdown(retrieval_docs[0].metadata["source_passage"])
|
111 |
+
# tabs = st.tabs([f"doc{i}" for i in range(len(retrieval_docs))])
|
112 |
+
# for i in range(len(retrieval_docs)):
|
113 |
+
# tabs[i].write(retrieval_docs[i].page_content)
|
114 |
+
# tabs[i].write(retrieval_docs[i].metadata)
|
embeddings.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
|
3 |
+
import os
|
4 |
+
from langchain_core.embeddings import Embeddings
|
5 |
+
from transformers import AutoModel, AutoTokenizer
|
6 |
+
|
7 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
8 |
+
|
9 |
+
|
10 |
+
def get_roberta_embeddings(sentences: List[str]):
|
11 |
+
"""
|
12 |
+
Get features of Korean input texts w/ BM-K/KoSimCSE-roberta.
|
13 |
+
Returns:
|
14 |
+
List[List[int]] of dimension 768
|
15 |
+
"""
|
16 |
+
|
17 |
+
model = AutoModel.from_pretrained("BM-K/KoSimCSE-roberta")
|
18 |
+
tokenizer = AutoTokenizer.from_pretrained("BM-K/KoSimCSE-roberta")
|
19 |
+
inputs = tokenizer(sentences, padding=True, truncation=True, return_tensors="pt")
|
20 |
+
embeddings, _ = model(**inputs, return_dict=False)
|
21 |
+
ls = []
|
22 |
+
for embedding in embeddings:
|
23 |
+
vector = embedding[0].detach().numpy().tolist()
|
24 |
+
ls.append(vector)
|
25 |
+
return ls
|
26 |
+
|
27 |
+
|
28 |
+
class KorRobertaEmbeddings(Embeddings):
|
29 |
+
"""Feature Extraction w/ BM-K/KoSimCSE-roberta"""
|
30 |
+
|
31 |
+
dimension = 768
|
32 |
+
|
33 |
+
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
34 |
+
return get_roberta_embeddings(texts)
|
35 |
+
|
36 |
+
def embed_query(self, text: str) -> List[float]:
|
37 |
+
return get_roberta_embeddings([text])[0]
|
requirements.txt
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
-i https://pypi.org/simple
|
2 |
+
aiohttp==3.9.3; python_version >= '3.8'
|
3 |
+
aiosignal==1.3.1; python_version >= '3.7'
|
4 |
+
altair==5.2.0; python_version >= '3.8'
|
5 |
+
annotated-types==0.6.0; python_version >= '3.8'
|
6 |
+
anyio==4.3.0; python_version >= '3.8'
|
7 |
+
async-timeout==4.0.3; python_version < '3.11'
|
8 |
+
attrs==23.2.0; python_version >= '3.7'
|
9 |
+
blinker==1.7.0; python_version >= '3.8'
|
10 |
+
cachetools==5.3.2; python_version >= '3.7'
|
11 |
+
certifi==2024.2.2; python_version >= '3.6'
|
12 |
+
charset-normalizer==3.3.2; python_full_version >= '3.7.0'
|
13 |
+
click==8.1.7; python_version >= '3.7'
|
14 |
+
dataclasses-json==0.6.4; python_version >= '3.7' and python_version < '4.0'
|
15 |
+
exceptiongroup==1.2.0; python_version < '3.11'
|
16 |
+
filelock==3.13.1; python_version >= '3.8'
|
17 |
+
frozenlist==1.4.1; python_version >= '3.8'
|
18 |
+
fsspec==2024.2.0; python_version >= '3.8'
|
19 |
+
gitdb==4.0.11; python_version >= '3.7'
|
20 |
+
gitpython==3.1.42; python_version >= '3.7'
|
21 |
+
huggingface-hub==0.20.3; python_full_version >= '3.8.0'
|
22 |
+
idna==3.6; python_version >= '3.5'
|
23 |
+
importlib-metadata==7.0.1; python_version >= '3.8'
|
24 |
+
jinja2==3.1.3; python_version >= '3.7'
|
25 |
+
jsonpatch==1.33; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4, 3.5, 3.6'
|
26 |
+
jsonpointer==2.4; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4, 3.5, 3.6'
|
27 |
+
jsonschema==4.21.1; python_version >= '3.8'
|
28 |
+
jsonschema-specifications==2023.12.1; python_version >= '3.8'
|
29 |
+
langchain==0.1.8; python_version < '4.0' and python_full_version >= '3.8.1'
|
30 |
+
langchain-community==0.0.21; python_version < '4.0' and python_full_version >= '3.8.1'
|
31 |
+
langchain-core==0.1.25; python_version < '4.0' and python_full_version >= '3.8.1'
|
32 |
+
langchain-pinecone==0.0.2; python_version < '3.13' and python_full_version >= '3.8.1'
|
33 |
+
langsmith==0.1.5; python_version < '4.0' and python_full_version >= '3.8.1'
|
34 |
+
markdown-it-py==3.0.0; python_version >= '3.8'
|
35 |
+
markupsafe==2.1.5; python_version >= '3.7'
|
36 |
+
marshmallow==3.20.2; python_version >= '3.8'
|
37 |
+
mdurl==0.1.2; python_version >= '3.7'
|
38 |
+
mpmath==1.3.0
|
39 |
+
multidict==6.0.5; python_version >= '3.7'
|
40 |
+
mypy-extensions==1.0.0; python_version >= '3.5'
|
41 |
+
networkx==3.2.1; python_version >= '3.9'
|
42 |
+
numpy==1.26.4; python_version >= '3.9'
|
43 |
+
packaging==23.2; python_version >= '3.7'
|
44 |
+
pandas==2.2.0; python_version >= '3.9'
|
45 |
+
pillow==10.2.0; python_version >= '3.8'
|
46 |
+
pinecone-client==3.0.3; python_version < '3.13' and python_version >= '3.8'
|
47 |
+
protobuf==4.25.3; python_version >= '3.8'
|
48 |
+
pyarrow==15.0.0; python_version >= '3.8'
|
49 |
+
pydantic==2.6.1; python_version >= '3.8'
|
50 |
+
pydantic-core==2.16.2; python_version >= '3.8'
|
51 |
+
pydeck==0.8.1b0; python_version >= '3.7'
|
52 |
+
pygments==2.17.2; python_version >= '3.7'
|
53 |
+
python-dateutil==2.8.2; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'
|
54 |
+
pytz==2024.1
|
55 |
+
pyyaml==6.0.1; python_version >= '3.6'
|
56 |
+
referencing==0.33.0; python_version >= '3.8'
|
57 |
+
regex==2023.12.25; python_version >= '3.7'
|
58 |
+
requests==2.31.0; python_version >= '3.7'
|
59 |
+
rich==13.7.0; python_full_version >= '3.7.0'
|
60 |
+
rpds-py==0.18.0; python_version >= '3.8'
|
61 |
+
safetensors==0.4.2; python_version >= '3.7'
|
62 |
+
six==1.16.0; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'
|
63 |
+
smmap==5.0.1; python_version >= '3.7'
|
64 |
+
sniffio==1.3.0; python_version >= '3.7'
|
65 |
+
sqlalchemy==2.0.27; python_version >= '3.7'
|
66 |
+
streamlit==1.31.1; python_version >= '3.8' and python_full_version != '3.9.7'
|
67 |
+
sympy==1.12; python_version >= '3.8'
|
68 |
+
tenacity==8.2.3; python_version >= '3.7'
|
69 |
+
tokenizers==0.15.2; python_version >= '3.7'
|
70 |
+
toml==0.10.2; python_version >= '2.6' and python_version not in '3.0, 3.1, 3.2, 3.3'
|
71 |
+
toolz==0.12.1; python_version >= '3.7'
|
72 |
+
torch==2.2.0; python_full_version >= '3.8.0'
|
73 |
+
tornado==6.4; python_version >= '3.8'
|
74 |
+
tqdm==4.66.2; python_version >= '3.7'
|
75 |
+
transformers==4.37.2; python_full_version >= '3.8.0'
|
76 |
+
typing-extensions==4.9.0; python_version >= '3.8'
|
77 |
+
typing-inspect==0.9.0
|
78 |
+
tzdata==2024.1; python_version >= '2'
|
79 |
+
tzlocal==5.2; python_version >= '3.8'
|
80 |
+
urllib3==2.2.1; python_version >= '3.8'
|
81 |
+
validators==0.22.0; python_version >= '3.8'
|
82 |
+
yarl==1.9.4; python_version >= '3.7'
|
83 |
+
zipp==3.17.0; python_version >= '3.8'
|