Spaces:
Build error
Build error
import os | |
import numpy as np | |
from sentence_transformers import SentenceTransformer | |
from sentence_transformers.util import cos_sim | |
from modules.pdfExtractor import PdfConverter | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.schema import Document | |
# model = SentenceTransformer( | |
# "thenlper/gte-base", # switch to en/zh for English or Chinese | |
# trust_remote_code=True | |
# ) | |
# model.save(os.path.join(os.getcwd(), "embeddingModel")) | |
def contextChunks(document_text, chunk_size, chunk_overlap): | |
document = Document(page_content=document_text) | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap) | |
text_chunks = text_splitter.split_documents([document]) | |
text_content_chunks = [chunk.page_content for chunk in text_chunks] | |
return text_content_chunks | |
def contextEmbedding(model, text_content_chunks): | |
text_contents_embeddings = [model.encode([text]) for text in text_content_chunks] | |
return text_contents_embeddings | |
def contextEmbeddingChroma(model, text_content_chunks, db_client, db_path): | |
text_contents_embeddings = [model.encode([text])[0] for text in text_content_chunks] | |
ids = [f"id_{i}" for i in range(len(text_content_chunks))] | |
collection = db_client.get_or_create_collection("embeddings_collection") | |
collection.add( | |
documents=text_content_chunks, | |
embeddings=text_contents_embeddings, | |
ids=ids # Include the generated IDs | |
) | |
return text_contents_embeddings | |
def retrieveEmbeddingsChroma(db_client): | |
collection_name = "embeddings_collection" | |
collection = db_client.get_collection(collection_name) | |
records = collection.get() | |
embeddings = [] | |
text_chunks = [] | |
if records and "documents" in records and "embeddings" in records: | |
text_chunks = records["documents"] or [] | |
embeddings = records["embeddings"] or [] | |
else: | |
print("No documents or embeddings found in the collection.") | |
return embeddings, text_chunks | |
def ragQuery(model, query): | |
return model.encode([query]) | |
def similarity(query_embedding, text_contents_embeddings, text_content_chunks, top_k): | |
similarities = [(text, cos_sim(embedding, query_embedding[0])) | |
for text, embedding in zip(text_content_chunks, text_contents_embeddings)] | |
similarities_sorted = sorted(similarities, key=lambda x: x[1], reverse=True) | |
top_k_texts = [text for text, _ in similarities_sorted[:top_k]] | |
return "\n".join(f"Text Chunk <{i + 1}>\n{element}" for i, element in enumerate(top_k_texts)) | |
def similarityChroma(query_embedding, db_client, top_k): | |
collection = db_client.get_collection("embeddings_collection") | |
results = collection.get(include=["documents", "embeddings"]) | |
text_content_chunks = results["documents"] | |
text_contents_embeddings = np.array(results["embeddings"]) | |
text_contents_embeddings = text_contents_embeddings.astype(np.float32) | |
query_embedding = query_embedding.astype(np.float32) | |
similarities = [ | |
(text, cos_sim(embedding.reshape(1, -1), query_embedding.reshape(1, -1))[0][0]) | |
for text, embedding in zip(text_content_chunks, text_contents_embeddings) | |
] | |
similarities_sorted = sorted(similarities, key=lambda x: x[1], reverse=True) | |
top_k_texts = [text for text, _ in similarities_sorted[:top_k]] | |
return "\n".join(f"Text Chunk <{i + 1}>\n{element}" for i, element in enumerate(top_k_texts)) | |
# pdf_file = os.path.join(os.getcwd(), "pdfs", "test2.pdf") | |
# converter = PdfConverter(pdf_file) | |
# document_text = converter.convert_to_markdown() | |
# chunk_size, chunk_overlap, top_k = 2000, 200, 5 | |
# query = "what metric used in this paper for performance evaluation?" | |
# text_content_chunks = contextChunks(document_text, chunk_size, chunk_overlap) | |
# text_contents_embeddings = contextEmbedding(model, text_content_chunks) | |
# query_embedding = ragQuery(model, query) | |
# top_k_matches = similarity(query_embedding, text_contents_embeddings, text_content_chunks, top_k) | |
# print(top_k_matches[1]) |