Cheselle commited on
Commit
1e61831
1 Parent(s): c75b207

Added Base RAG

Browse files
Files changed (3) hide show
  1. Dockerfile +11 -0
  2. app.py +107 -0
  3. requirements.txt +96 -0
Dockerfile ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9
2
+ RUN useradd -m -u 1000 user
3
+ USER user
4
+ ENV HOME=/home/user \
5
+ PATH=/home/user/.local/bin:$PATH
6
+ WORKDIR $HOME/app
7
+ COPY --chown=user . $HOME/app
8
+ COPY ./requirements.txt ~/app/requirements.txt
9
+ RUN pip install -r requirements.txt
10
+ COPY . .
11
+ CMD ["chainlit", "run", "app.py", "--port", "7860"]
app.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ from langchain_openai import OpenAIEmbeddings
4
+ from langchain_openai import ChatOpenAI
5
+ from langchain_openai.embeddings import OpenAIEmbeddings
6
+
7
+ from langchain.prompts import ChatPromptTemplate
8
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
9
+ from langchain.schema import StrOutputParser
10
+
11
+ from langchain_community.document_loaders import PyMuPDFLoader
12
+ from langchain_community.vectorstores import Qdrant
13
+
14
+ from langchain_core.runnables import RunnablePassthrough, RunnableParallel
15
+ from langchain_core.documents import Document
16
+
17
+ from operator import itemgetter
18
+ import os
19
+ from dotenv import load_dotenv
20
+ import chainlit as cl
21
+
22
+ load_dotenv()
23
+
24
+
25
+ ai_framework_document = PyMuPDFLoader(file_path="https://nvlpubs.nist.gov/nistpubs/ai/NIST.AI.600-1.pdf").load()
26
+ ai_blueprint_document = PyMuPDFLoader(file_path="https://www.whitehouse.gov/wp-content/uploads/2022/10/Blueprint-for-an-AI-Bill-of-Rights.pdf").load()
27
+
28
+
29
+ def metadata_generator(document, name):
30
+ fixed_text_splitter = RecursiveCharacterTextSplitter(
31
+ chunk_size=500,
32
+ chunk_overlap=100,
33
+ separators=["\n\n", "\n", ".", "!", "?"]
34
+ )
35
+ collection = fixed_text_splitter.split_documents(document)
36
+ for doc in collection:
37
+ doc.metadata["source"] = name
38
+ return collection
39
+
40
+ recursive_framework_document = metadata_generator(ai_framework_document, "AI Framework")
41
+ recursive_blueprint_document = metadata_generator(ai_blueprint_document, "AI Blueprint")
42
+ combined_documents = recursive_framework_document + recursive_blueprint_document
43
+
44
+ embeddings = OpenAIEmbeddings(model="text-embedding-3-small")
45
+
46
+ vectorstore = Qdrant.from_documents(
47
+ documents=combined_documents,
48
+ embedding=embeddings,
49
+ location=":memory:",
50
+ collection_name="ai_policy"
51
+ )
52
+ alt_retriever = vectorstore.as_retriever()
53
+
54
+ ## Generation LLM
55
+ llm = ChatOpenAI(model="gpt-4o-mini")
56
+
57
+ RAG_PROMPT = """\
58
+ You are an AI Policy Expert.
59
+ Given a provided context and question, you must answer the question based only on context.
60
+ Think through your answer carefully and step by step.
61
+
62
+ Context: {context}
63
+ Question: {question}
64
+ """
65
+
66
+ rag_prompt = ChatPromptTemplate.from_template(RAG_PROMPT)
67
+
68
+ retrieval_augmented_qa_chain = (
69
+ # INVOKE CHAIN WITH: {"question" : "<<SOME USER QUESTION>>"}
70
+ # "question" : populated by getting the value of the "question" key
71
+ # "context" : populated by getting the value of the "question" key and chaining it into the base_retriever
72
+ {"context": itemgetter("question") | alt_retriever, "question": itemgetter("question")}
73
+ # "context" : is assigned to a RunnablePassthrough object (will not be called or considered in the next step)
74
+ # by getting the value of the "context" key from the previous step
75
+ | RunnablePassthrough.assign(context=itemgetter("context"))
76
+ # "response" : the "context" and "question" values are used to format our prompt object and then piped
77
+ # into the LLM and stored in a key called "response"
78
+ # "context" : populated by getting the value of the "context" key from the previous step
79
+ | {"response": rag_prompt | llm, "context": itemgetter("context")}
80
+ )
81
+
82
+ #alt_rag_chain.invoke({"question" : "What is the AI framework all about?"})
83
+
84
+ @cl.on_message
85
+ async def handle_message(message):
86
+ try:
87
+ # Process the incoming question using the RAG chain
88
+ result = retrieval_augmented_qa_chain.invoke({"question": message.content})
89
+
90
+ # Create a new message for the response
91
+ response_message = cl.Message(content=result["response"].content)
92
+
93
+ # Send the response back to the user
94
+ await response_message.send()
95
+
96
+ except Exception as e:
97
+ # Handle any exception and log it or send a response back to the user
98
+ error_message = cl.Message(content=f"An error occurred: {str(e)}")
99
+ await error_message.send()
100
+ print(f"Error occurred: {e}")
101
+
102
+ # Run the ChainLit server
103
+ if __name__ == "__main__":
104
+ try:
105
+ cl.run()
106
+ except Exception as e:
107
+ print(f"Server error occurred: {e}")
requirements.txt ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1
2
+ aiohappyeyeballs==2.4.0
3
+ aiohttp==3.10.5
4
+ aiosignal==1.3.1
5
+ annotated-types==0.7.0
6
+ anyio==3.7.1
7
+ asyncer==0.0.2
8
+ attrs==24.2.0
9
+ bidict==0.23.1
10
+ certifi==2024.8.30
11
+ chainlit==0.7.700
12
+ charset-normalizer==3.3.2
13
+ click==8.1.7
14
+ dataclasses-json==0.5.14
15
+ Deprecated==1.2.14
16
+ distro==1.9.0
17
+ fastapi==0.100.1
18
+ fastapi-socketio==0.0.10
19
+ filetype==1.2.0
20
+ frozenlist==1.4.1
21
+ googleapis-common-protos==1.65.0
22
+ grpcio==1.66.1
23
+ grpcio-tools==1.62.3
24
+ h11==0.14.0
25
+ h2==4.1.0
26
+ hpack==4.0.0
27
+ httpcore==0.17.3
28
+ httpx==0.24.1
29
+ hyperframe==6.0.1
30
+ idna==3.10
31
+ importlib_metadata==8.4.0
32
+ jiter==0.5.0
33
+ jsonpatch==1.33
34
+ jsonpointer==3.0.0
35
+ langchain==0.2.16
36
+ langchain-community==0.2.17
37
+ langchain-core==0.2.41
38
+ langchain-experimental==0.0.65
39
+ langchain-openai==0.1.25
40
+ langchain-qdrant==0.1.4
41
+ langchain-text-splitters==0.2.4
42
+ langsmith==0.1.125
43
+ Lazify==0.4.0
44
+ marshmallow==3.22.0
45
+ multidict==6.1.0
46
+ mypy-extensions==1.0.0
47
+ nest-asyncio==1.6.0
48
+ numpy==1.26.4
49
+ openai==1.46.1
50
+ opentelemetry-api==1.27.0
51
+ opentelemetry-exporter-otlp==1.27.0
52
+ opentelemetry-exporter-otlp-proto-common==1.27.0
53
+ opentelemetry-exporter-otlp-proto-grpc==1.27.0
54
+ opentelemetry-exporter-otlp-proto-http==1.27.0
55
+ opentelemetry-instrumentation==0.48b0
56
+ opentelemetry-proto==1.27.0
57
+ opentelemetry-sdk==1.27.0
58
+ opentelemetry-semantic-conventions==0.48b0
59
+ orjson==3.10.7
60
+ packaging==23.2
61
+ portalocker==2.10.1
62
+ protobuf==4.25.5
63
+ pydantic==2.9.2
64
+ pydantic_core==2.23.4
65
+ PyJWT==2.9.0
66
+ PyMuPDF==1.24.10
67
+ PyMuPDFb==1.24.10
68
+ python-dotenv==1.0.1
69
+ python-engineio==4.9.1
70
+ python-graphql-client==0.4.3
71
+ python-multipart==0.0.6
72
+ python-socketio==5.11.4
73
+ PyYAML==6.0.2
74
+ qdrant-client==1.11.2
75
+ regex==2024.9.11
76
+ requests==2.32.3
77
+ simple-websocket==1.0.0
78
+ sniffio==1.3.1
79
+ SQLAlchemy==2.0.35
80
+ starlette==0.27.0
81
+ syncer==2.0.3
82
+ tenacity==8.5.0
83
+ tiktoken==0.7.0
84
+ tomli==2.0.1
85
+ tqdm==4.66.5
86
+ typing-inspect==0.9.0
87
+ typing_extensions==4.12.2
88
+ uptrace==1.26.0
89
+ urllib3==2.2.3
90
+ uvicorn==0.23.2
91
+ watchfiles==0.20.0
92
+ websockets==13.0.1
93
+ wrapt==1.16.0
94
+ wsproto==1.2.0
95
+ yarl==1.11.1
96
+ zipp==3.20.2