CoolApp / app.py
rchrdgwr's picture
cleaned up code
0614fbf
raw
history blame
5.04 kB
import os
from chainlit.types import AskFileResponse
from aimakerspace.openai_utils.prompts import (
UserRolePrompt,
SystemRolePrompt,
AssistantRolePrompt,
)
from aimakerspace.openai_utils.embedding import EmbeddingModel
from aimakerspace.vectordatabase import VectorDatabase
from aimakerspace.openai_utils.chatmodel import ChatOpenAI
import chainlit as cl
from richard.text_utils import FileLoader
from richard.pipeline import RetrievalAugmentedQAPipeline
# from richard.vector_database import QdrantDatabase
from qdrant_client import QdrantClient
from langchain.vectorstores import Qdrant
system_template = """\
Use the following context to answer a users question.
If you cannot find the answer in the context, say you don't know the answer.
The context contains the text from a document. Refer to it as the document not the context.
"""
system_role_prompt = SystemRolePrompt(system_template)
user_prompt_template = """\
Context:
{context}
Question:
{question}
"""
user_role_prompt = UserRolePrompt(user_prompt_template)
def process_file(file: AskFileResponse):
fileLoader = FileLoader()
return fileLoader.load_file(file)
@cl.on_chat_start
async def on_chat_start():
res = await cl.AskActionMessage(
content="Do you want to use Qdrant?",
actions=[
cl.Action(name="yes", value="yes", label="βœ… Yes"),
cl.Action(name="no", value="no", label="❌ No"),
],
).send()
use_qdrant = False
use_qdrant_type = "Local"
if res and res.get("value") == "yes":
use_qdrant = True
local_res = await cl.AskActionMessage(
content="Do you want to use local or cloud?",
actions=[
cl.Action(name="Local", value="Local", label="βœ… Local"),
cl.Action(name="Cloud", value="Cloud", label="❌ Cloud"),
],
).send()
if local_res and local_res.get("value") == "Cloud":
use_qdrant_type = "Cloud"
msg = cl.Message(
content=f"Sorry - the Qdrant processing has been temporarily disconnected"
)
await msg.send()
use_qdrant = False
files = None
# Wait for the user to upload a file
while not files:
files = await cl.AskFileMessage(
content="Please upload a .txt or .pdf file to begin processing!",
accept=["text/plain", "application/pdf"],
max_size_mb=2,
timeout=180,
).send()
file = files[0]
msg = cl.Message(
content=f"Processing `{file.name}`...", disable_human_feedback=True
)
await msg.send()
# load the file
texts = process_file(file)
msg = cl.Message(
content=f"Resulted in {len(texts)} chunks", disable_human_feedback=True
)
await msg.send()
# decide if to use the dict vector store of the Qdrant vector store
from qdrant_client.models import PointStruct, VectorParams
# Create a dict vector store
if use_qdrant == False:
vector_db = VectorDatabase()
vector_db = await vector_db.abuild_from_list(texts)
else:
embedding_model = EmbeddingModel()
if use_qdrant_type == "Local":
from qdrant_client.http.models import OptimizersConfig
print("Using qdrant local")
qdrant_client = QdrantClient(location=":memory:")
vector_params = VectorParams(
size=1536, # vector size
distance="Cosine" # distance metric
)
qdrant_client.recreate_collection(
collection_name="my_collection",
vectors_config={"default": vector_params},
)
from richard.vector_database import QdrantDatabase
vector_db = QdrantDatabase(
qdrant_client=qdrant_client,
collection_name="my_collection",
embedding_model=embedding_model
)
vector_db = await vector_db.abuild_from_list(texts)
msg = cl.Message(
content=f"The Vector store has been created", disable_human_feedback=True
)
await msg.send()
chat_openai = ChatOpenAI()
# Create a chain
retrieval_augmented_qa_pipeline = RetrievalAugmentedQAPipeline(
vector_db_retriever=vector_db,
llm=chat_openai,
system_role_prompt=system_role_prompt,
user_role_prompt=user_role_prompt
)
# Let the user know that the system is ready
msg.content = f"Processing `{file.name}` is complete."
await msg.update()
msg.content = f"You can now ask questions about `{file.name}`."
await msg.update()
cl.user_session.set("chain", retrieval_augmented_qa_pipeline)
@cl.on_message
async def main(message):
chain = cl.user_session.get("chain")
msg = cl.Message(content="")
result = await chain.arun_pipeline(message.content)
async for stream_resp in result["response"]:
await msg.stream_token(stream_resp)
await msg.send()