Spaces:
Runtime error
Runtime error
File size: 4,025 Bytes
6923d59 63a6e05 1ff6584 63a6e05 5e8fd8b 1ff6584 d7e310c 5e8fd8b 1ff6584 0ee737b 1ff6584 63a6e05 1ff6584 8e67761 1ff6584 5e8fd8b 1ff6584 5e8fd8b 1ff6584 f1cf709 5e8fd8b 1ff6584 5e8fd8b 1ff6584 5e8fd8b 1ff6584 5e8fd8b 1ff6584 d7e310c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
import os
import logging
from llama_index.core import (
SimpleDirectoryReader,
VectorStoreIndex,
StorageContext,
Settings,
get_response_synthesizer)
from llama_index.core.query_engine import RetrieverQueryEngine, TransformQueryEngine
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.schema import TextNode, MetadataMode
from llama_index.vector_stores.qdrant import QdrantVectorStore
from llama_index.embeddings.ollama import OllamaEmbedding
from llama_index.llms.ollama import Ollama
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.core.indices.query.query_transform import HyDEQueryTransform
from qdrant_client import QdrantClient
QDRANT_API_URL = os.getenv('QDRANT_API_URL')
QDRANT_API_KEY = os.getenv('QDRANT_API_KEY')
class ChatPDF:
text_chunks = []
doc_ids = []
nodes = []
hyde_query_engine = None
def __init__(self):
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
text_parser = SentenceSplitter(chunk_size=512, chunk_overlap=100)
logger.info("initializing the vector store related objects")
client = QdrantClient(url=QDRANT_API_URL, api_key=QDRANT_API_KEY)
vector_store = QdrantVectorStore(client=client, collection_name="rag_documents")
logger.info("initializing the OllamaEmbedding")
embed_model = OllamaEmbedding(model_name='mxbai-embed-large')
logger.info("initializing the global settings")
Settings.embed_model = embed_model
Settings.llm = Ollama(model="qwen:1.8b", request_timeout=1000000)
Settings.transformations = [text_parser]
def ingest(self, dir_path: str):
docs = SimpleDirectoryReader(input_dir=dir_path).load_data()
logger.info("enumerating docs")
for doc_idx, doc in enumerate(docs):
curr_text_chunks = text_parser.split_text(doc.text)
text_chunks.extend(curr_text_chunks)
doc_ids.extend([doc_idx] * len(curr_text_chunks))
logger.info("enumerating text_chunks")
for idx, text_chunk in enumerate(text_chunks):
node = TextNode(text=text_chunk)
src_doc = docs[doc_ids[idx]]
node.metadata = src_doc.metadata
nodes.append(node)
logger.info("enumerating nodes")
for node in nodes:
node_embedding = embed_model.get_text_embedding(
node.get_content(metadata_mode=MetadataMode.ALL)
)
node.embedding = node_embedding
logger.info("initializing the storage context")
storage_context = StorageContext.from_defaults(vector_store=vector_store)
logger.info("indexing the nodes in VectorStoreIndex")
index = VectorStoreIndex(
nodes=nodes,
storage_context=storage_context,
transformations=Settings.transformations,
)
logger.info("initializing the VectorIndexRetriever with top_k as 5")
vector_retriever = VectorIndexRetriever(index=index, similarity_top_k=5)
response_synthesizer = get_response_synthesizer()
logger.info("creating the RetrieverQueryEngine instance")
vector_query_engine = RetrieverQueryEngine(
retriever=vector_retriever,
response_synthesizer=response_synthesizer,
)
logger.info("creating the HyDEQueryTransform instance")
hyde = HyDEQueryTransform(include_original=True)
self.hyde_query_engine = TransformQueryEngine(vector_query_engine, hyde)
def ask(self, query: str):
if not self.hyde_query_engine:
return "Please, add a PDF document first."
logger.info("retrieving the response to the query")
response = self.hyde_query_engine.query(str_or_query_bundle=query)
print(response)
return response
def clear(self):
self.text_chunks = []
self.doc_ids = []
self.nodes = []
self.hyde_query_engine = None |