Wick-RAQA / app.py
JFreemanBRG's picture
Initial Commit
679f269
raw
history blame
5.66 kB
import os
import chainlit as cl
import llama_index
from llama_index import set_global_handler
from llama_index.embeddings import OpenAIEmbedding
from llama_index import ServiceContext
from llama_index.llms import OpenAI
from llama_index import SimpleDirectoryReader
from llama_index.ingestion import IngestionPipeline
from llama_index.node_parser import TokenTextSplitter
from llama_index import load_index_from_storage
from llama_index.tools import FunctionTool
from llama_index.vector_stores.types import (
VectorStoreInfo,
MetadataInfo,
ExactMatchFilter,
MetadataFilters,
)
from llama_index.retrievers import VectorIndexRetriever
from llama_index.query_engine import RetrieverQueryEngine
from typing import List
from AutoRetrieveModel import AutoRetrieveModel
from llama_index.agent import OpenAIAgent
from sqlalchemy import create_engine
from llama_index import SQLDatabase
from llama_index.indices.struct_store.sql_query import NLSQLTableQueryEngine
from llama_index.tools.query_engine import QueryEngineTool
import pandas as pd
import openai
set_global_handler("wandb", run_args={"project": "llamaindex-demo-v1"})
wandb_callback = llama_index.global_handler
def create_semantic_agent(service_context):
# Load in wikipedia index
storage_context = wandb_callback.load_storage_context(
artifact_url="jfreeman/llamaindex-demo-v1/wiki-index:v1"
)
index = load_index_from_storage(storage_context, service_context=service_context)
def auto_retrieve_fn(
query: str, filter_key_list: List[str], filter_value_list: List[str]
):
"""Auto retrieval function.
Performs auto-retrieval from a vector database, and then applies a set of filters.
"""
query = query or "Query"
exact_match_filters = [
ExactMatchFilter(key=k, value=v)
for k, v in zip(filter_key_list, filter_value_list)
]
retriever = VectorIndexRetriever(
index, filters=MetadataFilters(filters=exact_match_filters), top_k=3
)
query_engine = RetrieverQueryEngine.from_args(retriever, service_context=service_context)
response = query_engine.query(query)
return str(response)
vector_store_info = VectorStoreInfo(
content_info="semantic information about movies",
metadata_info=[MetadataInfo(
name="title",
type="str",
description="title of the movie, one of [John Wick (film), John Wick: Chapter 2, John Wick: Chapter 3 – Parabellum, John Wick: Chapter 4]",
)]
)
description = f"""\
Use this tool to look up semantic information about films.
The vector database schema is given below:
{vector_store_info.json()}
"""
auto_retrieve_tool = FunctionTool.from_defaults(
fn=auto_retrieve_fn,
name="semantic-film-info",
description=description,
fn_schema=AutoRetrieveModel
)
return auto_retrieve_tool
def create_sql_agent(service_context):
engine = create_engine("sqlite+pysqlite:///:memory:")
for i in range(1,5):
fn = os.path.join('wick_tables',f'jw{i}.csv')
df = pd.read_csv(fn)
df.to_sql(
f"John Wick {i}",
engine
)
sql_database = SQLDatabase(
engine=engine,
include_tables=["John Wick 1", "John Wick 2", "John Wick 3", "John Wick 4"]
)
sql_query_engine = NLSQLTableQueryEngine(
sql_database=sql_database,
tables=["John Wick 1", "John Wick 2", "John Wick 3", "John Wick 4"],
service_context=service_context
)
sql_tool = QueryEngineTool.from_defaults(
query_engine=sql_query_engine,
name="sql-query",
description=(
"Useful for translating a natrual language query into a SQL query over a table containing: "
"John Wick 1, containing information related to reviews of the first John Wick movie call 'John Wick'"
"John Wick 2, containing information related to reviews of the second John Wick movie call 'John Wick: Chapter 2'"
"John Wick 3, containing information related to reviews of the third John Wick movie call 'John Wick: Chapter 3 - Parabellum'"
"John Wick 4, containing information related to reviews of the forth John Wick movie call 'John Wick: Chapter 4'"
),
)
return sql_tool
welcome_message = "Welcome to the John Wick RAQA demo! Ask me anything about the John Wick movies."
@cl.on_chat_start # marks a function that will be executed at the start of a user session
async def start_chat():
# Create the service context
embed_model = OpenAIEmbedding()
chunk_size = 500
llm = OpenAI(
temperature=0,
model='gpt-4-1106-preview',
streaming=True
)
service_context = ServiceContext.from_defaults(
llm=llm,
chunk_size=chunk_size,
embed_model=embed_model,
)
auto_retrieve_tool = create_semantic_agent(service_context)
sql_tool = create_sql_agent(service_context)
'''
agent = OpenAIAgent.from_tools(
tools=[auto_retrieve_tool, sql_tool],
)
'''
agent = OpenAIAgent.from_tools(
tools=[sql_tool, auto_retrieve_tool],
)
cl.user_session.set("agent", agent)
await cl.Message(content=welcome_message).send()
@cl.on_message # marks a function that should be run each time the chatbot receives a message from a user
async def main(message: cl.Message):
agent = cl.user_session.get("agent")
res = await agent.achat(message.content)
answer = str(res)
await cl.Message(content=answer).send()