WJL commited on
Commit
9e88bc1
·
1 Parent(s): ab1dc24

feat: vectorsearch-based QA

Browse files
Files changed (3) hide show
  1. app.py +112 -2
  2. embeddings.py +37 -0
  3. requirements.txt +83 -0
app.py CHANGED
@@ -1,4 +1,114 @@
 
 
1
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- x = st.slider('Select a value')
4
- st.write(x, 'squared is', x * x)
 
 
 
 
 
 
1
+ from embeddings import KorRobertaEmbeddings
2
+
3
  import streamlit as st
4
+ from streamlit import session_state as sst
5
+
6
+ from langchain_core.runnables import (
7
+ RunnablePassthrough,
8
+ RunnableParallel,
9
+ )
10
+
11
+ PINECONE_API_KEY = st.secrets["PINECONE_API_KEY"]
12
+
13
+
14
+ def create_or_get_pinecone_index(index_name: str, dimension: int = 768):
15
+ from pinecone import Pinecone, ServerlessSpec
16
+
17
+ client = Pinecone(api_key=PINECONE_API_KEY)
18
+ if index_name in [index["name"] for index in client.list_indexes()]:
19
+ pc_index = client.Index(index_name)
20
+ print("☑️ Got the existing Pinecone index")
21
+ else:
22
+ client.create_index(
23
+ name=index_name,
24
+ dimension=dimension,
25
+ metric="cosine",
26
+ spec=ServerlessSpec("aws", "us-west-2"),
27
+ )
28
+ pc_index = client.Index(index_name)
29
+ print("☑️ Created a new Pinecone index")
30
+
31
+ print(pc_index.describe_index_stats())
32
+ return pc_index
33
+
34
+
35
+ def get_pinecone_vectorstore(
36
+ index_name: str,
37
+ embedding_fn=KorRobertaEmbeddings(),
38
+ dimension: int = 768,
39
+ namespace: str = None,
40
+ ):
41
+ from langchain_pinecone import Pinecone
42
+
43
+ index = create_or_get_pinecone_index(
44
+ index_name,
45
+ dimension,
46
+ )
47
+ vs = Pinecone(
48
+ index,
49
+ embedding_fn,
50
+ pinecone_api_key=PINECONE_API_KEY,
51
+ index_name=index_name,
52
+ namespace=namespace,
53
+ )
54
+ print(vs)
55
+ return vs
56
+
57
+
58
+ def build_pinecone_retrieval_chain(vectorstore):
59
+ retriever = vectorstore.as_retriever()
60
+ rag_chain_with_source = RunnableParallel(
61
+ {"context": retriever, "question": RunnablePassthrough()}
62
+ )
63
+
64
+ return rag_chain_with_source
65
+
66
+
67
+ @st.cache_resource
68
+ def get_pinecone_retrieval_chain(collection_name):
69
+ print("☑️ Building a new pinecone retrieval chain...")
70
+ embed_fn = KorRobertaEmbeddings()
71
+ pinecone_vectorstore = get_pinecone_vectorstore(
72
+ index_name=collection_name,
73
+ embedding_fn=embed_fn,
74
+ dimension=768,
75
+ namespace="0221",
76
+ )
77
+
78
+ chain = build_pinecone_retrieval_chain(pinecone_vectorstore)
79
+ return chain
80
+
81
+
82
+ def rerun():
83
+ st.rerun()
84
+
85
+
86
+ st.title("이노션 데모")
87
+
88
+ with st.spinner("환경 설정 중"):
89
+ sst.retrieval_chain = get_pinecone_retrieval_chain(
90
+ collection_name="innocean",
91
+ )
92
+
93
+ if prompt := st.chat_input("정보 검색"):
94
+
95
+ # Display user message in chat message container
96
+ with st.chat_message("human"):
97
+ st.markdown(prompt)
98
+
99
+ # Get assistant response
100
+ outputs = sst.retrieval_chain.invoke(prompt)
101
+ print(outputs)
102
+ retrieval_docs = outputs["context"]
103
+
104
+ # Display assistant response in chat message container
105
+ with st.chat_message("assistant"):
106
+ st.markdown(retrieval_docs[0].metadata["answer"])
107
 
108
+ with st.expander("출처 보기", expanded=True):
109
+ st.info(f"출처 페이지: {retrieval_docs[0].metadata['page']}")
110
+ st.markdown(retrieval_docs[0].metadata["source_passage"])
111
+ # tabs = st.tabs([f"doc{i}" for i in range(len(retrieval_docs))])
112
+ # for i in range(len(retrieval_docs)):
113
+ # tabs[i].write(retrieval_docs[i].page_content)
114
+ # tabs[i].write(retrieval_docs[i].metadata)
embeddings.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import os
4
+ from langchain_core.embeddings import Embeddings
5
+ from transformers import AutoModel, AutoTokenizer
6
+
7
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
8
+
9
+
10
+ def get_roberta_embeddings(sentences: List[str]):
11
+ """
12
+ Get features of Korean input texts w/ BM-K/KoSimCSE-roberta.
13
+ Returns:
14
+ List[List[int]] of dimension 768
15
+ """
16
+
17
+ model = AutoModel.from_pretrained("BM-K/KoSimCSE-roberta")
18
+ tokenizer = AutoTokenizer.from_pretrained("BM-K/KoSimCSE-roberta")
19
+ inputs = tokenizer(sentences, padding=True, truncation=True, return_tensors="pt")
20
+ embeddings, _ = model(**inputs, return_dict=False)
21
+ ls = []
22
+ for embedding in embeddings:
23
+ vector = embedding[0].detach().numpy().tolist()
24
+ ls.append(vector)
25
+ return ls
26
+
27
+
28
+ class KorRobertaEmbeddings(Embeddings):
29
+ """Feature Extraction w/ BM-K/KoSimCSE-roberta"""
30
+
31
+ dimension = 768
32
+
33
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
34
+ return get_roberta_embeddings(texts)
35
+
36
+ def embed_query(self, text: str) -> List[float]:
37
+ return get_roberta_embeddings([text])[0]
requirements.txt ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -i https://pypi.org/simple
2
+ aiohttp==3.9.3; python_version >= '3.8'
3
+ aiosignal==1.3.1; python_version >= '3.7'
4
+ altair==5.2.0; python_version >= '3.8'
5
+ annotated-types==0.6.0; python_version >= '3.8'
6
+ anyio==4.3.0; python_version >= '3.8'
7
+ async-timeout==4.0.3; python_version < '3.11'
8
+ attrs==23.2.0; python_version >= '3.7'
9
+ blinker==1.7.0; python_version >= '3.8'
10
+ cachetools==5.3.2; python_version >= '3.7'
11
+ certifi==2024.2.2; python_version >= '3.6'
12
+ charset-normalizer==3.3.2; python_full_version >= '3.7.0'
13
+ click==8.1.7; python_version >= '3.7'
14
+ dataclasses-json==0.6.4; python_version >= '3.7' and python_version < '4.0'
15
+ exceptiongroup==1.2.0; python_version < '3.11'
16
+ filelock==3.13.1; python_version >= '3.8'
17
+ frozenlist==1.4.1; python_version >= '3.8'
18
+ fsspec==2024.2.0; python_version >= '3.8'
19
+ gitdb==4.0.11; python_version >= '3.7'
20
+ gitpython==3.1.42; python_version >= '3.7'
21
+ huggingface-hub==0.20.3; python_full_version >= '3.8.0'
22
+ idna==3.6; python_version >= '3.5'
23
+ importlib-metadata==7.0.1; python_version >= '3.8'
24
+ jinja2==3.1.3; python_version >= '3.7'
25
+ jsonpatch==1.33; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4, 3.5, 3.6'
26
+ jsonpointer==2.4; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4, 3.5, 3.6'
27
+ jsonschema==4.21.1; python_version >= '3.8'
28
+ jsonschema-specifications==2023.12.1; python_version >= '3.8'
29
+ langchain==0.1.8; python_version < '4.0' and python_full_version >= '3.8.1'
30
+ langchain-community==0.0.21; python_version < '4.0' and python_full_version >= '3.8.1'
31
+ langchain-core==0.1.25; python_version < '4.0' and python_full_version >= '3.8.1'
32
+ langchain-pinecone==0.0.2; python_version < '3.13' and python_full_version >= '3.8.1'
33
+ langsmith==0.1.5; python_version < '4.0' and python_full_version >= '3.8.1'
34
+ markdown-it-py==3.0.0; python_version >= '3.8'
35
+ markupsafe==2.1.5; python_version >= '3.7'
36
+ marshmallow==3.20.2; python_version >= '3.8'
37
+ mdurl==0.1.2; python_version >= '3.7'
38
+ mpmath==1.3.0
39
+ multidict==6.0.5; python_version >= '3.7'
40
+ mypy-extensions==1.0.0; python_version >= '3.5'
41
+ networkx==3.2.1; python_version >= '3.9'
42
+ numpy==1.26.4; python_version >= '3.9'
43
+ packaging==23.2; python_version >= '3.7'
44
+ pandas==2.2.0; python_version >= '3.9'
45
+ pillow==10.2.0; python_version >= '3.8'
46
+ pinecone-client==3.0.3; python_version < '3.13' and python_version >= '3.8'
47
+ protobuf==4.25.3; python_version >= '3.8'
48
+ pyarrow==15.0.0; python_version >= '3.8'
49
+ pydantic==2.6.1; python_version >= '3.8'
50
+ pydantic-core==2.16.2; python_version >= '3.8'
51
+ pydeck==0.8.1b0; python_version >= '3.7'
52
+ pygments==2.17.2; python_version >= '3.7'
53
+ python-dateutil==2.8.2; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'
54
+ pytz==2024.1
55
+ pyyaml==6.0.1; python_version >= '3.6'
56
+ referencing==0.33.0; python_version >= '3.8'
57
+ regex==2023.12.25; python_version >= '3.7'
58
+ requests==2.31.0; python_version >= '3.7'
59
+ rich==13.7.0; python_full_version >= '3.7.0'
60
+ rpds-py==0.18.0; python_version >= '3.8'
61
+ safetensors==0.4.2; python_version >= '3.7'
62
+ six==1.16.0; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'
63
+ smmap==5.0.1; python_version >= '3.7'
64
+ sniffio==1.3.0; python_version >= '3.7'
65
+ sqlalchemy==2.0.27; python_version >= '3.7'
66
+ streamlit==1.31.1; python_version >= '3.8' and python_full_version != '3.9.7'
67
+ sympy==1.12; python_version >= '3.8'
68
+ tenacity==8.2.3; python_version >= '3.7'
69
+ tokenizers==0.15.2; python_version >= '3.7'
70
+ toml==0.10.2; python_version >= '2.6' and python_version not in '3.0, 3.1, 3.2, 3.3'
71
+ toolz==0.12.1; python_version >= '3.7'
72
+ torch==2.2.0; python_full_version >= '3.8.0'
73
+ tornado==6.4; python_version >= '3.8'
74
+ tqdm==4.66.2; python_version >= '3.7'
75
+ transformers==4.37.2; python_full_version >= '3.8.0'
76
+ typing-extensions==4.9.0; python_version >= '3.8'
77
+ typing-inspect==0.9.0
78
+ tzdata==2024.1; python_version >= '2'
79
+ tzlocal==5.2; python_version >= '3.8'
80
+ urllib3==2.2.1; python_version >= '3.8'
81
+ validators==0.22.0; python_version >= '3.8'
82
+ yarl==1.9.4; python_version >= '3.7'
83
+ zipp==3.17.0; python_version >= '3.8'