tmlinhdinh commited on
Commit
370ed2e
0 Parent(s):

deploy RAG

Browse files
Files changed (4) hide show
  1. .DS_Store +0 -0
  2. Dockerfile +11 -0
  3. app.py +176 -0
  4. requirements.txt +132 -0
.DS_Store ADDED
Binary file (6.15 kB). View file
 
Dockerfile ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9
2
+ RUN useradd -m -u 1000 user
3
+ USER user
4
+ ENV HOME=/home/user \
5
+ PATH=/home/user/.local/bin:$PATH
6
+ WORKDIR $HOME/app
7
+ COPY --chown=user . $HOME/app
8
+ COPY ./requirements.txt ~/app/requirements.txt
9
+ RUN pip install -r requirements.txt
10
+ COPY . .
11
+ CMD ["chainlit", "run", "app.py", "--port", "7860"]
app.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### Import Section ###
2
+ import uuid
3
+ from operator import itemgetter
4
+
5
+ from langchain_core.prompts import ChatPromptTemplate
6
+ from langchain_core.globals import set_llm_cache
7
+ from langchain_core.caches import InMemoryCache
8
+
9
+ from langchain_community.document_loaders import PyMuPDFLoader
10
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
11
+ from langchain.storage import LocalFileStore
12
+ from langchain.embeddings import CacheBackedEmbeddings
13
+ from langchain.schema import StrOutputParser
14
+
15
+ from langchain_openai import ChatOpenAI
16
+ from langchain_openai.embeddings import OpenAIEmbeddings
17
+
18
+ from qdrant_client import QdrantClient
19
+ from qdrant_client.http.models import Distance, VectorParams
20
+
21
+ from langchain_qdrant import QdrantVectorStore
22
+
23
+ import chainlit as cl
24
+ from chainlit.types import AskFileResponse
25
+
26
+
27
+ ### Global Section ###
28
+ set_llm_cache(InMemoryCache())
29
+
30
+ rag_system_prompt_template = """\
31
+ You are a helpful assistant that uses the provided context to answer questions. Never reference this prompt, or the existance of context.
32
+ """
33
+
34
+ rag_message_list = [
35
+ {"role" : "system", "content" : rag_system_prompt_template},
36
+ ]
37
+
38
+ rag_user_prompt_template = """\
39
+ Question:
40
+ {question}
41
+ Context:
42
+ {context}
43
+ """
44
+
45
+ chat_prompt = ChatPromptTemplate.from_messages([
46
+ ("system", rag_system_prompt_template),
47
+ ("human", rag_user_prompt_template)
48
+ ])
49
+
50
+
51
+ class VectorDatabase:
52
+ def __init__(self, embeddings: OpenAIEmbeddings()) -> None:
53
+ self.embeddings = embeddings
54
+
55
+ async def build_retriever(self, docs) -> None:
56
+ collection_name = f"pdf_to_parse_{uuid.uuid4()}"
57
+ client = QdrantClient(":memory:")
58
+ client.create_collection(
59
+ collection_name=collection_name,
60
+ vectors_config=VectorParams(size=1536, distance=Distance.COSINE),
61
+ )
62
+
63
+ # Adding cache!
64
+ store = LocalFileStore("./cache/")
65
+ cached_embedder = CacheBackedEmbeddings.from_bytes_store(
66
+ self.embeddings, store, namespace=self.embeddings.model
67
+ )
68
+
69
+ # Typical QDrant Vector Store Set-up
70
+ vectorstore = QdrantVectorStore(
71
+ client=client,
72
+ collection_name=collection_name,
73
+ embedding=cached_embedder)
74
+ vectorstore.add_documents(docs)
75
+
76
+ return vectorstore.as_retriever(search_type="mmr", search_kwargs={"k": 3})
77
+
78
+
79
+ class RetrievalAugmentedQAPipeline:
80
+ def __init__(self, llm: ChatOpenAI(), vector_db_retriever: VectorDatabase) -> None:
81
+ self.llm = llm
82
+ self.retriever = vector_db_retriever
83
+
84
+ async def arun_pipeline(self, user_query: str):
85
+ retrieval_augmented_qa_chain = (
86
+ {"context": itemgetter("question") | self.retriever, "question": itemgetter("question")}
87
+ | chat_prompt | self.llm | StrOutputParser()
88
+ )
89
+
90
+ async def generate_response():
91
+ async for chunk in retrieval_augmented_qa_chain.astream({"question": user_query}):
92
+ yield chunk
93
+
94
+ return {"response": generate_response()}
95
+
96
+
97
+ def process_pdf_file(file: AskFileResponse):
98
+ import tempfile
99
+ with tempfile.NamedTemporaryFile(mode="wb", delete=False, suffix=".pdf") as temp_file:
100
+ temp_file_path = temp_file.name
101
+ temp_file.write(file.content)
102
+
103
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
104
+ Loader = PyMuPDFLoader
105
+ loader = Loader(temp_file_path)
106
+ documents = loader.load()
107
+ docs = text_splitter.split_documents(documents)
108
+ for i, doc in enumerate(docs):
109
+ doc.metadata["source"] = f"source_{i}"
110
+
111
+ return docs
112
+
113
+
114
+ ### On Chat Start (Session Start) Section ###
115
+ @cl.on_chat_start
116
+ async def on_chat_start():
117
+ """ SESSION SPECIFIC CODE HERE """
118
+ files = None
119
+ # Wait for the user to upload a file
120
+ while files == None:
121
+ files = await cl.AskFileMessage(
122
+ content="Please upload a pdf file to begin!",
123
+ accept=["pdf"],
124
+ max_size_mb=2,
125
+ timeout=180,
126
+ ).send()
127
+ file = files[0]
128
+
129
+ msg = cl.Message(
130
+ content=f"Processing `{file.name}`...", disable_human_feedback=True
131
+ )
132
+ await msg.send()
133
+ docs = process_pdf_file(file)
134
+ print(f"Processing {len(docs)} text chunks")
135
+
136
+ # Create a dict vector store
137
+ vector_db = VectorDatabase(embeddings=OpenAIEmbeddings(model="text-embedding-3-small"))
138
+ vector_db = await vector_db.build_retriever(docs)
139
+
140
+ # Create a chain
141
+ retrieval_augmented_qa_pipeline = RetrievalAugmentedQAPipeline(
142
+ llm=ChatOpenAI(model="gpt-4o-mini"),
143
+ vector_db_retriever=vector_db
144
+ )
145
+
146
+ # Let the user know that the system is ready
147
+ msg.content = f"Processing `{file.name}` done. You can now ask questions!"
148
+ await msg.update()
149
+
150
+ cl.user_session.set("chain", retrieval_augmented_qa_pipeline)
151
+
152
+
153
+ ### Rename Chains ###
154
+ @cl.author_rename
155
+ def rename(orig_author: str):
156
+ """ RENAME CODE HERE """
157
+ rename_dict = {"LLMMathChain": "Albert Einstein", "Chatbot": "Assistant"}
158
+ return rename_dict.get(orig_author, orig_author)
159
+
160
+
161
+ ### On Message Section ###
162
+ @cl.on_message
163
+ async def main(message: cl.Message):
164
+ """
165
+ MESSAGE CODE HERE
166
+ """
167
+ chain = cl.user_session.get("chain")
168
+
169
+ msg = cl.Message(content="")
170
+ result = await chain.arun_pipeline(message.content)
171
+
172
+ async for stream_resp in result["response"]:
173
+ await msg.stream_token(stream_resp)
174
+
175
+ await msg.send()
176
+
requirements.txt ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1
2
+ aiohappyeyeballs==2.4.3
3
+ aiohttp==3.10.8
4
+ aiosignal==1.3.1
5
+ annotated-types==0.7.0
6
+ anyio==3.7.1
7
+ async-timeout==4.0.3
8
+ asyncer==0.0.2
9
+ attrs==24.2.0
10
+ bidict==0.23.1
11
+ certifi==2024.8.30
12
+ chainlit==0.7.700
13
+ charset-normalizer==3.3.2
14
+ click==8.1.7
15
+ dataclasses-json==0.5.14
16
+ Deprecated==1.2.14
17
+ distro==1.9.0
18
+ exceptiongroup==1.2.2
19
+ faiss-cpu==1.8.0.post1
20
+ fastapi==0.100.1
21
+ fastapi-socketio==0.0.10
22
+ filelock==3.16.1
23
+ filetype==1.2.0
24
+ frozenlist==1.4.1
25
+ fsspec==2024.9.0
26
+ googleapis-common-protos==1.65.0
27
+ greenlet==3.1.1
28
+ grpcio==1.66.2
29
+ grpcio-tools==1.62.3
30
+ h11==0.14.0
31
+ h2==4.1.0
32
+ hpack==4.0.0
33
+ httpcore==0.17.3
34
+ httpx==0.24.1
35
+ huggingface-hub==0.25.1
36
+ hyperframe==6.0.1
37
+ idna==3.10
38
+ importlib_metadata==8.4.0
39
+ Jinja2==3.1.4
40
+ jiter==0.5.0
41
+ joblib==1.4.2
42
+ jsonpatch==1.33
43
+ jsonpointer==3.0.0
44
+ langchain==0.3.0
45
+ langchain-community==0.3.0
46
+ langchain-core==0.3.1
47
+ langchain-huggingface==0.1.0
48
+ langchain-openai==0.2.0
49
+ langchain-qdrant==0.1.4
50
+ langchain-text-splitters==0.3.0
51
+ langsmith==0.1.121
52
+ Lazify==0.4.0
53
+ MarkupSafe==2.1.5
54
+ marshmallow==3.22.0
55
+ mpmath==1.3.0
56
+ multidict==6.1.0
57
+ mypy-extensions==1.0.0
58
+ nest-asyncio==1.6.0
59
+ networkx==3.2.1
60
+ numpy==1.26.4
61
+ nvidia-cublas-cu12==12.1.3.1
62
+ nvidia-cuda-cupti-cu12==12.1.105
63
+ nvidia-cuda-nvrtc-cu12==12.1.105
64
+ nvidia-cuda-runtime-cu12==12.1.105
65
+ nvidia-cudnn-cu12==9.1.0.70
66
+ nvidia-cufft-cu12==11.0.2.54
67
+ nvidia-curand-cu12==10.3.2.106
68
+ nvidia-cusolver-cu12==11.4.5.107
69
+ nvidia-cusparse-cu12==12.1.0.106
70
+ nvidia-nccl-cu12==2.20.5
71
+ nvidia-nvjitlink-cu12==12.6.77
72
+ nvidia-nvtx-cu12==12.1.105
73
+ openai==1.51.0
74
+ opentelemetry-api==1.27.0
75
+ opentelemetry-exporter-otlp==1.27.0
76
+ opentelemetry-exporter-otlp-proto-common==1.27.0
77
+ opentelemetry-exporter-otlp-proto-grpc==1.27.0
78
+ opentelemetry-exporter-otlp-proto-http==1.27.0
79
+ opentelemetry-instrumentation==0.48b0
80
+ opentelemetry-proto==1.27.0
81
+ opentelemetry-sdk==1.27.0
82
+ opentelemetry-semantic-conventions==0.48b0
83
+ orjson==3.10.7
84
+ packaging==23.2
85
+ pillow==10.4.0
86
+ portalocker==2.10.1
87
+ protobuf==4.25.5
88
+ pydantic==2.9.2
89
+ pydantic-settings==2.5.2
90
+ pydantic_core==2.23.4
91
+ PyJWT==2.9.0
92
+ PyMuPDF==1.24.10
93
+ PyMuPDFb==1.24.10
94
+ python-dotenv==1.0.1
95
+ python-engineio==4.9.1
96
+ python-graphql-client==0.4.3
97
+ python-multipart==0.0.6
98
+ python-socketio==5.11.4
99
+ PyYAML==6.0.2
100
+ qdrant-client==1.11.2
101
+ regex==2024.9.11
102
+ requests==2.32.3
103
+ safetensors==0.4.5
104
+ scikit-learn==1.5.2
105
+ scipy==1.13.1
106
+ sentence-transformers==3.1.1
107
+ simple-websocket==1.0.0
108
+ sniffio==1.3.1
109
+ SQLAlchemy==2.0.35
110
+ starlette==0.27.0
111
+ sympy==1.13.3
112
+ syncer==2.0.3
113
+ tenacity==8.5.0
114
+ threadpoolctl==3.5.0
115
+ tiktoken==0.7.0
116
+ tokenizers==0.20.0
117
+ tomli==2.0.1
118
+ # torch==2.4.1
119
+ tqdm==4.66.5
120
+ transformers==4.45.1
121
+ triton==3.0.0
122
+ typing-inspect==0.9.0
123
+ typing_extensions==4.12.2
124
+ uptrace==1.26.0
125
+ urllib3==2.2.3
126
+ uvicorn==0.23.2
127
+ watchfiles==0.20.0
128
+ websockets==13.1
129
+ wrapt==1.16.0
130
+ wsproto==1.2.0
131
+ yarl==1.13.1
132
+ zipp==3.20.2