Spaces:
Sleeping
Sleeping
Sean-Case
commited on
Commit
·
78d71d4
1
Parent(s):
ba838fc
Added basic semantic search functionality
Browse files- app.py +347 -26
- search_funcs/chatfuncs.py +393 -0
- search_funcs/ingest.py +417 -0
app.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
import nltk
|
2 |
-
|
3 |
nltk.download('names')
|
4 |
nltk.download('stopwords')
|
5 |
nltk.download('wordnet')
|
@@ -9,11 +9,26 @@ from search_funcs.fast_bm25 import BM25
|
|
9 |
from search_funcs.clean_funcs import initial_clean, get_lemma_tokens#, stem_sentence
|
10 |
from nltk import word_tokenize
|
11 |
|
|
|
12 |
|
13 |
import gradio as gr
|
14 |
import pandas as pd
|
15 |
import os
|
16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
def prepare_input_data(in_file, text_column, clean="No", progress=gr.Progress()):
|
18 |
|
19 |
filename = in_file.name
|
@@ -63,7 +78,7 @@ def get_file_path_end(file_path):
|
|
63 |
|
64 |
return filename_without_extension
|
65 |
|
66 |
-
def save_prepared_data(in_file, prepared_text_list, in_df,
|
67 |
|
68 |
# Check if the list and the dataframe have the same length
|
69 |
if len(prepared_text_list) != len(in_df):
|
@@ -73,10 +88,10 @@ def save_prepared_data(in_file, prepared_text_list, in_df, in_column):
|
|
73 |
|
74 |
file_name = get_file_path_end(in_file.name) + "_cleaned" + file_end
|
75 |
|
76 |
-
prepared_text_df = pd.DataFrame(data={
|
77 |
|
78 |
# Drop original column from input file to reduce file size
|
79 |
-
in_df = in_df.drop(
|
80 |
|
81 |
prepared_df = pd.concat([in_df, prepared_text_df], axis = 1)
|
82 |
|
@@ -194,7 +209,7 @@ def read_file(filename):
|
|
194 |
elif file_type == 'parquet':
|
195 |
return pd.read_parquet(filename).reset_index().drop(["index", "Unnamed: 0"], axis=1, errors="ignore")
|
196 |
|
197 |
-
def put_columns_in_df(in_file,
|
198 |
'''
|
199 |
When file is loaded, update the column dropdown choices and change 'clean data' dropdown option to 'no'.
|
200 |
'''
|
@@ -213,12 +228,12 @@ def put_columns_in_df(in_file, in_column):
|
|
213 |
return gr.Dropdown(choices=concat_choices), gr.Dropdown(value="No", choices = ["Yes", "No"]),\
|
214 |
gr.Dropdown(choices=concat_choices)
|
215 |
|
216 |
-
def put_columns_in_join_df(in_file,
|
217 |
'''
|
218 |
When file is loaded, update the column dropdown choices and change 'clean data' dropdown option to 'no'.
|
219 |
'''
|
220 |
|
221 |
-
print("
|
222 |
|
223 |
new_choices = []
|
224 |
concat_choices = []
|
@@ -241,11 +256,293 @@ def dummy_function(gradio_component):
|
|
241 |
|
242 |
def display_info(info_component):
|
243 |
gr.Info(info_component)
|
244 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
245 |
# ## Gradio app - BM25 search
|
246 |
block = gr.Blocks(theme = gr.themes.Base())
|
247 |
|
248 |
-
with block:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
249 |
|
250 |
corpus_state = gr.State()
|
251 |
data_state = gr.State(pd.DataFrame())
|
@@ -267,14 +564,18 @@ depends on factors such as the type of documents or queries. Information taken f
|
|
267 |
# Fast text search
|
268 |
Enter a text query below to search through a text data column and find relevant terms. It will only find terms containing the exact text you enter. Your data should contain at least 20 entries for the search to consistently return results.
|
269 |
""")
|
|
|
270 |
|
271 |
with gr.Tab(label="Search your data"):
|
|
|
|
|
|
|
272 |
with gr.Accordion(label = "Load in data", open=True):
|
273 |
-
|
274 |
with gr.Row():
|
275 |
-
|
276 |
|
277 |
-
|
278 |
|
279 |
|
280 |
with gr.Row():
|
@@ -291,6 +592,21 @@ depends on factors such as the type of documents or queries. Information taken f
|
|
291 |
with gr.Row():
|
292 |
output_single_text = gr.Textbox(label="Top result")
|
293 |
output_file = gr.File(label="File output")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
294 |
|
295 |
|
296 |
with gr.Tab(label="Advanced options"):
|
@@ -327,28 +643,33 @@ depends on factors such as the type of documents or queries. Information taken f
|
|
327 |
in_no_search_results_button.click(display_info, inputs=in_no_search_info)
|
328 |
|
329 |
|
330 |
-
|
|
|
331 |
in_join_file.upload(put_columns_in_join_df, inputs=[in_join_file, in_join_column], outputs=[in_join_column])
|
332 |
-
|
333 |
-
# Load in
|
334 |
-
|
335 |
then(fn=prepare_bm25, inputs=[corpus_state, in_k1, in_b, in_alpha], outputs=[load_finished_message]).\
|
336 |
-
then(fn=put_columns_in_df, inputs=[
|
337 |
-
|
338 |
-
#
|
|
|
|
|
339 |
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
|
|
344 |
|
345 |
-
|
346 |
-
|
347 |
|
348 |
# Dummy functions just to get dropdowns to work correctly with Gradio 3.50
|
349 |
-
|
350 |
search_df_join_column.change(dummy_function, search_df_join_column, None)
|
351 |
in_join_column.change(dummy_function, in_join_column, None)
|
|
|
352 |
|
353 |
block.queue().launch(debug=True)
|
354 |
|
|
|
1 |
import nltk
|
2 |
+
from typing import TypeVar
|
3 |
nltk.download('names')
|
4 |
nltk.download('stopwords')
|
5 |
nltk.download('wordnet')
|
|
|
9 |
from search_funcs.clean_funcs import initial_clean, get_lemma_tokens#, stem_sentence
|
10 |
from nltk import word_tokenize
|
11 |
|
12 |
+
PandasDataFrame = TypeVar('pd.core.frame.DataFrame')
|
13 |
|
14 |
import gradio as gr
|
15 |
import pandas as pd
|
16 |
import os
|
17 |
|
18 |
+
from itertools import compress
|
19 |
+
|
20 |
+
#from langchain.embeddings import HuggingFaceEmbeddings
|
21 |
+
#from langchain.vectorstores import FAISS
|
22 |
+
from transformers import AutoModel
|
23 |
+
|
24 |
+
import search_funcs.ingest as ing
|
25 |
+
import search_funcs.chatfuncs as chatf
|
26 |
+
|
27 |
+
# Import Chroma and instantiate a client. The default Chroma client is ephemeral, meaning it will not save to disk.
|
28 |
+
import chromadb
|
29 |
+
|
30 |
+
#collection = client.create_collection(name="my_collection")
|
31 |
+
|
32 |
def prepare_input_data(in_file, text_column, clean="No", progress=gr.Progress()):
|
33 |
|
34 |
filename = in_file.name
|
|
|
78 |
|
79 |
return filename_without_extension
|
80 |
|
81 |
+
def save_prepared_data(in_file, prepared_text_list, in_df, in_bm25_column):
|
82 |
|
83 |
# Check if the list and the dataframe have the same length
|
84 |
if len(prepared_text_list) != len(in_df):
|
|
|
88 |
|
89 |
file_name = get_file_path_end(in_file.name) + "_cleaned" + file_end
|
90 |
|
91 |
+
prepared_text_df = pd.DataFrame(data={in_bm25_column + "_cleaned":prepared_text_list})
|
92 |
|
93 |
# Drop original column from input file to reduce file size
|
94 |
+
in_df = in_df.drop(in_bm25_column, axis = 1)
|
95 |
|
96 |
prepared_df = pd.concat([in_df, prepared_text_df], axis = 1)
|
97 |
|
|
|
209 |
elif file_type == 'parquet':
|
210 |
return pd.read_parquet(filename).reset_index().drop(["index", "Unnamed: 0"], axis=1, errors="ignore")
|
211 |
|
212 |
+
def put_columns_in_df(in_file, in_bm25_column):
|
213 |
'''
|
214 |
When file is loaded, update the column dropdown choices and change 'clean data' dropdown option to 'no'.
|
215 |
'''
|
|
|
228 |
return gr.Dropdown(choices=concat_choices), gr.Dropdown(value="No", choices = ["Yes", "No"]),\
|
229 |
gr.Dropdown(choices=concat_choices)
|
230 |
|
231 |
+
def put_columns_in_join_df(in_file, in_bm25_column):
|
232 |
'''
|
233 |
When file is loaded, update the column dropdown choices and change 'clean data' dropdown option to 'no'.
|
234 |
'''
|
235 |
|
236 |
+
print("in_bm25_column")
|
237 |
|
238 |
new_choices = []
|
239 |
concat_choices = []
|
|
|
256 |
|
257 |
def display_info(info_component):
|
258 |
gr.Info(info_component)
|
259 |
+
|
260 |
+
embeddings_name = "jinaai/jina-embeddings-v2-small-en"
|
261 |
+
|
262 |
+
#embeddings_name = "BAAI/bge-base-en-v1.5"
|
263 |
+
import chromadb
|
264 |
+
from typing_extensions import Protocol
|
265 |
+
from chromadb import Documents, EmbeddingFunction, Embeddings
|
266 |
+
|
267 |
+
embeddings_model = AutoModel.from_pretrained(embeddings_name, trust_remote_code=True)
|
268 |
+
|
269 |
+
class MyEmbeddingFunction(EmbeddingFunction):
|
270 |
+
def __call__(self, input) -> Embeddings:
|
271 |
+
|
272 |
+
|
273 |
+
embeddings = []
|
274 |
+
for text in input:
|
275 |
+
embeddings.append(embeddings_model.encode(text))
|
276 |
+
|
277 |
+
return embeddings
|
278 |
+
|
279 |
+
|
280 |
+
def load_embeddings(embeddings_name = "jinaai/jina-embeddings-v2-small-en"):
|
281 |
+
'''
|
282 |
+
Load embeddings model and create a global variable based on it.
|
283 |
+
'''
|
284 |
+
|
285 |
+
# Import Chroma and instantiate a client. The default Chroma client is ephemeral, meaning it will not save to disk.
|
286 |
+
|
287 |
+
#else:
|
288 |
+
embeddings_func = AutoModel.from_pretrained(embeddings_name, trust_remote_code=True)
|
289 |
+
|
290 |
+
global embeddings
|
291 |
+
|
292 |
+
embeddings = embeddings_func
|
293 |
+
|
294 |
+
return embeddings
|
295 |
+
|
296 |
+
embeddings = load_embeddings(embeddings_name)
|
297 |
+
|
298 |
+
def docs_to_chroma_save(docs_out, embeddings=embeddings, progress=gr.Progress()):
|
299 |
+
'''
|
300 |
+
Takes a Langchain document class and saves it into a Chroma sqlite file.
|
301 |
+
'''
|
302 |
+
|
303 |
+
|
304 |
+
|
305 |
+
print(f"> Total split documents: {len(docs_out)}")
|
306 |
+
|
307 |
+
#print(docs_out)
|
308 |
+
|
309 |
+
page_contents = [doc.page_content for doc in docs_out]
|
310 |
+
page_meta = [doc.metadata for doc in docs_out]
|
311 |
+
ids_range = range(0,len(page_contents))
|
312 |
+
ids = [str(element) for element in ids_range]
|
313 |
+
|
314 |
+
embeddings_list = []
|
315 |
+
for page in progress.tqdm(page_contents, desc = "Preparing search index", unit = "rows"):
|
316 |
+
embeddings_list.append(embeddings.encode(sentences=page, max_length=1024).tolist())
|
317 |
+
|
318 |
+
|
319 |
+
client = chromadb.PersistentClient(path=".")
|
320 |
+
|
321 |
+
# Create a new Chroma collection to store the supporting evidence. We don't need to specify an embedding fuction, and the default will be used.
|
322 |
+
try:
|
323 |
+
collection = client.get_collection(name="my_collection")
|
324 |
+
client.delete_collection(name="my_collection")
|
325 |
+
except:
|
326 |
+
collection = client.create_collection(name="my_collection")
|
327 |
+
|
328 |
+
collection.add(
|
329 |
+
documents = page_contents,
|
330 |
+
embeddings = embeddings_list,
|
331 |
+
metadatas = page_meta,
|
332 |
+
ids = ids)
|
333 |
+
|
334 |
+
#chatf.vectorstore = vectorstore_func
|
335 |
+
|
336 |
+
out_message = "Document processing complete"
|
337 |
+
|
338 |
+
return out_message, collection
|
339 |
+
|
340 |
+
def jina_simple_retrieval(new_question_kworded, vectorstore, docs, k_val, out_passages,
|
341 |
+
vec_score_cut_off, vec_weight): # ,vectorstore, embeddings
|
342 |
+
|
343 |
+
from numpy.linalg import norm
|
344 |
+
|
345 |
+
cos_sim = lambda a,b: (a @ b.T) / (norm(a)*norm(b))
|
346 |
+
|
347 |
+
query = embeddings.encode(new_question_kworded)
|
348 |
+
|
349 |
+
# Calculate cosine similarity with each string in the list
|
350 |
+
cosine_similarities = [cos_sim(query, string_vector) for string_vector in vectorstore]
|
351 |
+
|
352 |
+
|
353 |
+
|
354 |
+
print(cosine_similarities)
|
355 |
+
|
356 |
+
|
357 |
+
|
358 |
+
#vectorstore=globals()["vectorstore"]
|
359 |
+
#embeddings=globals()["embeddings"]
|
360 |
+
doc_df = pd.DataFrame()
|
361 |
+
|
362 |
+
|
363 |
+
docs = vectorstore.similarity_search_with_score(new_question_kworded, k=k_val)
|
364 |
+
|
365 |
+
print("Docs from similarity search:")
|
366 |
+
print(docs)
|
367 |
+
|
368 |
+
# Keep only documents with a certain score
|
369 |
+
docs_len = [len(x[0].page_content) for x in docs]
|
370 |
+
docs_scores = [x[1] for x in docs]
|
371 |
+
|
372 |
+
# Only keep sources that are sufficiently relevant (i.e. similarity search score below threshold below)
|
373 |
+
score_more_limit = pd.Series(docs_scores) < vec_score_cut_off
|
374 |
+
docs_keep = list(compress(docs, score_more_limit))
|
375 |
+
|
376 |
+
if not docs_keep:
|
377 |
+
return [], pd.DataFrame(), []
|
378 |
+
|
379 |
+
# Only keep sources that are at least 100 characters long
|
380 |
+
length_more_limit = pd.Series(docs_len) >= 100
|
381 |
+
docs_keep = list(compress(docs_keep, length_more_limit))
|
382 |
+
|
383 |
+
if not docs_keep:
|
384 |
+
return [], pd.DataFrame(), []
|
385 |
+
|
386 |
+
docs_keep_as_doc = [x[0] for x in docs_keep]
|
387 |
+
docs_keep_length = len(docs_keep_as_doc)
|
388 |
+
|
389 |
+
|
390 |
+
|
391 |
+
if docs_keep_length == 1:
|
392 |
+
|
393 |
+
content=[]
|
394 |
+
meta_url=[]
|
395 |
+
score=[]
|
396 |
+
|
397 |
+
for item in docs_keep:
|
398 |
+
content.append(item[0].page_content)
|
399 |
+
meta_url.append(item[0].metadata['source'])
|
400 |
+
score.append(item[1])
|
401 |
+
|
402 |
+
# Create df from 'winning' passages
|
403 |
+
|
404 |
+
doc_df = pd.DataFrame(list(zip(content, meta_url, score)),
|
405 |
+
columns =['page_content', 'meta_url', 'score'])
|
406 |
+
|
407 |
+
docs_content = doc_df['page_content'].astype(str)
|
408 |
+
docs_url = doc_df['meta_url']
|
409 |
+
|
410 |
+
return docs_keep_as_doc, docs_content, docs_url
|
411 |
+
|
412 |
+
# Check for if more docs are removed than the desired output
|
413 |
+
if out_passages > docs_keep_length:
|
414 |
+
out_passages = docs_keep_length
|
415 |
+
k_val = docs_keep_length
|
416 |
+
|
417 |
+
vec_rank = [*range(1, docs_keep_length+1)]
|
418 |
+
vec_score = [(docs_keep_length/x)*vec_weight for x in vec_rank]
|
419 |
+
|
420 |
+
## Calculate final score based on three ranking methods
|
421 |
+
final_score = [a for a in zip(vec_score)]
|
422 |
+
final_rank = [sorted(final_score, reverse=True).index(x)+1 for x in final_score]
|
423 |
+
# Force final_rank to increment by 1 each time
|
424 |
+
final_rank = list(pd.Series(final_rank).rank(method='first'))
|
425 |
+
|
426 |
+
#print("final rank: " + str(final_rank))
|
427 |
+
#print("out_passages: " + str(out_passages))
|
428 |
+
|
429 |
+
best_rank_index_pos = []
|
430 |
+
|
431 |
+
for x in range(1,out_passages+1):
|
432 |
+
try:
|
433 |
+
best_rank_index_pos.append(final_rank.index(x))
|
434 |
+
except IndexError: # catch the error
|
435 |
+
pass
|
436 |
+
|
437 |
+
# Adjust best_rank_index_pos to
|
438 |
+
|
439 |
+
best_rank_pos_series = pd.Series(best_rank_index_pos)
|
440 |
+
|
441 |
+
|
442 |
+
docs_keep_out = [docs_keep[i] for i in best_rank_index_pos]
|
443 |
+
|
444 |
+
# Keep only 'best' options
|
445 |
+
docs_keep_as_doc = [x[0] for x in docs_keep_out]
|
446 |
+
|
447 |
+
# Make df of best options
|
448 |
+
doc_df = create_doc_df(docs_keep_out)
|
449 |
+
|
450 |
+
return docs_keep_as_doc, doc_df, docs_keep_out
|
451 |
+
|
452 |
+
def chroma_retrieval(new_question_kworded, vectorstore, docs, k_val, out_passages,
|
453 |
+
vec_score_cut_off, vec_weight): # ,vectorstore, embeddings
|
454 |
+
|
455 |
+
query = embeddings.encode(new_question_kworded).tolist()
|
456 |
+
|
457 |
+
docs = vectorstore.query(
|
458 |
+
query_embeddings=query,
|
459 |
+
n_results= 9999 # No practical limit on number of responses returned
|
460 |
+
#where={"metadata_field": "is_equal_to_this"},
|
461 |
+
#where_document={"$contains":"search_string"}
|
462 |
+
)
|
463 |
+
|
464 |
+
# Calculate cosine similarity with each string in the list
|
465 |
+
#cosine_similarities = [cos_sim(query, string_vector) for string_vector in vectorstore]
|
466 |
+
|
467 |
+
#print(docs)
|
468 |
+
|
469 |
+
#vectorstore=globals()["vectorstore"]
|
470 |
+
#embeddings=globals()["embeddings"]
|
471 |
+
df = pd.DataFrame(data={'ids': docs['ids'][0],
|
472 |
+
'documents': docs['documents'][0],
|
473 |
+
'metadatas':docs['metadatas'][0],
|
474 |
+
'distances':docs['distances'][0]#,
|
475 |
+
#'embeddings': docs['embeddings']
|
476 |
+
})
|
477 |
+
|
478 |
+
def create_docs_keep_from_df(df):
|
479 |
+
dict_out = {'ids' : [df['ids']],
|
480 |
+
'documents': [df['documents']],
|
481 |
+
'metadatas': [df['metadatas']],
|
482 |
+
'distances': [df['distances']],
|
483 |
+
'embeddings': None
|
484 |
+
}
|
485 |
+
return dict_out
|
486 |
+
|
487 |
+
# Prepare the DataFrame by transposing
|
488 |
+
df_docs = df#.apply(lambda x: x.explode()).reset_index(drop=True)
|
489 |
+
|
490 |
+
#print(df_docs)
|
491 |
+
|
492 |
+
|
493 |
+
# Keep only documents with a certain score
|
494 |
+
|
495 |
+
docs_scores = df_docs["distances"] #.astype(float)
|
496 |
+
|
497 |
+
#print(docs_scores)
|
498 |
+
|
499 |
+
# Only keep sources that are sufficiently relevant (i.e. similarity search score below threshold below)
|
500 |
+
score_more_limit = df_docs.loc[docs_scores < vec_score_cut_off, :]
|
501 |
+
docs_keep = create_docs_keep_from_df(score_more_limit) #list(compress(docs, score_more_limit))
|
502 |
+
|
503 |
+
#print(docs_keep)
|
504 |
+
|
505 |
+
if not docs_keep:
|
506 |
+
return 'No result found!', ""
|
507 |
+
|
508 |
+
# Only keep sources that are at least 100 characters long
|
509 |
+
docs_len = score_more_limit["documents"].str.len() >= 100
|
510 |
+
length_more_limit = score_more_limit.loc[docs_len, :] #pd.Series(docs_len) >= 100
|
511 |
+
docs_keep = create_docs_keep_from_df(length_more_limit) #list(compress(docs_keep, length_more_limit))
|
512 |
+
|
513 |
+
#print(docs_keep)
|
514 |
+
|
515 |
+
print(length_more_limit)
|
516 |
+
|
517 |
+
if not docs_keep:
|
518 |
+
return 'No result found!', ""
|
519 |
+
|
520 |
+
results_df_name = "semantic_search_result.csv"
|
521 |
+
length_more_limit.to_csv(results_df_name, index= None)
|
522 |
+
results_first_text = length_more_limit["documents"][0]
|
523 |
+
|
524 |
+
|
525 |
+
return results_first_text, results_df_name
|
526 |
+
|
527 |
# ## Gradio app - BM25 search
|
528 |
block = gr.Blocks(theme = gr.themes.Base())
|
529 |
|
530 |
+
with block:
|
531 |
+
|
532 |
+
ingest_text = gr.State()
|
533 |
+
ingest_metadata = gr.State()
|
534 |
+
ingest_docs = gr.State()
|
535 |
+
vectorstore_state = gr.State() # globals()["vectorstore"]
|
536 |
+
embeddings_state = gr.State() # globals()["embeddings"]
|
537 |
+
|
538 |
+
k_val = gr.State(100)
|
539 |
+
out_passages = gr.State(100)
|
540 |
+
vec_score_cut_off = gr.State(100)
|
541 |
+
vec_weight = gr.State(1)
|
542 |
+
|
543 |
+
docs_keep_as_doc_state = gr.State()
|
544 |
+
doc_df_state = gr.State()
|
545 |
+
docs_keep_out_state = gr.State()
|
546 |
|
547 |
corpus_state = gr.State()
|
548 |
data_state = gr.State(pd.DataFrame())
|
|
|
564 |
# Fast text search
|
565 |
Enter a text query below to search through a text data column and find relevant terms. It will only find terms containing the exact text you enter. Your data should contain at least 20 entries for the search to consistently return results.
|
566 |
""")
|
567 |
+
|
568 |
|
569 |
with gr.Tab(label="Search your data"):
|
570 |
+
with gr.Row():
|
571 |
+
current_source = gr.Textbox(label="Current data source(s)", value="None")
|
572 |
+
|
573 |
with gr.Accordion(label = "Load in data", open=True):
|
574 |
+
in_bm25_file = gr.File(label="Upload your search data here")
|
575 |
with gr.Row():
|
576 |
+
in_bm25_column = gr.Dropdown(label="Enter the name of the text column in the data file to search")
|
577 |
|
578 |
+
load_bm25_data_button = gr.Button(value="Load data")
|
579 |
|
580 |
|
581 |
with gr.Row():
|
|
|
592 |
with gr.Row():
|
593 |
output_single_text = gr.Textbox(label="Top result")
|
594 |
output_file = gr.File(label="File output")
|
595 |
+
|
596 |
+
|
597 |
+
with gr.Tab("Fuzzy/semantic search"):
|
598 |
+
with gr.Accordion("CSV/Excel file", open = True):
|
599 |
+
in_semantic_file = gr.File(label="Upload data file for semantic search")
|
600 |
+
in_semantic_column = gr.Dropdown(label="Enter the name of the text column in the data file to search")
|
601 |
+
load_semantic_data_button = gr.Button(value="Load in CSV/Excel file", variant="secondary", scale=0)
|
602 |
+
|
603 |
+
ingest_embed_out = gr.Textbox(label="File/web page preparation progress")
|
604 |
+
semantic_query = gr.Textbox(label="Enter semantic search query here")
|
605 |
+
semantic_submit = gr.Button(value="Start semantic search", variant="secondary", scale = 1)
|
606 |
+
|
607 |
+
with gr.Row():
|
608 |
+
semantic_output_single_text = gr.Textbox(label="Top result")
|
609 |
+
semantic_output_file = gr.File(label="File output")
|
610 |
|
611 |
|
612 |
with gr.Tab(label="Advanced options"):
|
|
|
643 |
in_no_search_results_button.click(display_info, inputs=in_no_search_info)
|
644 |
|
645 |
|
646 |
+
# Update dropdowns upon initial file load
|
647 |
+
in_bm25_file.upload(put_columns_in_df, inputs=[in_bm25_file, in_bm25_column], outputs=[in_bm25_column, in_clean_data, search_df_join_column])
|
648 |
in_join_file.upload(put_columns_in_join_df, inputs=[in_join_file, in_join_column], outputs=[in_join_column])
|
649 |
+
|
650 |
+
# Load in BM25 data
|
651 |
+
load_bm25_data_button.click(fn=prepare_input_data, inputs=[in_bm25_file, in_bm25_column, in_clean_data], outputs=[corpus_state, load_finished_message, data_state, output_file]).\
|
652 |
then(fn=prepare_bm25, inputs=[corpus_state, in_k1, in_b, in_alpha], outputs=[load_finished_message]).\
|
653 |
+
then(fn=put_columns_in_df, inputs=[in_bm25_file, in_bm25_column], outputs=[in_bm25_column, in_clean_data, search_df_join_column])
|
654 |
+
|
655 |
+
# BM25 search functions on click or enter
|
656 |
+
search_button.click(fn=bm25_search, inputs=[in_query, in_no_search_results, data_state, in_bm25_column, in_clean_data, in_join_file, in_join_column, search_df_join_column], outputs=[output_single_text, output_file, mod_query], api_name="search")
|
657 |
+
in_query.submit(fn=bm25_search, inputs=[in_query, in_no_search_results, data_state, in_bm25_column, in_clean_data, in_join_file, in_join_column, search_df_join_column], outputs=[output_single_text, output_file, mod_query])
|
658 |
|
659 |
+
# Load in a csv/excel file for semantic search
|
660 |
+
in_semantic_file.upload(put_columns_in_df, inputs=[in_semantic_file, in_semantic_column], outputs=[in_semantic_column, in_clean_data, search_df_join_column])
|
661 |
+
load_semantic_data_button.click(ing.parse_csv_or_excel, inputs=[in_semantic_file, in_semantic_column], outputs=[ingest_text, current_source]).\
|
662 |
+
then(ing.csv_excel_text_to_docs, inputs=[ingest_text, in_semantic_column], outputs=[ingest_docs]).\
|
663 |
+
then(docs_to_chroma_save, inputs=[ingest_docs], outputs=[ingest_embed_out, vectorstore_state])
|
664 |
|
665 |
+
# Semantic search query
|
666 |
+
semantic_submit.click(chroma_retrieval, inputs=[semantic_query, vectorstore_state, ingest_docs, k_val,out_passages, vec_score_cut_off, vec_weight], outputs=[semantic_output_single_text, semantic_output_file], api_name="semantic")
|
667 |
|
668 |
# Dummy functions just to get dropdowns to work correctly with Gradio 3.50
|
669 |
+
in_bm25_column.change(dummy_function, in_bm25_column, None)
|
670 |
search_df_join_column.change(dummy_function, search_df_join_column, None)
|
671 |
in_join_column.change(dummy_function, in_join_column, None)
|
672 |
+
in_semantic_column.change(dummy_function, in_join_column, None)
|
673 |
|
674 |
block.queue().launch(debug=True)
|
675 |
|
search_funcs/chatfuncs.py
ADDED
@@ -0,0 +1,393 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import os
|
3 |
+
from typing import TypeVar, List
|
4 |
+
import pandas as pd
|
5 |
+
|
6 |
+
|
7 |
+
# Model packages
|
8 |
+
import torch.cuda
|
9 |
+
|
10 |
+
# Alternative model sources
|
11 |
+
#from dataclasses import asdict, dataclass
|
12 |
+
|
13 |
+
# Langchain functions
|
14 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
15 |
+
from langchain.docstore.document import Document
|
16 |
+
|
17 |
+
# For keyword extraction (not currently used)
|
18 |
+
#import nltk
|
19 |
+
#nltk.download('wordnet')
|
20 |
+
from nltk.corpus import stopwords
|
21 |
+
from nltk.tokenize import RegexpTokenizer
|
22 |
+
from nltk.stem import WordNetLemmatizer
|
23 |
+
|
24 |
+
# For Name Entity Recognition model
|
25 |
+
#from span_marker import SpanMarkerModel # Not currently used
|
26 |
+
|
27 |
+
|
28 |
+
import gradio as gr
|
29 |
+
|
30 |
+
torch.cuda.empty_cache()
|
31 |
+
|
32 |
+
PandasDataFrame = TypeVar('pd.core.frame.DataFrame')
|
33 |
+
|
34 |
+
embeddings = None # global variable setup
|
35 |
+
vectorstore = None # global variable setup
|
36 |
+
model_type = None # global variable setup
|
37 |
+
|
38 |
+
max_memory_length = 0 # How long should the memory of the conversation last?
|
39 |
+
|
40 |
+
full_text = "" # Define dummy source text (full text) just to enable highlight function to load
|
41 |
+
|
42 |
+
model = [] # Define empty list for model functions to run
|
43 |
+
tokenizer = [] # Define empty list for model functions to run
|
44 |
+
|
45 |
+
## Highlight text constants
|
46 |
+
hlt_chunk_size = 12
|
47 |
+
hlt_strat = [" ", ". ", "! ", "? ", ": ", "\n\n", "\n", ", "]
|
48 |
+
hlt_overlap = 4
|
49 |
+
|
50 |
+
## Initialise NER model ##
|
51 |
+
ner_model = []#SpanMarkerModel.from_pretrained("tomaarsen/span-marker-mbert-base-multinerd") # Not currently used
|
52 |
+
|
53 |
+
|
54 |
+
# Currently set gpu_layers to 0 even with cuda due to persistent bugs in implementation with cuda
|
55 |
+
if torch.cuda.is_available():
|
56 |
+
torch_device = "cuda"
|
57 |
+
gpu_layers = 0
|
58 |
+
else:
|
59 |
+
torch_device = "cpu"
|
60 |
+
gpu_layers = 0
|
61 |
+
|
62 |
+
print("Running on device:", torch_device)
|
63 |
+
threads = 6 #torch.get_num_threads()
|
64 |
+
print("CPU threads:", threads)
|
65 |
+
|
66 |
+
# Vectorstore funcs
|
67 |
+
|
68 |
+
# Prompt functions
|
69 |
+
|
70 |
+
def write_out_metadata_as_string(metadata_in):
|
71 |
+
metadata_string = [f"{' '.join(f'{k}: {v}' for k, v in d.items() if k != 'page_section')}" for d in metadata_in] # ['metadata']
|
72 |
+
return metadata_string
|
73 |
+
|
74 |
+
|
75 |
+
def determine_file_type(file_path):
|
76 |
+
"""
|
77 |
+
Determine the file type based on its extension.
|
78 |
+
|
79 |
+
Parameters:
|
80 |
+
file_path (str): Path to the file.
|
81 |
+
|
82 |
+
Returns:
|
83 |
+
str: File extension (e.g., '.pdf', '.docx', '.txt', '.html').
|
84 |
+
"""
|
85 |
+
return os.path.splitext(file_path)[1].lower()
|
86 |
+
|
87 |
+
|
88 |
+
def create_doc_df(docs_keep_out):
|
89 |
+
# Extract content and metadata from 'winning' passages.
|
90 |
+
content=[]
|
91 |
+
meta=[]
|
92 |
+
meta_url=[]
|
93 |
+
page_section=[]
|
94 |
+
score=[]
|
95 |
+
|
96 |
+
doc_df = pd.DataFrame()
|
97 |
+
|
98 |
+
|
99 |
+
|
100 |
+
for item in docs_keep_out:
|
101 |
+
content.append(item[0].page_content)
|
102 |
+
meta.append(item[0].metadata)
|
103 |
+
meta_url.append(item[0].metadata['source'])
|
104 |
+
|
105 |
+
file_extension = determine_file_type(item[0].metadata['source'])
|
106 |
+
if (file_extension != ".csv") & (file_extension != ".xlsx"):
|
107 |
+
page_section.append(item[0].metadata['page_section'])
|
108 |
+
else: page_section.append("")
|
109 |
+
score.append(item[1])
|
110 |
+
|
111 |
+
# Create df from 'winning' passages
|
112 |
+
|
113 |
+
doc_df = pd.DataFrame(list(zip(content, meta, page_section, meta_url, score)),
|
114 |
+
columns =['page_content', 'metadata', 'page_section', 'meta_url', 'score'])
|
115 |
+
|
116 |
+
docs_content = doc_df['page_content'].astype(str)
|
117 |
+
doc_df['full_url'] = "https://" + doc_df['meta_url']
|
118 |
+
|
119 |
+
return doc_df
|
120 |
+
|
121 |
+
|
122 |
+
def get_expanded_passages(vectorstore, docs, width):
|
123 |
+
|
124 |
+
"""
|
125 |
+
Extracts expanded passages based on given documents and a width for context.
|
126 |
+
|
127 |
+
Parameters:
|
128 |
+
- vectorstore: The primary data source.
|
129 |
+
- docs: List of documents to be expanded.
|
130 |
+
- width: Number of documents to expand around a given document for context.
|
131 |
+
|
132 |
+
Returns:
|
133 |
+
- expanded_docs: List of expanded Document objects.
|
134 |
+
- doc_df: DataFrame representation of expanded_docs.
|
135 |
+
"""
|
136 |
+
|
137 |
+
from collections import defaultdict
|
138 |
+
|
139 |
+
def get_docs_from_vstore(vectorstore):
|
140 |
+
vector = vectorstore.docstore._dict
|
141 |
+
return list(vector.items())
|
142 |
+
|
143 |
+
def extract_details(docs_list):
|
144 |
+
docs_list_out = [tup[1] for tup in docs_list]
|
145 |
+
content = [doc.page_content for doc in docs_list_out]
|
146 |
+
meta = [doc.metadata for doc in docs_list_out]
|
147 |
+
return ''.join(content), meta[0], meta[-1]
|
148 |
+
|
149 |
+
def get_parent_content_and_meta(vstore_docs, width, target):
|
150 |
+
#target_range = range(max(0, target - width), min(len(vstore_docs), target + width + 1))
|
151 |
+
target_range = range(max(0, target), min(len(vstore_docs), target + width + 1)) # Now only selects extra passages AFTER the found passage
|
152 |
+
parent_vstore_out = [vstore_docs[i] for i in target_range]
|
153 |
+
|
154 |
+
content_str_out, meta_first_out, meta_last_out = [], [], []
|
155 |
+
for _ in parent_vstore_out:
|
156 |
+
content_str, meta_first, meta_last = extract_details(parent_vstore_out)
|
157 |
+
content_str_out.append(content_str)
|
158 |
+
meta_first_out.append(meta_first)
|
159 |
+
meta_last_out.append(meta_last)
|
160 |
+
return content_str_out, meta_first_out, meta_last_out
|
161 |
+
|
162 |
+
def merge_dicts_except_source(d1, d2):
|
163 |
+
merged = {}
|
164 |
+
for key in d1:
|
165 |
+
if key != "source":
|
166 |
+
merged[key] = str(d1[key]) + " to " + str(d2[key])
|
167 |
+
else:
|
168 |
+
merged[key] = d1[key] # or d2[key], based on preference
|
169 |
+
return merged
|
170 |
+
|
171 |
+
def merge_two_lists_of_dicts(list1, list2):
|
172 |
+
return [merge_dicts_except_source(d1, d2) for d1, d2 in zip(list1, list2)]
|
173 |
+
|
174 |
+
# Step 1: Filter vstore_docs
|
175 |
+
vstore_docs = get_docs_from_vstore(vectorstore)
|
176 |
+
doc_sources = {doc.metadata['source'] for doc, _ in docs}
|
177 |
+
vstore_docs = [(k, v) for k, v in vstore_docs if v.metadata.get('source') in doc_sources]
|
178 |
+
|
179 |
+
# Step 2: Group by source and proceed
|
180 |
+
vstore_by_source = defaultdict(list)
|
181 |
+
for k, v in vstore_docs:
|
182 |
+
vstore_by_source[v.metadata['source']].append((k, v))
|
183 |
+
|
184 |
+
expanded_docs = []
|
185 |
+
for doc, score in docs:
|
186 |
+
search_source = doc.metadata['source']
|
187 |
+
|
188 |
+
|
189 |
+
#if file_type == ".csv" | file_type == ".xlsx":
|
190 |
+
# content_str, meta_first, meta_last = get_parent_content_and_meta(vstore_by_source[search_source], 0, search_index)
|
191 |
+
|
192 |
+
#else:
|
193 |
+
search_section = doc.metadata['page_section']
|
194 |
+
parent_vstore_meta_section = [doc.metadata['page_section'] for _, doc in vstore_by_source[search_source]]
|
195 |
+
search_index = parent_vstore_meta_section.index(search_section) if search_section in parent_vstore_meta_section else -1
|
196 |
+
|
197 |
+
content_str, meta_first, meta_last = get_parent_content_and_meta(vstore_by_source[search_source], width, search_index)
|
198 |
+
meta_full = merge_two_lists_of_dicts(meta_first, meta_last)
|
199 |
+
|
200 |
+
expanded_doc = (Document(page_content=content_str[0], metadata=meta_full[0]), score)
|
201 |
+
expanded_docs.append(expanded_doc)
|
202 |
+
|
203 |
+
doc_df = pd.DataFrame()
|
204 |
+
|
205 |
+
doc_df = create_doc_df(expanded_docs) # Assuming you've defined the 'create_doc_df' function elsewhere
|
206 |
+
|
207 |
+
return expanded_docs, doc_df
|
208 |
+
|
209 |
+
def highlight_found_text(search_text: str, full_text: str, hlt_chunk_size:int=hlt_chunk_size, hlt_strat:List=hlt_strat, hlt_overlap:int=hlt_overlap) -> str:
|
210 |
+
"""
|
211 |
+
Highlights occurrences of search_text within full_text.
|
212 |
+
|
213 |
+
Parameters:
|
214 |
+
- search_text (str): The text to be searched for within full_text.
|
215 |
+
- full_text (str): The text within which search_text occurrences will be highlighted.
|
216 |
+
|
217 |
+
Returns:
|
218 |
+
- str: A string with occurrences of search_text highlighted.
|
219 |
+
|
220 |
+
Example:
|
221 |
+
>>> highlight_found_text("world", "Hello, world! This is a test. Another world awaits.")
|
222 |
+
'Hello, <mark style="color:black;">world</mark>! This is a test. Another <mark style="color:black;">world</mark> awaits.'
|
223 |
+
"""
|
224 |
+
|
225 |
+
def extract_text_from_input(text, i=0):
|
226 |
+
if isinstance(text, str):
|
227 |
+
return text.replace(" ", " ").strip()
|
228 |
+
elif isinstance(text, list):
|
229 |
+
return text[i][0].replace(" ", " ").strip()
|
230 |
+
else:
|
231 |
+
return ""
|
232 |
+
|
233 |
+
def extract_search_text_from_input(text):
|
234 |
+
if isinstance(text, str):
|
235 |
+
return text.replace(" ", " ").strip()
|
236 |
+
elif isinstance(text, list):
|
237 |
+
return text[-1][1].replace(" ", " ").strip()
|
238 |
+
else:
|
239 |
+
return ""
|
240 |
+
|
241 |
+
full_text = extract_text_from_input(full_text)
|
242 |
+
search_text = extract_search_text_from_input(search_text)
|
243 |
+
|
244 |
+
|
245 |
+
|
246 |
+
text_splitter = RecursiveCharacterTextSplitter(
|
247 |
+
chunk_size=hlt_chunk_size,
|
248 |
+
separators=hlt_strat,
|
249 |
+
chunk_overlap=hlt_overlap,
|
250 |
+
)
|
251 |
+
sections = text_splitter.split_text(search_text)
|
252 |
+
|
253 |
+
found_positions = {}
|
254 |
+
for x in sections:
|
255 |
+
text_start_pos = 0
|
256 |
+
while text_start_pos != -1:
|
257 |
+
text_start_pos = full_text.find(x, text_start_pos)
|
258 |
+
if text_start_pos != -1:
|
259 |
+
found_positions[text_start_pos] = text_start_pos + len(x)
|
260 |
+
text_start_pos += 1
|
261 |
+
|
262 |
+
# Combine overlapping or adjacent positions
|
263 |
+
sorted_starts = sorted(found_positions.keys())
|
264 |
+
combined_positions = []
|
265 |
+
if sorted_starts:
|
266 |
+
current_start, current_end = sorted_starts[0], found_positions[sorted_starts[0]]
|
267 |
+
for start in sorted_starts[1:]:
|
268 |
+
if start <= (current_end + 10):
|
269 |
+
current_end = max(current_end, found_positions[start])
|
270 |
+
else:
|
271 |
+
combined_positions.append((current_start, current_end))
|
272 |
+
current_start, current_end = start, found_positions[start]
|
273 |
+
combined_positions.append((current_start, current_end))
|
274 |
+
|
275 |
+
# Construct pos_tokens
|
276 |
+
pos_tokens = []
|
277 |
+
prev_end = 0
|
278 |
+
for start, end in combined_positions:
|
279 |
+
if end-start > 15: # Only combine if there is a significant amount of matched text. Avoids picking up single words like 'and' etc.
|
280 |
+
pos_tokens.append(full_text[prev_end:start])
|
281 |
+
pos_tokens.append('<mark style="color:black;">' + full_text[start:end] + '</mark>')
|
282 |
+
prev_end = end
|
283 |
+
pos_tokens.append(full_text[prev_end:])
|
284 |
+
|
285 |
+
return "".join(pos_tokens)
|
286 |
+
|
287 |
+
|
288 |
+
# # Chat history functions
|
289 |
+
|
290 |
+
def clear_chat(chat_history_state, sources, chat_message, current_topic):
|
291 |
+
chat_history_state = []
|
292 |
+
sources = ''
|
293 |
+
chat_message = ''
|
294 |
+
current_topic = ''
|
295 |
+
|
296 |
+
return chat_history_state, sources, chat_message, current_topic
|
297 |
+
|
298 |
+
|
299 |
+
# Keyword functions
|
300 |
+
|
301 |
+
def remove_q_stopwords(question): # Remove stopwords from question. Not used at the moment
|
302 |
+
# Prepare keywords from question by removing stopwords
|
303 |
+
text = question.lower()
|
304 |
+
|
305 |
+
# Remove numbers
|
306 |
+
text = re.sub('[0-9]', '', text)
|
307 |
+
|
308 |
+
tokenizer = RegexpTokenizer(r'\w+')
|
309 |
+
text_tokens = tokenizer.tokenize(text)
|
310 |
+
#text_tokens = word_tokenize(text)
|
311 |
+
tokens_without_sw = [word for word in text_tokens if not word in stopwords]
|
312 |
+
|
313 |
+
# Remove duplicate words while preserving order
|
314 |
+
ordered_tokens = set()
|
315 |
+
result = []
|
316 |
+
for word in tokens_without_sw:
|
317 |
+
if word not in ordered_tokens:
|
318 |
+
ordered_tokens.add(word)
|
319 |
+
result.append(word)
|
320 |
+
|
321 |
+
|
322 |
+
|
323 |
+
new_question_keywords = ' '.join(result)
|
324 |
+
return new_question_keywords
|
325 |
+
|
326 |
+
def remove_q_ner_extractor(question):
|
327 |
+
|
328 |
+
predict_out = ner_model.predict(question)
|
329 |
+
|
330 |
+
|
331 |
+
|
332 |
+
predict_tokens = [' '.join(v for k, v in d.items() if k == 'span') for d in predict_out]
|
333 |
+
|
334 |
+
# Remove duplicate words while preserving order
|
335 |
+
ordered_tokens = set()
|
336 |
+
result = []
|
337 |
+
for word in predict_tokens:
|
338 |
+
if word not in ordered_tokens:
|
339 |
+
ordered_tokens.add(word)
|
340 |
+
result.append(word)
|
341 |
+
|
342 |
+
|
343 |
+
|
344 |
+
new_question_keywords = ' '.join(result).lower()
|
345 |
+
return new_question_keywords
|
346 |
+
|
347 |
+
def apply_lemmatize(text, wnl=WordNetLemmatizer()):
|
348 |
+
|
349 |
+
def prep_for_lemma(text):
|
350 |
+
|
351 |
+
# Remove numbers
|
352 |
+
text = re.sub('[0-9]', '', text)
|
353 |
+
print(text)
|
354 |
+
|
355 |
+
tokenizer = RegexpTokenizer(r'\w+')
|
356 |
+
text_tokens = tokenizer.tokenize(text)
|
357 |
+
#text_tokens = word_tokenize(text)
|
358 |
+
|
359 |
+
return text_tokens
|
360 |
+
|
361 |
+
tokens = prep_for_lemma(text)
|
362 |
+
|
363 |
+
def lem_word(word):
|
364 |
+
|
365 |
+
if len(word) > 3: out_word = wnl.lemmatize(word)
|
366 |
+
else: out_word = word
|
367 |
+
|
368 |
+
return out_word
|
369 |
+
|
370 |
+
return [lem_word(token) for token in tokens]
|
371 |
+
|
372 |
+
def keybert_keywords(text, n, kw_model):
|
373 |
+
tokens_lemma = apply_lemmatize(text)
|
374 |
+
lemmatised_text = ' '.join(tokens_lemma)
|
375 |
+
|
376 |
+
keywords_text = KeyBERT(model=kw_model).extract_keywords(lemmatised_text, stop_words='english', top_n=n,
|
377 |
+
keyphrase_ngram_range=(1, 1))
|
378 |
+
keywords_list = [item[0] for item in keywords_text]
|
379 |
+
|
380 |
+
return keywords_list
|
381 |
+
|
382 |
+
# Gradio functions
|
383 |
+
def turn_off_interactivity(user_message, history):
|
384 |
+
return gr.update(value="", interactive=False), history + [[user_message, None]]
|
385 |
+
|
386 |
+
def restore_interactivity():
|
387 |
+
return gr.update(interactive=True)
|
388 |
+
|
389 |
+
def update_message(dropdown_value):
|
390 |
+
return gr.Textbox.update(value=dropdown_value)
|
391 |
+
|
392 |
+
def hide_block():
|
393 |
+
return gr.Radio.update(visible=False)
|
search_funcs/ingest.py
ADDED
@@ -0,0 +1,417 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ---
|
2 |
+
# jupyter:
|
3 |
+
# jupytext:
|
4 |
+
# formats: ipynb,py:light
|
5 |
+
# text_representation:
|
6 |
+
# extension: .py
|
7 |
+
# format_name: light
|
8 |
+
# format_version: '1.5'
|
9 |
+
# jupytext_version: 1.14.6
|
10 |
+
# kernelspec:
|
11 |
+
# display_name: Python 3 (ipykernel)
|
12 |
+
# language: python
|
13 |
+
# name: python3
|
14 |
+
# ---
|
15 |
+
|
16 |
+
# # Ingest website to FAISS
|
17 |
+
|
18 |
+
# ## Install/ import stuff we need
|
19 |
+
|
20 |
+
import os
|
21 |
+
from pathlib import Path
|
22 |
+
import re
|
23 |
+
import pandas as pd
|
24 |
+
from typing import TypeVar, List
|
25 |
+
|
26 |
+
#from langchain.embeddings import HuggingFaceEmbeddings # HuggingFaceInstructEmbeddings,
|
27 |
+
from langchain.vectorstores.faiss import FAISS
|
28 |
+
from langchain.vectorstores import Chroma
|
29 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
30 |
+
from langchain.docstore.document import Document
|
31 |
+
|
32 |
+
#from bs4 import BeautifulSoup
|
33 |
+
#from docx import Document as Doc
|
34 |
+
#from pypdf import PdfReader
|
35 |
+
|
36 |
+
PandasDataFrame = TypeVar('pd.core.frame.DataFrame')
|
37 |
+
# -
|
38 |
+
|
39 |
+
split_strat = ["\n\n", "\n", ". ", "! ", "? "]
|
40 |
+
chunk_size = 500
|
41 |
+
chunk_overlap = 0
|
42 |
+
start_index = True
|
43 |
+
|
44 |
+
## Parse files
|
45 |
+
def determine_file_type(file_path):
|
46 |
+
"""
|
47 |
+
Determine the file type based on its extension.
|
48 |
+
|
49 |
+
Parameters:
|
50 |
+
file_path (str): Path to the file.
|
51 |
+
|
52 |
+
Returns:
|
53 |
+
str: File extension (e.g., '.pdf', '.docx', '.txt', '.html').
|
54 |
+
"""
|
55 |
+
return os.path.splitext(file_path)[1].lower()
|
56 |
+
|
57 |
+
def parse_file(file_paths, text_column='text'):
|
58 |
+
"""
|
59 |
+
Accepts a list of file paths, determines each file's type based on its extension,
|
60 |
+
and passes it to the relevant parsing function.
|
61 |
+
|
62 |
+
Parameters:
|
63 |
+
file_paths (list): List of file paths.
|
64 |
+
text_column (str): Name of the column in CSV/Excel files that contains the text content.
|
65 |
+
|
66 |
+
Returns:
|
67 |
+
dict: A dictionary with file paths as keys and their parsed content (or error message) as values.
|
68 |
+
"""
|
69 |
+
|
70 |
+
|
71 |
+
|
72 |
+
if not isinstance(file_paths, list):
|
73 |
+
raise ValueError("Expected a list of file paths.")
|
74 |
+
|
75 |
+
extension_to_parser = {
|
76 |
+
# '.pdf': parse_pdf,
|
77 |
+
# '.docx': parse_docx,
|
78 |
+
# '.txt': parse_txt,
|
79 |
+
# '.html': parse_html,
|
80 |
+
# '.htm': parse_html, # Considering both .html and .htm for HTML files
|
81 |
+
'.csv': lambda file_path: parse_csv_or_excel(file_path, text_column),
|
82 |
+
'.xlsx': lambda file_path: parse_csv_or_excel(file_path, text_column)
|
83 |
+
}
|
84 |
+
|
85 |
+
parsed_contents = {}
|
86 |
+
file_names = []
|
87 |
+
|
88 |
+
for file_path in file_paths:
|
89 |
+
print(file_path.name)
|
90 |
+
#file = open(file_path.name, 'r')
|
91 |
+
#print(file)
|
92 |
+
file_extension = determine_file_type(file_path.name)
|
93 |
+
if file_extension in extension_to_parser:
|
94 |
+
parsed_contents[file_path.name] = extension_to_parser[file_extension](file_path.name)
|
95 |
+
else:
|
96 |
+
parsed_contents[file_path.name] = f"Unsupported file type: {file_extension}"
|
97 |
+
|
98 |
+
filename_end = get_file_path_end(file_path.name)
|
99 |
+
|
100 |
+
file_names.append(filename_end)
|
101 |
+
|
102 |
+
return parsed_contents, file_names
|
103 |
+
|
104 |
+
def text_regex_clean(text):
|
105 |
+
# Merge hyphenated words
|
106 |
+
text = re.sub(r"(\w+)-\n(\w+)", r"\1\2", text)
|
107 |
+
# If a double newline ends in a letter, add a full stop.
|
108 |
+
text = re.sub(r'(?<=[a-zA-Z])\n\n', '.\n\n', text)
|
109 |
+
# Fix newlines in the middle of sentences
|
110 |
+
text = re.sub(r"(?<!\n\s)\n(?!\s\n)", " ", text.strip())
|
111 |
+
# Remove multiple newlines
|
112 |
+
text = re.sub(r"\n\s*\n", "\n\n", text)
|
113 |
+
text = re.sub(r" ", " ", text)
|
114 |
+
# Add full stops and new lines between words with no space between where the second one has a capital letter
|
115 |
+
text = re.sub(r'(?<=[a-z])(?=[A-Z])', '. \n\n', text)
|
116 |
+
|
117 |
+
return text
|
118 |
+
|
119 |
+
def parse_csv_or_excel(file_path, text_column = "text"):
|
120 |
+
"""
|
121 |
+
Read in a CSV or Excel file.
|
122 |
+
|
123 |
+
Parameters:
|
124 |
+
file_path (str): Path to the CSV file.
|
125 |
+
text_column (str): Name of the column in the CSV file that contains the text content.
|
126 |
+
|
127 |
+
Returns:
|
128 |
+
Pandas DataFrame: Dataframe output from file read
|
129 |
+
"""
|
130 |
+
|
131 |
+
#out_df = pd.DataFrame()
|
132 |
+
|
133 |
+
#for file_path in file_paths:
|
134 |
+
file_extension = determine_file_type(file_path.name)
|
135 |
+
file_name = get_file_path_end(file_path.name)
|
136 |
+
file_names = [file_name]
|
137 |
+
|
138 |
+
if file_extension == ".csv":
|
139 |
+
df = pd.read_csv(file_path.name, low_memory=False)
|
140 |
+
if text_column not in df.columns: return pd.DataFrame(), ['Please choose a valid column name']
|
141 |
+
df['source'] = file_name
|
142 |
+
df['page_section'] = ""
|
143 |
+
elif file_extension == ".xlsx":
|
144 |
+
df = pd.read_excel(file_path.name, engine='openpyxl')
|
145 |
+
if text_column not in df.columns: return pd.DataFrame(), ['Please choose a valid column name']
|
146 |
+
df['source'] = file_name
|
147 |
+
df['page_section'] = ""
|
148 |
+
else:
|
149 |
+
print(f"Unsupported file type: {file_extension}")
|
150 |
+
return pd.DataFrame(), ['Please choose a valid file type']
|
151 |
+
|
152 |
+
# file_names.append(file_name)
|
153 |
+
# out_df = pd.concat([out_df, df])
|
154 |
+
|
155 |
+
#if text_column not in df.columns:
|
156 |
+
# return f"Column '{text_column}' not found in {file_path}"
|
157 |
+
#text_out = " ".join(df[text_column].dropna().astype(str))
|
158 |
+
return df, file_names
|
159 |
+
|
160 |
+
def parse_excel(file_path, text_column):
|
161 |
+
"""
|
162 |
+
Read text from an Excel file.
|
163 |
+
|
164 |
+
Parameters:
|
165 |
+
file_path (str): Path to the Excel file.
|
166 |
+
text_column (str): Name of the column in the Excel file that contains the text content.
|
167 |
+
|
168 |
+
Returns:
|
169 |
+
Pandas DataFrame: Dataframe output from file read
|
170 |
+
"""
|
171 |
+
df = pd.read_excel(file_path, engine='openpyxl')
|
172 |
+
#if text_column not in df.columns:
|
173 |
+
# return f"Column '{text_column}' not found in {file_path}"
|
174 |
+
#text_out = " ".join(df[text_column].dropna().astype(str))
|
175 |
+
return df
|
176 |
+
|
177 |
+
def get_file_path_end(file_path):
|
178 |
+
match = re.search(r'(.*[\/\\])?(.+)$', file_path)
|
179 |
+
|
180 |
+
filename_end = match.group(2) if match else ''
|
181 |
+
|
182 |
+
return filename_end
|
183 |
+
|
184 |
+
# +
|
185 |
+
# Convert parsed text to docs
|
186 |
+
# -
|
187 |
+
|
188 |
+
def text_to_docs(text_dict: dict, chunk_size: int = chunk_size) -> List[Document]:
|
189 |
+
"""
|
190 |
+
Converts the output of parse_file (a dictionary of file paths to content)
|
191 |
+
to a list of Documents with metadata.
|
192 |
+
"""
|
193 |
+
|
194 |
+
doc_sections = []
|
195 |
+
parent_doc_sections = []
|
196 |
+
|
197 |
+
for file_path, content in text_dict.items():
|
198 |
+
ext = os.path.splitext(file_path)[1].lower()
|
199 |
+
|
200 |
+
# Depending on the file extension, handle the content
|
201 |
+
# if ext == '.pdf':
|
202 |
+
# docs, page_docs = pdf_text_to_docs(content, chunk_size)
|
203 |
+
# elif ext in ['.html', '.htm', '.txt', '.docx']:
|
204 |
+
# docs = html_text_to_docs(content, chunk_size)
|
205 |
+
if ext in ['.csv', '.xlsx']:
|
206 |
+
docs, page_docs = csv_excel_text_to_docs(content, chunk_size)
|
207 |
+
else:
|
208 |
+
print(f"Unsupported file type {ext} for {file_path}. Skipping.")
|
209 |
+
continue
|
210 |
+
|
211 |
+
|
212 |
+
filename_end = get_file_path_end(file_path)
|
213 |
+
|
214 |
+
#match = re.search(r'(.*[\/\\])?(.+)$', file_path)
|
215 |
+
#filename_end = match.group(2) if match else ''
|
216 |
+
|
217 |
+
# Add filename as metadata
|
218 |
+
for doc in docs: doc.metadata["source"] = filename_end
|
219 |
+
#for parent_doc in parent_docs: parent_doc.metadata["source"] = filename_end
|
220 |
+
|
221 |
+
doc_sections.extend(docs)
|
222 |
+
#parent_doc_sections.extend(parent_docs)
|
223 |
+
|
224 |
+
return doc_sections#, page_docs
|
225 |
+
|
226 |
+
|
227 |
+
def write_out_metadata_as_string(metadata_in):
|
228 |
+
# If metadata_in is a single dictionary, wrap it in a list
|
229 |
+
if isinstance(metadata_in, dict):
|
230 |
+
metadata_in = [metadata_in]
|
231 |
+
|
232 |
+
metadata_string = [f"{' '.join(f'{k}: {v}' for k, v in d.items() if k != 'page_section')}" for d in metadata_in] # ['metadata']
|
233 |
+
return metadata_string
|
234 |
+
|
235 |
+
def csv_excel_text_to_docs(df, text_column='text', chunk_size=None) -> List[Document]:
|
236 |
+
"""Converts a DataFrame's content to a list of Documents with metadata."""
|
237 |
+
|
238 |
+
doc_sections = []
|
239 |
+
df[text_column] = df[text_column].astype(str) # Ensure column is a string column
|
240 |
+
|
241 |
+
# For each row in the dataframe
|
242 |
+
for idx, row in df.iterrows():
|
243 |
+
# Extract the text content for the document
|
244 |
+
doc_content = row[text_column]
|
245 |
+
|
246 |
+
# Generate metadata containing other columns' data
|
247 |
+
metadata = {"row": idx + 1}
|
248 |
+
for col, value in row.items():
|
249 |
+
if col != text_column:
|
250 |
+
metadata[col] = value
|
251 |
+
|
252 |
+
metadata_string = write_out_metadata_as_string(metadata)[0]
|
253 |
+
|
254 |
+
|
255 |
+
|
256 |
+
# If chunk_size is provided, split the text into chunks
|
257 |
+
if chunk_size:
|
258 |
+
# Assuming you have a text splitter function similar to the PDF handling
|
259 |
+
text_splitter = RecursiveCharacterTextSplitter(
|
260 |
+
chunk_size=chunk_size,
|
261 |
+
# Other arguments as required by the splitter
|
262 |
+
)
|
263 |
+
sections = text_splitter.split_text(doc_content)
|
264 |
+
|
265 |
+
|
266 |
+
# For each section, create a Document object
|
267 |
+
for i, section in enumerate(sections):
|
268 |
+
#section = '. '.join([metadata_string, section])
|
269 |
+
doc = Document(page_content=section,
|
270 |
+
metadata={**metadata, "section": i, "row_section": f"{metadata['row']}-{i}"})
|
271 |
+
doc_sections.append(doc)
|
272 |
+
else:
|
273 |
+
# If no chunk_size is provided, create a single Document object for the row
|
274 |
+
#doc_content = '. '.join([metadata_string, doc_content])
|
275 |
+
doc = Document(page_content=doc_content, metadata=metadata)
|
276 |
+
doc_sections.append(doc)
|
277 |
+
|
278 |
+
return doc_sections
|
279 |
+
|
280 |
+
# # Functions for working with documents after loading them back in
|
281 |
+
|
282 |
+
def pull_out_data(series):
|
283 |
+
|
284 |
+
# define a lambda function to convert each string into a tuple
|
285 |
+
to_tuple = lambda x: eval(x)
|
286 |
+
|
287 |
+
# apply the lambda function to each element of the series
|
288 |
+
series_tup = series.apply(to_tuple)
|
289 |
+
|
290 |
+
series_tup_content = list(zip(*series_tup))[1]
|
291 |
+
|
292 |
+
series = pd.Series(list(series_tup_content))#.str.replace("^Main post content", "", regex=True).str.strip()
|
293 |
+
|
294 |
+
return series
|
295 |
+
|
296 |
+
def docs_from_csv(df):
|
297 |
+
|
298 |
+
import ast
|
299 |
+
|
300 |
+
documents = []
|
301 |
+
|
302 |
+
page_content = pull_out_data(df["0"])
|
303 |
+
metadatas = pull_out_data(df["1"])
|
304 |
+
|
305 |
+
for x in range(0,len(df)):
|
306 |
+
new_doc = Document(page_content=page_content[x], metadata=metadatas[x])
|
307 |
+
documents.append(new_doc)
|
308 |
+
|
309 |
+
return documents
|
310 |
+
|
311 |
+
def docs_from_lists(docs, metadatas):
|
312 |
+
|
313 |
+
documents = []
|
314 |
+
|
315 |
+
for x, doc in enumerate(docs):
|
316 |
+
new_doc = Document(page_content=doc, metadata=metadatas[x])
|
317 |
+
documents.append(new_doc)
|
318 |
+
|
319 |
+
return documents
|
320 |
+
|
321 |
+
def docs_elements_from_csv_save(docs_path="documents.csv"):
|
322 |
+
|
323 |
+
documents = pd.read_csv(docs_path)
|
324 |
+
|
325 |
+
docs_out = docs_from_csv(documents)
|
326 |
+
|
327 |
+
out_df = pd.DataFrame(docs_out)
|
328 |
+
|
329 |
+
docs_content = pull_out_data(out_df[0].astype(str))
|
330 |
+
|
331 |
+
docs_meta = pull_out_data(out_df[1].astype(str))
|
332 |
+
|
333 |
+
doc_sources = [d['source'] for d in docs_meta]
|
334 |
+
|
335 |
+
return out_df, docs_content, docs_meta, doc_sources
|
336 |
+
|
337 |
+
# ## Create embeddings and save faiss vector store to the path specified in `save_to`
|
338 |
+
|
339 |
+
def load_embeddings(model_name = "BAAI/bge-base-en-v1.5"):
|
340 |
+
|
341 |
+
#if model_name == "hkunlp/instructor-large":
|
342 |
+
# embeddings_func = HuggingFaceInstructEmbeddings(model_name=model_name,
|
343 |
+
# embed_instruction="Represent the paragraph for retrieval: ",
|
344 |
+
# query_instruction="Represent the question for retrieving supporting documents: "
|
345 |
+
# )
|
346 |
+
|
347 |
+
#else:
|
348 |
+
embeddings_func = HuggingFaceEmbeddings(model_name=model_name)
|
349 |
+
|
350 |
+
global embeddings
|
351 |
+
|
352 |
+
embeddings = embeddings_func
|
353 |
+
|
354 |
+
return embeddings_func
|
355 |
+
|
356 |
+
def embed_faiss_save_to_zip(docs_out, save_to="faiss_lambeth_census_embedding", model_name = "BAAI/bge-base-en-v1.5"):
|
357 |
+
|
358 |
+
load_embeddings(model_name=model_name)
|
359 |
+
|
360 |
+
#embeddings_fast = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
361 |
+
|
362 |
+
print(f"> Total split documents: {len(docs_out)}")
|
363 |
+
|
364 |
+
vectorstore = FAISS.from_documents(documents=docs_out, embedding=embeddings)
|
365 |
+
|
366 |
+
|
367 |
+
if Path(save_to).exists():
|
368 |
+
vectorstore.save_local(folder_path=save_to)
|
369 |
+
|
370 |
+
print("> DONE")
|
371 |
+
print(f"> Saved to: {save_to}")
|
372 |
+
|
373 |
+
### Save as zip, then remove faiss/pkl files to allow for upload to huggingface
|
374 |
+
|
375 |
+
import shutil
|
376 |
+
|
377 |
+
shutil.make_archive(save_to, 'zip', save_to)
|
378 |
+
|
379 |
+
os.remove(save_to + "/index.faiss")
|
380 |
+
os.remove(save_to + "/index.pkl")
|
381 |
+
|
382 |
+
shutil.move(save_to + '.zip', save_to + "/" + save_to + '.zip')
|
383 |
+
|
384 |
+
return vectorstore
|
385 |
+
|
386 |
+
def docs_to_chroma_save(embeddings, docs_out:PandasDataFrame, save_to:str):
|
387 |
+
print(f"> Total split documents: {len(docs_out)}")
|
388 |
+
|
389 |
+
vectordb = Chroma.from_documents(documents=docs_out,
|
390 |
+
embedding=embeddings,
|
391 |
+
persist_directory=save_to)
|
392 |
+
|
393 |
+
# persiste the db to disk
|
394 |
+
vectordb.persist()
|
395 |
+
|
396 |
+
print("> DONE")
|
397 |
+
print(f"> Saved to: {save_to}")
|
398 |
+
|
399 |
+
return vectordb
|
400 |
+
|
401 |
+
def sim_search_local_saved_vec(query, k_val, save_to="faiss_lambeth_census_embedding"):
|
402 |
+
|
403 |
+
load_embeddings()
|
404 |
+
|
405 |
+
docsearch = FAISS.load_local(folder_path=save_to, embeddings=embeddings)
|
406 |
+
|
407 |
+
|
408 |
+
display(Markdown(question))
|
409 |
+
|
410 |
+
search = docsearch.similarity_search_with_score(query, k=k_val)
|
411 |
+
|
412 |
+
for item in search:
|
413 |
+
print(item[0].page_content)
|
414 |
+
print(f"Page: {item[0].metadata['source']}")
|
415 |
+
print(f"Date: {item[0].metadata['date']}")
|
416 |
+
print(f"Score: {item[1]}")
|
417 |
+
print("---")
|