mitulagr2's picture
Update rag.py
bc8f854
raw
history blame
4.24 kB
import os
import logging
from llama_index.core import (
SimpleDirectoryReader,
VectorStoreIndex,
StorageContext,
Settings,
get_response_synthesizer)
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.schema import TextNode, MetadataMode
from llama_index.core.vector_stores import VectorStoreQuery
from llama_index.llms.llama_cpp import LlamaCPP
from llama_index.embeddings.fastembed import FastEmbedEmbedding
from llama_index.vector_stores.qdrant import QdrantVectorStore
from qdrant_client import QdrantClient
from llama_index.readers.file.docs.base import DocxReader, HWPReader, PDFReader
store_dir = os.path.expanduser("~/wtp_be_store/")
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class ChatPDF:
pdf_count = 0
text_chunks = []
doc_ids = []
nodes = []
def __init__(self):
self.text_parser = SentenceSplitter(chunk_size=512, chunk_overlap=24)
logger.info("initializing the vector store related objects")
self.client = QdrantClient(path=store_dir)
self.vector_store = QdrantVectorStore(
client=self.client,
collection_name="rag_documents",
# enable_hybrid=True
)
logger.info("initializing the FastEmbedEmbedding")
self.embed_model = FastEmbedEmbedding(
# model_name="BAAI/bge-small-en"
)
llm = LlamaCPP(
model_url="https://huggingface.co/Qwen/Qwen2-0.5B-Instruct-GGUF/resolve/main/qwen2-0_5b-instruct-fp16.gguf",
temperature=0.1,
max_new_tokens=256,
generate_kwargs={"max_tokens": 256, "temperature": 0.1, "top_k": 3},
# messages_to_prompt=self.messages_to_prompt,
# completion_to_prompt=self.completion_to_prompt,
verbose=True,
)
# tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
# tokenizer.save_pretrained("./models/tokenizer/")
logger.info("initializing the global settings")
Settings.text_splitter = self.text_parser
Settings.embed_model = self.embed_model
Settings.llm = llm
# Settings.tokenzier = tokenizer
Settings.transformations = [self.text_parser]
def ingest(self, files_dir: str):
docs = SimpleDirectoryReader(input_dir=files_dir).load_data()
logger.info("enumerating docs")
for doc_idx, doc in enumerate(docs):
self.pdf_count = self.pdf_count + 1
curr_text_chunks = self.text_parser.split_text(doc.text)
self.text_chunks.extend(curr_text_chunks)
self.doc_ids.extend([doc_idx] * len(curr_text_chunks))
logger.info("enumerating text_chunks")
for idx, text_chunk in enumerate(self.text_chunks):
node = TextNode(text=text_chunk)
# src_doc = docs[self.doc_ids[idx]]
# node.metadata = src_doc.metadata
if node.get_content(metadata_mode=MetadataMode.EMBED):
self.nodes.append(node)
logger.info("enumerating nodes")
for node in self.nodes:
node_embedding = self.embed_model.get_text_embedding(
node.get_content(metadata_mode=MetadataMode.ALL)
)
node.embedding = node_embedding
logger.info("initializing the storage context")
storage_context = StorageContext.from_defaults(vector_store=self.vector_store)
logger.info("indexing the nodes in VectorStoreIndex")
index = VectorStoreIndex(
nodes=self.nodes,
storage_context=storage_context,
transformations=Settings.transformations,
)
self.query_engine = index.as_query_engine(
streaming=True,
similarity_top_k=3,
)
def ask(self, query: str):
logger.info("retrieving the response to the query")
streaming_response = self.query_engine.query(query)
return streaming_response
def clear(self):
if self.nodes:
self.vector_store.delete_nodes(self.nodes)
self.pdf_count = 0
self.text_chunks = []
self.doc_ids = []
self.nodes = []