File size: 4,150 Bytes
9151071
 
1ff6584
 
683c59a
1ff6584
 
 
 
 
b5f36b8
 
 
794ae55
 
63a6e05
5e8fd8b
9151071
 
 
bdc84e2
 
 
5e8fd8b
4c88907
 
b5f36b8
3e1829f
1ff6584
bdc84e2
b5f36b8
3e1829f
9151071
 
 
 
 
1ff6584
bdc84e2
9151071
 
 
b5f36b8
 
794ae55
b5f36b8
 
4929aba
b5f36b8
dbeb658
b5808ba
 
b5f36b8
 
 
 
 
 
bdc84e2
881c0e5
5c1d000
b5f36b8
 
881c0e5
5e8fd8b
8b90c15
5c1d000
 
 
5e8fd8b
8b90c15
b5f36b8
bdc84e2
1ff6584
4c88907
881c0e5
1ff6584
 
 
bdc84e2
1ff6584
 
 
 
 
 
bdc84e2
1ff6584
5c1d000
1ff6584
 
 
 
bdc84e2
5c1d000
bdc84e2
1ff6584
 
 
 
f1cf709
5e8fd8b
2d12855
 
794ae55
2d12855
b5f36b8
d04ea2c
bdc84e2
4c88907
 
 
2d12855
d04ea2c
5e8fd8b
 
4c88907
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import os
import logging
from llama_index.core import (
    SimpleDirectoryReader,
    VectorStoreIndex,
    StorageContext,
    Settings,
    get_response_synthesizer)
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.schema import TextNode, MetadataMode
from llama_index.core.vector_stores import VectorStoreQuery
from llama_index.llms.llama_cpp import LlamaCPP
from llama_index.embeddings.fastembed import FastEmbedEmbedding
from llama_index.vector_stores.qdrant import QdrantVectorStore
from qdrant_client import QdrantClient


QDRANT_API_URL = os.getenv('QDRANT_API_URL')
QDRANT_API_KEY = os.getenv('QDRANT_API_KEY')

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class ChatPDF:
    pdf_count = 0

    def __init__(self):
        self.text_parser = SentenceSplitter(chunk_size=512, chunk_overlap=24)

        logger.info("initializing the vector store related objects")
        # client = QdrantClient(host="localhost", port=6333)
        client = QdrantClient(url=QDRANT_API_URL, api_key=QDRANT_API_KEY)
        self.vector_store = QdrantVectorStore(
            client=client,
            collection_name="rag_documents",
            # enable_hybrid=True
        )

        logger.info("initializing the FastEmbedEmbedding")
        self.embed_model = FastEmbedEmbedding(
            # model_name="BAAI/bge-small-en"
        )

        llm = LlamaCPP(
            model_url="https://huggingface.co/Qwen/Qwen1.5-1.8B-Chat-GGUF/resolve/main/qwen1_5-1_8b-chat-q4_k_m.gguf",
            temperature=0.1,
            max_new_tokens=256,
            context_window=3900, #32k
            # generate_kwargs={},
            # model_kwargs={"n_gpu_layers": -1},
            # messages_to_prompt=self.messages_to_prompt,
            # completion_to_prompt=self.completion_to_prompt,
            verbose=True,
        )

        # tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
        # tokenizer.save_pretrained("./models/tokenizer/")

        logger.info("initializing the global settings")
        Settings.text_splitter = self.text_parser
        Settings.embed_model = self.embed_model
        Settings.llm = llm
        # Settings.tokenzier = tokenizer
        Settings.transformations = [self.text_parser]

    def ingest(self, files_dir: str):
        text_chunks = []
        doc_ids = []
        nodes = []

        docs = SimpleDirectoryReader(input_dir=files_dir).load_data()

        logger.info("enumerating docs")
        for doc_idx, doc in enumerate(docs):
            self.pdf_count = self.pdf_count + 1
            curr_text_chunks = self.text_parser.split_text(doc.text)
            text_chunks.extend(curr_text_chunks)
            doc_ids.extend([doc_idx] * len(curr_text_chunks))

        logger.info("enumerating text_chunks")
        for idx, text_chunk in enumerate(text_chunks):
            node = TextNode(text=text_chunk)
            src_doc = docs[doc_ids[idx]]
            node.metadata = src_doc.metadata
            nodes.append(node)

        logger.info("enumerating nodes")
        for node in nodes:
            node_embedding = self.embed_model.get_text_embedding(
                node.get_content(metadata_mode=MetadataMode.ALL)
            )
            node.embedding = node_embedding

        logger.info("initializing the storage context")
        storage_context = StorageContext.from_defaults(vector_store=self.vector_store)
        logger.info("indexing the nodes in VectorStoreIndex")
        index = VectorStoreIndex(
            nodes=nodes,
            storage_context=storage_context,
            transformations=Settings.transformations,
        )

        self.query_engine = index.as_query_engine(
            streaming=True,
            similarity_top_k=6,
        )

    def ask(self, query: str):
        logger.info("retrieving the response to the query")
        if not self.pdf_count > 0:
            return "Please, add a PDF document first."

        streaming_response = self.query_engine.query(query)
        return streaming_response.response_gen

    def clear(self):
        self.pdf_count = 0