Joshua Sundance Bailey commited on
Commit
c132355
1 Parent(s): d94be33

rm bm25 & fix docstore kwarg

Browse files
langchain-streamlit-demo/app.py CHANGED
@@ -17,9 +17,8 @@ from streamlit_feedback import streamlit_feedback
17
  from defaults import default_values
18
 
19
  from llm_resources import (
20
- get_runnable,
21
- get_llm,
22
- get_texts_and_retriever,
23
  get_texts_and_multiretriever,
24
  StreamHandler,
25
  )
 
17
  from defaults import default_values
18
 
19
  from llm_resources import (
20
+ get_runnable,
21
+ get_llm,
 
22
  get_texts_and_multiretriever,
23
  StreamHandler,
24
  )
langchain-streamlit-demo/llm_resources.py CHANGED
@@ -11,7 +11,7 @@ from langchain.chat_models import (
11
  )
12
  from langchain.document_loaders import PyPDFLoader
13
  from langchain.embeddings import AzureOpenAIEmbeddings, OpenAIEmbeddings
14
- from langchain.retrievers import BM25Retriever, EnsembleRetriever
15
  from langchain.schema import Document, BaseRetriever
16
  from langchain.text_splitter import RecursiveCharacterTextSplitter
17
  from langchain.vectorstores import FAISS
@@ -116,48 +116,6 @@ def get_llm(
116
  return None
117
 
118
 
119
- def get_texts_and_retriever(
120
- uploaded_file_bytes: bytes,
121
- openai_api_key: str,
122
- chunk_size: int = DEFAULT_CHUNK_SIZE,
123
- chunk_overlap: int = DEFAULT_CHUNK_OVERLAP,
124
- k: int = DEFAULT_RETRIEVER_K,
125
- azure_kwargs: Optional[Dict[str, str]] = None,
126
- use_azure: bool = False,
127
- ) -> Tuple[List[Document], BaseRetriever]:
128
- with NamedTemporaryFile() as temp_file:
129
- temp_file.write(uploaded_file_bytes)
130
- temp_file.seek(0)
131
-
132
- loader = PyPDFLoader(temp_file.name)
133
- documents = loader.load()
134
- text_splitter = RecursiveCharacterTextSplitter(
135
- chunk_size=chunk_size,
136
- chunk_overlap=chunk_overlap,
137
- )
138
- texts = text_splitter.split_documents(documents)
139
- embeddings_kwargs = {"openai_api_key": openai_api_key}
140
- if use_azure and azure_kwargs:
141
- azure_kwargs["azure_endpoint"] = azure_kwargs.pop("openai_api_base")
142
- embeddings_kwargs.update(azure_kwargs)
143
- embeddings = AzureOpenAIEmbeddings(**embeddings_kwargs)
144
- else:
145
- embeddings = OpenAIEmbeddings(**embeddings_kwargs)
146
-
147
- bm25_retriever = BM25Retriever.from_documents(texts)
148
- bm25_retriever.k = k
149
-
150
- faiss_vectorstore = FAISS.from_documents(texts, embeddings)
151
- faiss_retriever = faiss_vectorstore.as_retriever(search_kwargs={"k": k})
152
-
153
- ensemble_retriever = EnsembleRetriever(
154
- retrievers=[bm25_retriever, faiss_retriever],
155
- weights=[0.5, 0.5],
156
- )
157
-
158
- return texts, ensemble_retriever
159
-
160
-
161
  def get_texts_and_multiretriever(
162
  uploaded_file_bytes: bytes,
163
  openai_api_key: str,
@@ -204,7 +162,7 @@ def get_texts_and_multiretriever(
204
  multivectorstore = FAISS.from_documents(sub_texts, embeddings)
205
  multivector_retriever = MultiVectorRetriever(
206
  vectorstore=multivectorstore,
207
- base_store=store,
208
  id_key=id_key,
209
  )
210
  multivector_retriever.docstore.mset(list(zip(text_ids, texts)))
 
11
  )
12
  from langchain.document_loaders import PyPDFLoader
13
  from langchain.embeddings import AzureOpenAIEmbeddings, OpenAIEmbeddings
14
+ from langchain.retrievers import EnsembleRetriever
15
  from langchain.schema import Document, BaseRetriever
16
  from langchain.text_splitter import RecursiveCharacterTextSplitter
17
  from langchain.vectorstores import FAISS
 
116
  return None
117
 
118
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  def get_texts_and_multiretriever(
120
  uploaded_file_bytes: bytes,
121
  openai_api_key: str,
 
162
  multivectorstore = FAISS.from_documents(sub_texts, embeddings)
163
  multivector_retriever = MultiVectorRetriever(
164
  vectorstore=multivectorstore,
165
+ docstore=store,
166
  id_key=id_key,
167
  )
168
  multivector_retriever.docstore.mset(list(zip(text_ids, texts)))
requirements.txt CHANGED
@@ -7,7 +7,6 @@ openai==1.3.8
7
  pillow>=10.0.1 # not directly required, pinned by Snyk to avoid a vulnerability
8
  pyarrow>=14.0.1 # not directly required, pinned by Snyk to avoid a vulnerability
9
  pypdf==3.17.2
10
- rank_bm25==0.2.2
11
  streamlit==1.29.0
12
  streamlit-feedback==0.1.3
13
  tiktoken==0.5.2
 
7
  pillow>=10.0.1 # not directly required, pinned by Snyk to avoid a vulnerability
8
  pyarrow>=14.0.1 # not directly required, pinned by Snyk to avoid a vulnerability
9
  pypdf==3.17.2
 
10
  streamlit==1.29.0
11
  streamlit-feedback==0.1.3
12
  tiktoken==0.5.2