File size: 4,025 Bytes
6923d59
63a6e05
 
1ff6584
 
 
 
 
 
 
 
 
 
 
 
 
 
63a6e05
 
 
 
5e8fd8b
 
 
1ff6584
 
 
d7e310c
5e8fd8b
 
1ff6584
 
0ee737b
1ff6584
 
 
63a6e05
1ff6584
 
 
8e67761
1ff6584
 
 
 
5e8fd8b
1ff6584
 
5e8fd8b
1ff6584
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f1cf709
5e8fd8b
1ff6584
 
 
 
 
 
 
 
 
 
 
5e8fd8b
 
1ff6584
5e8fd8b
 
1ff6584
 
 
 
5e8fd8b
 
1ff6584
 
d7e310c
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import os
import logging

from llama_index.core import (
    SimpleDirectoryReader,
    VectorStoreIndex,
    StorageContext,
    Settings,
    get_response_synthesizer)
from llama_index.core.query_engine import RetrieverQueryEngine, TransformQueryEngine
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.schema import TextNode, MetadataMode
from llama_index.vector_stores.qdrant import QdrantVectorStore
from llama_index.embeddings.ollama import OllamaEmbedding
from llama_index.llms.ollama import Ollama
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.core.indices.query.query_transform import HyDEQueryTransform
from qdrant_client import QdrantClient

QDRANT_API_URL = os.getenv('QDRANT_API_URL')
QDRANT_API_KEY = os.getenv('QDRANT_API_KEY')


class ChatPDF:
    text_chunks = []
    doc_ids = []
    nodes = []
    hyde_query_engine = None

    def __init__(self):
        logging.basicConfig(level=logging.INFO)
        logger = logging.getLogger(__name__)

        text_parser = SentenceSplitter(chunk_size=512, chunk_overlap=100)

        logger.info("initializing the vector store related objects")
        client = QdrantClient(url=QDRANT_API_URL, api_key=QDRANT_API_KEY)
        vector_store = QdrantVectorStore(client=client, collection_name="rag_documents")

        logger.info("initializing the OllamaEmbedding")
        embed_model = OllamaEmbedding(model_name='mxbai-embed-large')
        logger.info("initializing the global settings")
        Settings.embed_model = embed_model
        Settings.llm = Ollama(model="qwen:1.8b", request_timeout=1000000)
        Settings.transformations = [text_parser]

    def ingest(self, dir_path: str):
        docs = SimpleDirectoryReader(input_dir=dir_path).load_data()

        logger.info("enumerating docs")
        for doc_idx, doc in enumerate(docs):
            curr_text_chunks = text_parser.split_text(doc.text)
            text_chunks.extend(curr_text_chunks)
            doc_ids.extend([doc_idx] * len(curr_text_chunks))

        logger.info("enumerating text_chunks")
        for idx, text_chunk in enumerate(text_chunks):
            node = TextNode(text=text_chunk)
            src_doc = docs[doc_ids[idx]]
            node.metadata = src_doc.metadata
            nodes.append(node)

        logger.info("enumerating nodes")
        for node in nodes:
            node_embedding = embed_model.get_text_embedding(
                node.get_content(metadata_mode=MetadataMode.ALL)
            )
            node.embedding = node_embedding

        logger.info("initializing the storage context")
        storage_context = StorageContext.from_defaults(vector_store=vector_store)
        logger.info("indexing the nodes in VectorStoreIndex")
        index = VectorStoreIndex(
            nodes=nodes,
            storage_context=storage_context,
            transformations=Settings.transformations,
        )

        logger.info("initializing the VectorIndexRetriever with top_k as 5")
        vector_retriever = VectorIndexRetriever(index=index, similarity_top_k=5)
        response_synthesizer = get_response_synthesizer()
        logger.info("creating the RetrieverQueryEngine instance")
        vector_query_engine = RetrieverQueryEngine(
            retriever=vector_retriever,
            response_synthesizer=response_synthesizer,
        )
        logger.info("creating the HyDEQueryTransform instance")
        hyde = HyDEQueryTransform(include_original=True)
        self.hyde_query_engine = TransformQueryEngine(vector_query_engine, hyde)

    def ask(self, query: str):
        if not self.hyde_query_engine:
            return "Please, add a PDF document first."

        logger.info("retrieving the response to the query")
        response = self.hyde_query_engine.query(str_or_query_bundle=query)
        print(response)
        return response

    def clear(self):
        self.text_chunks = []
        self.doc_ids = []
        self.nodes = []
        self.hyde_query_engine = None