Mr-Cool commited on
Commit
e40a5fc
1 Parent(s): f3a602e

Added files

Browse files
Files changed (5) hide show
  1. .gitattributes +1 -0
  2. app.py +7 -0
  3. data/nist_ai.pdf +3 -0
  4. functions.py +97 -0
  5. requirements.txt +9 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ data/nist_ai.pdf filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from functions import *
3
+
4
+ rag_poc = gr.ChatInterface(get_response)
5
+
6
+ if __name__ == "__main__":
7
+ rag_poc.launch()
data/nist_ai.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b98f5f456157e2de607322a9a2630175f93683754a455c469c0954e4e94a1b1c
3
+ size 1204825
functions.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_community.document_loaders import PyMuPDFLoader
2
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
3
+ from langchain_openai import AzureOpenAIEmbeddings, AzureChatOpenAI
4
+ from operator import itemgetter
5
+ from langchain_core.runnables import RunnablePassthrough
6
+ from langchain_qdrant import QdrantVectorStore
7
+ from qdrant_client import QdrantClient
8
+ from qdrant_client.http.models import Distance, VectorParams
9
+ from langchain.prompts import ChatPromptTemplate
10
+ import tiktoken
11
+ import os
12
+
13
+ ### SETUP FUNCTIONS ###
14
+ def tiktoken_len(text):
15
+ tokens = tiktoken.encoding_for_model("gpt-4o").encode(
16
+ text,
17
+ )
18
+ return len(tokens)
19
+
20
+ def setup_vector_db():
21
+
22
+ # Get the directory of the current file
23
+ current_file_directory = os.path.dirname(os.path.abspath(__file__))
24
+ # Change the working directory to the current file's directory
25
+ os.chdir(current_file_directory)
26
+
27
+ # Load the NIST AI document
28
+ PDF_LINK = "data/nist_ai.pdf"
29
+ loader = PyMuPDFLoader(file_path=PDF_LINK)
30
+ nist_doc = loader.load()
31
+
32
+ text_splitter = RecursiveCharacterTextSplitter(
33
+ chunk_size = 500,
34
+ chunk_overlap = 100,
35
+ length_function = tiktoken_len,
36
+ )
37
+
38
+ nist_chunks = text_splitter.split_documents(nist_doc)
39
+
40
+ embeddings_small = AzureOpenAIEmbeddings(azure_deployment="text-embedding-3-small")
41
+
42
+ qdrant_client = QdrantClient(":memory:") # set Qdrant DB and its location (in-memory)
43
+
44
+ qdrant_client.create_collection(
45
+ collection_name="NIST_AI",
46
+ vectors_config=VectorParams(size=1536, distance=Distance.COSINE),
47
+ )
48
+
49
+ qdrant_vector_store = QdrantVectorStore(
50
+ client=qdrant_client,
51
+ collection_name="NIST_AI",
52
+ embedding=embeddings_small,
53
+ ) # create a QdrantVectorStore object with the above specified client, collection name, and embedding model.
54
+
55
+ qdrant_vector_store.add_documents(nist_chunks) # add the documents to the QdrantVectorStore
56
+
57
+ retriever = qdrant_vector_store.as_retriever()
58
+
59
+ return retriever
60
+
61
+ ### VARIABLES ###
62
+
63
+ # define a global variable to store the retriever object
64
+ retriever = setup_vector_db()
65
+ qa_gpt4_llm = AzureChatOpenAI(azure_deployment="gpt-4", temperature=0) # GPT-4o model
66
+
67
+ # define a template for the RAG model
68
+ rag_template = """
69
+ You are a helpful assistant that helps users find information and answer their question.
70
+ You MUST use ONLY the available context to answer the question.
71
+ If necessary information to answer the question cannot be found in the provided context, you MUST "I don't know."
72
+
73
+ Question:
74
+ {question}
75
+
76
+ Context:
77
+ {context}
78
+ """
79
+ # create rag prompt object from the template
80
+ prompt = ChatPromptTemplate.from_template(rag_template)
81
+
82
+ # update the chain with LLM, prompt, and question variable.
83
+ retrieval_augmented_qa_chain = (
84
+ {"context": itemgetter("question") | retriever, "question": itemgetter("question")}
85
+ | RunnablePassthrough.assign(context=itemgetter("context"))
86
+ | {"response": prompt | qa_gpt4_llm, "context": itemgetter("context"), "question": itemgetter("question")}
87
+ )
88
+
89
+ ### FUNCTIONS ###
90
+
91
+
92
+ def get_response(query, history):
93
+ """A helper function to get the response from the RAG model and return it to the UI."""
94
+
95
+ response = retrieval_augmented_qa_chain.invoke({"question" : query})
96
+
97
+ return response["response"].content
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ langchain-community
2
+ gradio
3
+ langchain-openai
4
+ langchain
5
+ qdrant-client
6
+ tiktoken
7
+ langchain-qdrant
8
+ PyMuPDF
9
+ langchain_core