File size: 5,442 Bytes
fbba7ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8c5aefd
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
from langchain_openai import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
from langchain.schema import StrOutputParser
from langchain.schema.runnable import Runnable
from langchain.schema.runnable.config import RunnableConfig
from typing import cast
import os
from langchain_community.document_loaders import PyMuPDFLoader
from langchain_experimental.text_splitter import SemanticChunker
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_openai.embeddings import OpenAIEmbeddings
from langchain_community.vectorstores import Qdrant
from langchain_core.runnables import RunnablePassthrough, RunnableParallel
from operator import itemgetter
import chainlit as cl
from openai import AsyncOpenAI
from dotenv import load_dotenv

load_dotenv()

# Set up API key for OpenAI
os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")

"""
"What is the AI Bill of Rights, and how does it affect the development of AI systems in the U.S.?"

"How is the government planning to regulate AI technologies in relation to privacy and data security?"

"What are the key principles outlined in the NIST AI Risk Management Framework?"

"How will the AI Bill of Rights affect businesses developing AI solutions for consumers?"

"What role does the government play in ensuring that AI is developed ethically and responsibly?"

"How might the outcomes of the upcoming elections impact AI regulation and policy?"

"What are the risks associated with using AI in political campaigns and decision-making?"

"How do the NIST guidelines help organizations reduce bias and ensure fairness in AI applications?"

"How are other countries approaching AI regulation compared to the U.S., and what can we learn from them?"

"What challenges do businesses face in complying with government guidelines like the AI Bill of Rights and NIST framework?"

"""
@cl.on_chat_start
async def on_chat_start():
    model = ChatOpenAI(streaming=True)
    
    # Define RAG prompt template
    prompt = ChatPromptTemplate.from_messages(
        [
            (
                "system",
                "You're a very knowledgeable AI engineer who's good at explaining stuff like ELI5."
            ),
            ("human", "{context}\n\nQuestion: {question}")
        ]
    )

    # Load documents and create retriever
    ai_framework_document = PyMuPDFLoader(file_path="https://nvlpubs.nist.gov/nistpubs/ai/NIST.AI.600-1.pdf").load()
    ai_blueprint_document = PyMuPDFLoader(file_path="https://www.whitehouse.gov/wp-content/uploads/2022/10/Blueprint-for-an-AI-Bill-of-Rights.pdf").load()


    def metadata_generator(document, name):
        fixed_text_splitter = RecursiveCharacterTextSplitter(chunk_size=500,
        chunk_overlap=100,
        separators=["\n\n", "\n", ".", "!", "?"]
    )
        collection = fixed_text_splitter.split_documents(document)
        for doc in collection:
            doc.metadata["source"] = name
        return collection

    recursive_framework_document = metadata_generator(ai_framework_document, "AI Framework")
    recursive_blueprint_document = metadata_generator(ai_blueprint_document, "AI Blueprint")
    combined_documents = recursive_framework_document + recursive_blueprint_document

    from transformers import AutoModel
    embeddings = AutoModel.from_pretrained("Cheselle/finetuned-arctic-sentence")

    # Vector store and retriever
    vectorstore = Qdrant.from_documents(
        documents=combined_documents,
        embedding=embeddings,
        location=":memory:",
        collection_name="AI Policy"
    )
    
    retriever = vectorstore.as_retriever()
    
    # Set the retriever and prompt into session for reuse
    cl.user_session.set("runnable", model)
    cl.user_session.set("retriever", retriever)
    cl.user_session.set("prompt_template", prompt)



@cl.on_message
async def on_message(message: cl.Message):
    # Get the stored model, retriever, and prompt
    model = cast(ChatOpenAI, cl.user_session.get("runnable"))  # type: ChatOpenAI
    retriever = cl.user_session.get("retriever")  # Get the retriever from the session
    prompt_template = cl.user_session.get("prompt_template")  # Get the RAG prompt template

    # Log the message content
    print(f"Received message: {message.content}")

    # Retrieve relevant context from documents based on the user's message
    relevant_docs = retriever.get_relevant_documents(message.content)
    print(f"Retrieved {len(relevant_docs)} documents.")

    if not relevant_docs:
        print("No relevant documents found.")
        await cl.Message(content="Sorry, I couldn't find any relevant documents.").send()
        return

    context = "\n\n".join([doc.page_content for doc in relevant_docs])

    # Log the context to check
    print(f"Context: {context}")

    # Construct the final RAG prompt
    final_prompt = prompt_template.format(context=context, question=message.content)
    print(f"Final prompt: {final_prompt}")

    # Initialize a streaming message
    msg = cl.Message(content="")

    # Stream the response from the model
    async for chunk in model.astream(
        final_prompt,
        config=RunnableConfig(callbacks=[cl.LangchainCallbackHandler()]),
    ):
        # Extract the content from AIMessageChunk and concatenate it to the message
        await msg.stream_token(chunk.content)

    await msg.send()

if __name__ == "__main__":
    app.run()  # or demo.launch() for Gradio apps