mitulagr2 commited on
Commit
b5f36b8
1 Parent(s): 5c1d000

Migrate to Llama.cpp

Browse files
Files changed (2) hide show
  1. app/rag.py +92 -40
  2. requirements.txt +5 -2
app/rag.py CHANGED
@@ -1,59 +1,100 @@
1
- import os
2
- import logging
3
-
4
  from llama_index.core import (
5
  SimpleDirectoryReader,
6
- VectorStoreIndex,
7
  StorageContext,
8
  Settings,
9
  get_response_synthesizer)
10
  from llama_index.core.query_engine import RetrieverQueryEngine, TransformQueryEngine
11
  from llama_index.core.node_parser import SentenceSplitter
12
  from llama_index.core.schema import TextNode, MetadataMode
13
- from llama_index.vector_stores.qdrant import QdrantVectorStore
14
- from llama_index.embeddings.ollama import OllamaEmbedding
15
- from llama_index.llms.ollama import Ollama
16
  from llama_index.core.retrievers import VectorIndexRetriever
17
- from llama_index.core.indices.query.query_transform import HyDEQueryTransform
 
 
 
 
 
 
18
  from qdrant_client import QdrantClient
 
 
 
 
19
 
20
- QDRANT_API_URL = os.getenv('QDRANT_API_URL')
21
- QDRANT_API_KEY = os.getenv('QDRANT_API_KEY')
22
 
23
 
24
  class ChatPDF:
25
- hyde_query_engine = None
26
- text_parser = None
27
- vector_store = None
28
- embed_model = None
29
- logger = None
30
 
31
- def __init__(self):
32
- logging.basicConfig(level=logging.INFO)
33
- self.logger = logging.getLogger(__name__)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
- self.text_parser = SentenceSplitter(chunk_size=512, chunk_overlap=100)
 
 
 
 
 
36
 
37
  self.logger.info("initializing the vector store related objects")
38
- client = QdrantClient(url=QDRANT_API_URL, api_key=QDRANT_API_KEY)
 
39
  self.vector_store = QdrantVectorStore(client=client, collection_name="rag_documents")
40
 
41
- self.logger.info("initializing the OllamaEmbedding")
42
- self.embed_model = OllamaEmbedding(model_name='mxbai-embed-large')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  self.logger.info("initializing the global settings")
 
44
  Settings.embed_model = self.embed_model
45
- Settings.llm = Ollama(model="qwen:1.8b", request_timeout=1000000)
46
- Settings.transformations = [self.text_parser]
 
47
 
48
- def ingest(self, dir_path: str):
49
- docs = SimpleDirectoryReader(input_dir=dir_path).load_data()
50
  text_chunks = []
51
  doc_ids = []
52
  nodes = []
53
 
 
 
54
  self.logger.info("enumerating docs")
55
  for doc_idx, doc in enumerate(docs):
56
- curr_text_chunks = self.text_parser.split_text(doc.text)
57
  text_chunks.extend(curr_text_chunks)
58
  doc_ids.extend([doc_idx] * len(curr_text_chunks))
59
 
@@ -80,26 +121,37 @@ class ChatPDF:
80
  transformations=Settings.transformations,
81
  )
82
 
83
- self.logger.info("initializing the VectorIndexRetriever with top_k as 5")
84
- vector_retriever = VectorIndexRetriever(index=index, similarity_top_k=5)
85
- response_synthesizer = get_response_synthesizer()
86
- self.logger.info("creating the RetrieverQueryEngine instance")
87
- vector_query_engine = RetrieverQueryEngine(
88
- retriever=vector_retriever,
 
 
 
 
 
 
 
 
 
 
89
  response_synthesizer=response_synthesizer,
90
  )
91
- self.logger.info("creating the HyDEQueryTransform instance")
92
- hyde = HyDEQueryTransform(include_original=True)
93
- self.hyde_query_engine = TransformQueryEngine(vector_query_engine, hyde)
 
94
 
95
  def ask(self, query: str):
96
- if not self.hyde_query_engine:
97
  return "Please, add a PDF document first."
98
 
99
  self.logger.info("retrieving the response to the query")
100
- response = self.hyde_query_engine.query(str_or_query_bundle=query)
101
- self.logger.info(response)
102
  return response
103
 
104
  def clear(self):
105
- self.hyde_query_engine = None
 
 
 
 
1
  from llama_index.core import (
2
  SimpleDirectoryReader,
3
+ # VectorStoreIndex,
4
  StorageContext,
5
  Settings,
6
  get_response_synthesizer)
7
  from llama_index.core.query_engine import RetrieverQueryEngine, TransformQueryEngine
8
  from llama_index.core.node_parser import SentenceSplitter
9
  from llama_index.core.schema import TextNode, MetadataMode
 
 
 
10
  from llama_index.core.retrievers import VectorIndexRetriever
11
+ # from llama_index.core.indices.query.query_transform import HyDEQueryTransform
12
+
13
+ from llama_index.core.response_synthesizers import ResponseMode
14
+ # from transformers import AutoTokenizer
15
+ from llama_index.core.vector_stores import VectorStoreQuery
16
+ from llama_index.core.indices.vector_store.base import VectorStoreIndex
17
+ from llama_index.vector_stores.qdrant import QdrantVectorStore
18
  from qdrant_client import QdrantClient
19
+ import logging
20
+
21
+ from llama_index.llms.llama_cpp import LlamaCPP
22
+ from llama_index.embeddings.fastembed import FastEmbedEmbedding
23
 
 
 
24
 
25
 
26
  class ChatPDF:
27
+ logging.basicConfig(level=logging.INFO)
28
+ logger = logging.getLogger(__name__)
29
+ query_engine = None
 
 
30
 
31
+ # model_url = "https://huggingface.co/Qwen/Qwen1.5-1.8B-Chat-GGUF/resolve/main/qwen1_5-1_8b-chat-q8_0.gguf"
32
+ model_url = "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-gguf/resolve/main/Phi-3-mini-4k-instruct-q4.gguf"
33
+
34
+ def messages_to_prompt(messages):
35
+ prompt = ""
36
+ for message in messages:
37
+ if message.role == 'system':
38
+ prompt += f"<|system|>\n{message.content}</s>\n"
39
+ elif message.role == 'user':
40
+ prompt += f"<|user|>\n{message.content}</s>\n"
41
+ elif message.role == 'assistant':
42
+ prompt += f"<|assistant|>\n{message.content}</s>\n"
43
+
44
+ if not prompt.startswith("<|system|>\n"):
45
+ prompt = "<|system|>\n</s>\n" + prompt
46
+
47
+ prompt = prompt + "<|assistant|>\n"
48
+
49
+ return prompt
50
 
51
+ def completion_to_prompt(completion):
52
+ return f"<|system|>\n</s>\n<|user|>\n{completion}</s>\n<|assistant|>\n"
53
+
54
+
55
+ def __init__(self):
56
+ text_parser = SentenceSplitter(chunk_size=512, chunk_overlap=20)
57
 
58
  self.logger.info("initializing the vector store related objects")
59
+ # client = QdrantClient(host="localhost", port=6333)
60
+ client = QdrantClient(":memory:")
61
  self.vector_store = QdrantVectorStore(client=client, collection_name="rag_documents")
62
 
63
+ self.logger.info("initializing the FastEmbedEmbedding")
64
+ self.embed_model = FastEmbedEmbedding(model_name="BAAI/bge-small-en")
65
+
66
+ llm = LlamaCPP(
67
+ # model_url=self.model_url,
68
+ temperature=0.1,
69
+ max_new_tokens=256,
70
+ context_window=3900,
71
+ # generate_kwargs={},
72
+ model_kwargs={"n_gpu_layers": -1},
73
+ messages_to_prompt=self.messages_to_prompt,
74
+ completion_to_prompt=self.completion_to_prompt,
75
+ verbose=True,
76
+ )
77
+
78
+ # tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
79
+ # tokenizer.save_pretrained("./models/tokenizer/")
80
+
81
  self.logger.info("initializing the global settings")
82
+ Settings.text_splitter = text_parser
83
  Settings.embed_model = self.embed_model
84
+ Settings.llm = llm
85
+ # Settings.tokenzier = tokenizer
86
+ Settings.transformations = [text_parser]
87
 
88
+ def ingest(self, pdf_file_path: str):
 
89
  text_chunks = []
90
  doc_ids = []
91
  nodes = []
92
 
93
+ docs = SimpleDirectoryReader(input_dir="files").load_data()
94
+
95
  self.logger.info("enumerating docs")
96
  for doc_idx, doc in enumerate(docs):
97
+ curr_text_chunks = text_parser.split_text(doc.text)
98
  text_chunks.extend(curr_text_chunks)
99
  doc_ids.extend([doc_idx] * len(curr_text_chunks))
100
 
 
121
  transformations=Settings.transformations,
122
  )
123
 
124
+ self.logger.info("configure retriever")
125
+ retriever = VectorIndexRetriever(
126
+ index=index,
127
+ similarity_top_k=6,
128
+ vector_store_query_mode="hybrid"
129
+ )
130
+
131
+ self.logger.info("configure response synthesizer")
132
+ response_synthesizer = get_response_synthesizer(
133
+ # streaming=True,
134
+ response_mode=ResponseMode.COMPACT,
135
+ )
136
+
137
+ self.logger.info("assemble query engine")
138
+ self.query_engine = RetrieverQueryEngine(
139
+ retriever=retriever,
140
  response_synthesizer=response_synthesizer,
141
  )
142
+
143
+ # self.logger.info("creating the HyDEQueryTransform instance")
144
+ # hyde = HyDEQueryTransform(include_original=True)
145
+ # self.hyde_query_engine = TransformQueryEngine(vector_query_engine, hyde)
146
 
147
  def ask(self, query: str):
148
+ if not self.query_engine:
149
  return "Please, add a PDF document first."
150
 
151
  self.logger.info("retrieving the response to the query")
152
+ response = self.query_engine.query(str_or_query_bundle=query)
153
+ print(response)
154
  return response
155
 
156
  def clear(self):
157
+ self.query_engine = None
requirements.txt CHANGED
@@ -1,5 +1,8 @@
1
  fastapi
2
  llama-index
3
  llama-index-vector-stores-qdrant
4
- llama-index-embeddings-ollama
5
- llama-index-llms-ollama
 
 
 
 
1
  fastapi
2
  llama-index
3
  llama-index-vector-stores-qdrant
4
+ qdrant-client
5
+ python-dotenv
6
+ llama-index-llms-llama-cpp
7
+ llama-index-embeddings-fastembed
8
+ fastembed==0.2.7