Spaces:
Build error
Build error
File size: 4,211 Bytes
60fc5e8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
import os
import numpy as np
from sentence_transformers import SentenceTransformer
from sentence_transformers.util import cos_sim
from modules.pdfExtractor import PdfConverter
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.schema import Document
# model = SentenceTransformer(
# "thenlper/gte-base", # switch to en/zh for English or Chinese
# trust_remote_code=True
# )
# model.save(os.path.join(os.getcwd(), "embeddingModel"))
def contextChunks(document_text, chunk_size, chunk_overlap):
document = Document(page_content=document_text)
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
text_chunks = text_splitter.split_documents([document])
text_content_chunks = [chunk.page_content for chunk in text_chunks]
return text_content_chunks
def contextEmbedding(model, text_content_chunks):
text_contents_embeddings = [model.encode([text]) for text in text_content_chunks]
return text_contents_embeddings
def contextEmbeddingChroma(model, text_content_chunks, db_client, db_path):
text_contents_embeddings = [model.encode([text])[0] for text in text_content_chunks]
ids = [f"id_{i}" for i in range(len(text_content_chunks))]
collection = db_client.get_or_create_collection("embeddings_collection")
collection.add(
documents=text_content_chunks,
embeddings=text_contents_embeddings,
ids=ids # Include the generated IDs
)
return text_contents_embeddings
def retrieveEmbeddingsChroma(db_client):
collection_name = "embeddings_collection"
collection = db_client.get_collection(collection_name)
records = collection.get()
embeddings = []
text_chunks = []
if records and "documents" in records and "embeddings" in records:
text_chunks = records["documents"] or []
embeddings = records["embeddings"] or []
else:
print("No documents or embeddings found in the collection.")
return embeddings, text_chunks
def ragQuery(model, query):
return model.encode([query])
def similarity(query_embedding, text_contents_embeddings, text_content_chunks, top_k):
similarities = [(text, cos_sim(embedding, query_embedding[0]))
for text, embedding in zip(text_content_chunks, text_contents_embeddings)]
similarities_sorted = sorted(similarities, key=lambda x: x[1], reverse=True)
top_k_texts = [text for text, _ in similarities_sorted[:top_k]]
return "\n".join(f"Text Chunk <{i + 1}>\n{element}" for i, element in enumerate(top_k_texts))
def similarityChroma(query_embedding, db_client, top_k):
collection = db_client.get_collection("embeddings_collection")
results = collection.get(include=["documents", "embeddings"])
text_content_chunks = results["documents"]
text_contents_embeddings = np.array(results["embeddings"])
text_contents_embeddings = text_contents_embeddings.astype(np.float32)
query_embedding = query_embedding.astype(np.float32)
similarities = [
(text, cos_sim(embedding.reshape(1, -1), query_embedding.reshape(1, -1))[0][0])
for text, embedding in zip(text_content_chunks, text_contents_embeddings)
]
similarities_sorted = sorted(similarities, key=lambda x: x[1], reverse=True)
top_k_texts = [text for text, _ in similarities_sorted[:top_k]]
return "\n".join(f"Text Chunk <{i + 1}>\n{element}" for i, element in enumerate(top_k_texts))
# pdf_file = os.path.join(os.getcwd(), "pdfs", "test2.pdf")
# converter = PdfConverter(pdf_file)
# document_text = converter.convert_to_markdown()
# chunk_size, chunk_overlap, top_k = 2000, 200, 5
# query = "what metric used in this paper for performance evaluation?"
# text_content_chunks = contextChunks(document_text, chunk_size, chunk_overlap)
# text_contents_embeddings = contextEmbedding(model, text_content_chunks)
# query_embedding = ragQuery(model, query)
# top_k_matches = similarity(query_embedding, text_contents_embeddings, text_content_chunks, top_k)
# print(top_k_matches[1]) |