FineTunedRAG / app.py
Cheselle's picture
Create app.py
fbba7ae verified
raw
history blame
5.36 kB
from langchain_openai import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
from langchain.schema import StrOutputParser
from langchain.schema.runnable import Runnable
from langchain.schema.runnable.config import RunnableConfig
from typing import cast
import os
from langchain_community.document_loaders import PyMuPDFLoader
from langchain_experimental.text_splitter import SemanticChunker
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_openai.embeddings import OpenAIEmbeddings
from langchain_community.vectorstores import Qdrant
from langchain_core.runnables import RunnablePassthrough, RunnableParallel
from operator import itemgetter
import chainlit as cl
from openai import AsyncOpenAI
from dotenv import load_dotenv
load_dotenv()
# Set up API key for OpenAI
os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")
"""
"What is the AI Bill of Rights, and how does it affect the development of AI systems in the U.S.?"
"How is the government planning to regulate AI technologies in relation to privacy and data security?"
"What are the key principles outlined in the NIST AI Risk Management Framework?"
"How will the AI Bill of Rights affect businesses developing AI solutions for consumers?"
"What role does the government play in ensuring that AI is developed ethically and responsibly?"
"How might the outcomes of the upcoming elections impact AI regulation and policy?"
"What are the risks associated with using AI in political campaigns and decision-making?"
"How do the NIST guidelines help organizations reduce bias and ensure fairness in AI applications?"
"How are other countries approaching AI regulation compared to the U.S., and what can we learn from them?"
"What challenges do businesses face in complying with government guidelines like the AI Bill of Rights and NIST framework?"
"""
@cl.on_chat_start
async def on_chat_start():
model = ChatOpenAI(streaming=True)
# Define RAG prompt template
prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"You're a very knowledgeable AI engineer who's good at explaining stuff like ELI5."
),
("human", "{context}\n\nQuestion: {question}")
]
)
# Load documents and create retriever
ai_framework_document = PyMuPDFLoader(file_path="https://nvlpubs.nist.gov/nistpubs/ai/NIST.AI.600-1.pdf").load()
ai_blueprint_document = PyMuPDFLoader(file_path="https://www.whitehouse.gov/wp-content/uploads/2022/10/Blueprint-for-an-AI-Bill-of-Rights.pdf").load()
def metadata_generator(document, name):
fixed_text_splitter = RecursiveCharacterTextSplitter(chunk_size=500,
chunk_overlap=100,
separators=["\n\n", "\n", ".", "!", "?"]
)
collection = fixed_text_splitter.split_documents(document)
for doc in collection:
doc.metadata["source"] = name
return collection
recursive_framework_document = metadata_generator(ai_framework_document, "AI Framework")
recursive_blueprint_document = metadata_generator(ai_blueprint_document, "AI Blueprint")
combined_documents = recursive_framework_document + recursive_blueprint_document
from transformers import AutoModel
embeddings = AutoModel.from_pretrained("Cheselle/finetuned-arctic-sentence")
# Vector store and retriever
vectorstore = Qdrant.from_documents(
documents=combined_documents,
embedding=embeddings,
location=":memory:",
collection_name="AI Policy"
)
retriever = vectorstore.as_retriever()
# Set the retriever and prompt into session for reuse
cl.user_session.set("runnable", model)
cl.user_session.set("retriever", retriever)
cl.user_session.set("prompt_template", prompt)
@cl.on_message
async def on_message(message: cl.Message):
# Get the stored model, retriever, and prompt
model = cast(ChatOpenAI, cl.user_session.get("runnable")) # type: ChatOpenAI
retriever = cl.user_session.get("retriever") # Get the retriever from the session
prompt_template = cl.user_session.get("prompt_template") # Get the RAG prompt template
# Log the message content
print(f"Received message: {message.content}")
# Retrieve relevant context from documents based on the user's message
relevant_docs = retriever.get_relevant_documents(message.content)
print(f"Retrieved {len(relevant_docs)} documents.")
if not relevant_docs:
print("No relevant documents found.")
await cl.Message(content="Sorry, I couldn't find any relevant documents.").send()
return
context = "\n\n".join([doc.page_content for doc in relevant_docs])
# Log the context to check
print(f"Context: {context}")
# Construct the final RAG prompt
final_prompt = prompt_template.format(context=context, question=message.content)
print(f"Final prompt: {final_prompt}")
# Initialize a streaming message
msg = cl.Message(content="")
# Stream the response from the model
async for chunk in model.astream(
final_prompt,
config=RunnableConfig(callbacks=[cl.LangchainCallbackHandler()]),
):
# Extract the content from AIMessageChunk and concatenate it to the message
await msg.stream_token(chunk.content)
await msg.send()