File size: 4,046 Bytes
1e61831
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import re

from langchain_openai import OpenAIEmbeddings
from langchain_openai import ChatOpenAI
from langchain_openai.embeddings import OpenAIEmbeddings

from langchain.prompts import ChatPromptTemplate
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.schema import StrOutputParser

from langchain_community.document_loaders import PyMuPDFLoader
from langchain_community.vectorstores import Qdrant

from langchain_core.runnables import RunnablePassthrough, RunnableParallel
from langchain_core.documents import Document

from operator import itemgetter
import os
from dotenv import load_dotenv
import chainlit as cl

load_dotenv()


ai_framework_document = PyMuPDFLoader(file_path="https://nvlpubs.nist.gov/nistpubs/ai/NIST.AI.600-1.pdf").load()
ai_blueprint_document = PyMuPDFLoader(file_path="https://www.whitehouse.gov/wp-content/uploads/2022/10/Blueprint-for-an-AI-Bill-of-Rights.pdf").load()


def metadata_generator(document, name):
    fixed_text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=500,
        chunk_overlap=100,
        separators=["\n\n", "\n", ".", "!", "?"]
    )
    collection = fixed_text_splitter.split_documents(document)
    for doc in collection:
        doc.metadata["source"] = name
    return collection

recursive_framework_document = metadata_generator(ai_framework_document, "AI Framework")
recursive_blueprint_document = metadata_generator(ai_blueprint_document, "AI Blueprint")
combined_documents = recursive_framework_document + recursive_blueprint_document

embeddings = OpenAIEmbeddings(model="text-embedding-3-small")

vectorstore = Qdrant.from_documents(
    documents=combined_documents,
    embedding=embeddings,
    location=":memory:",
    collection_name="ai_policy"
)
alt_retriever = vectorstore.as_retriever()

## Generation LLM
llm = ChatOpenAI(model="gpt-4o-mini")

RAG_PROMPT = """\
You are an AI Policy Expert. 
Given a provided context and question, you must answer the question based only on context. 
Think through your answer carefully and step by step. 

Context: {context}
Question: {question}
"""

rag_prompt = ChatPromptTemplate.from_template(RAG_PROMPT)

retrieval_augmented_qa_chain = (
    # INVOKE CHAIN WITH: {"question" : "<<SOME USER QUESTION>>"}
    # "question" : populated by getting the value of the "question" key
    # "context"  : populated by getting the value of the "question" key and chaining it into the base_retriever
    {"context": itemgetter("question") | alt_retriever, "question": itemgetter("question")}
    # "context"  : is assigned to a RunnablePassthrough object (will not be called or considered in the next step)
    #              by getting the value of the "context" key from the previous step
    | RunnablePassthrough.assign(context=itemgetter("context"))
    # "response" : the "context" and "question" values are used to format our prompt object and then piped
    #              into the LLM and stored in a key called "response"
    # "context"  : populated by getting the value of the "context" key from the previous step
    | {"response": rag_prompt | llm, "context": itemgetter("context")}
)

#alt_rag_chain.invoke({"question" : "What is the AI framework all about?"})

@cl.on_message
async def handle_message(message):
    try:
        # Process the incoming question using the RAG chain
        result = retrieval_augmented_qa_chain.invoke({"question": message.content})

        # Create a new message for the response
        response_message = cl.Message(content=result["response"].content)

        # Send the response back to the user
        await response_message.send()
    
    except Exception as e:
        # Handle any exception and log it or send a response back to the user
        error_message = cl.Message(content=f"An error occurred: {str(e)}")
        await error_message.send()
        print(f"Error occurred: {e}")

# Run the ChainLit server
if __name__ == "__main__":
    try:
        cl.run()
    except Exception as e:
        print(f"Server error occurred: {e}")