mitulagr2's picture
fix
d7e310c
raw
history blame
4.03 kB
import os
import logging
from llama_index.core import (
SimpleDirectoryReader,
VectorStoreIndex,
StorageContext,
Settings,
get_response_synthesizer)
from llama_index.core.query_engine import RetrieverQueryEngine, TransformQueryEngine
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.schema import TextNode, MetadataMode
from llama_index.vector_stores.qdrant import QdrantVectorStore
from llama_index.embeddings.ollama import OllamaEmbedding
from llama_index.llms.ollama import Ollama
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.core.indices.query.query_transform import HyDEQueryTransform
from qdrant_client import QdrantClient
QDRANT_API_URL = os.getenv('QDRANT_API_URL')
QDRANT_API_KEY = os.getenv('QDRANT_API_KEY')
class ChatPDF:
text_chunks = []
doc_ids = []
nodes = []
hyde_query_engine = None
def __init__(self):
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
text_parser = SentenceSplitter(chunk_size=512, chunk_overlap=100)
logger.info("initializing the vector store related objects")
client = QdrantClient(url=QDRANT_API_URL, api_key=QDRANT_API_KEY)
vector_store = QdrantVectorStore(client=client, collection_name="rag_documents")
logger.info("initializing the OllamaEmbedding")
embed_model = OllamaEmbedding(model_name='mxbai-embed-large')
logger.info("initializing the global settings")
Settings.embed_model = embed_model
Settings.llm = Ollama(model="qwen:1.8b", request_timeout=1000000)
Settings.transformations = [text_parser]
def ingest(self, dir_path: str):
docs = SimpleDirectoryReader(input_dir=dir_path).load_data()
logger.info("enumerating docs")
for doc_idx, doc in enumerate(docs):
curr_text_chunks = text_parser.split_text(doc.text)
text_chunks.extend(curr_text_chunks)
doc_ids.extend([doc_idx] * len(curr_text_chunks))
logger.info("enumerating text_chunks")
for idx, text_chunk in enumerate(text_chunks):
node = TextNode(text=text_chunk)
src_doc = docs[doc_ids[idx]]
node.metadata = src_doc.metadata
nodes.append(node)
logger.info("enumerating nodes")
for node in nodes:
node_embedding = embed_model.get_text_embedding(
node.get_content(metadata_mode=MetadataMode.ALL)
)
node.embedding = node_embedding
logger.info("initializing the storage context")
storage_context = StorageContext.from_defaults(vector_store=vector_store)
logger.info("indexing the nodes in VectorStoreIndex")
index = VectorStoreIndex(
nodes=nodes,
storage_context=storage_context,
transformations=Settings.transformations,
)
logger.info("initializing the VectorIndexRetriever with top_k as 5")
vector_retriever = VectorIndexRetriever(index=index, similarity_top_k=5)
response_synthesizer = get_response_synthesizer()
logger.info("creating the RetrieverQueryEngine instance")
vector_query_engine = RetrieverQueryEngine(
retriever=vector_retriever,
response_synthesizer=response_synthesizer,
)
logger.info("creating the HyDEQueryTransform instance")
hyde = HyDEQueryTransform(include_original=True)
self.hyde_query_engine = TransformQueryEngine(vector_query_engine, hyde)
def ask(self, query: str):
if not self.hyde_query_engine:
return "Please, add a PDF document first."
logger.info("retrieving the response to the query")
response = self.hyde_query_engine.query(str_or_query_bundle=query)
print(response)
return response
def clear(self):
self.text_chunks = []
self.doc_ids = []
self.nodes = []
self.hyde_query_engine = None