Spaces:
Sleeping
Sleeping
import os | |
import chainlit as cl | |
import llama_index | |
from llama_index import set_global_handler | |
from llama_index.embeddings import OpenAIEmbedding | |
from llama_index import ServiceContext | |
from llama_index.llms import OpenAI | |
from llama_index import SimpleDirectoryReader | |
from llama_index.ingestion import IngestionPipeline | |
from llama_index.node_parser import TokenTextSplitter | |
from llama_index import load_index_from_storage | |
from llama_index.tools import FunctionTool | |
from llama_index.vector_stores.types import ( | |
VectorStoreInfo, | |
MetadataInfo, | |
ExactMatchFilter, | |
MetadataFilters, | |
) | |
from llama_index.retrievers import VectorIndexRetriever | |
from llama_index.query_engine import RetrieverQueryEngine | |
from typing import List | |
from AutoRetrieveModel import AutoRetrieveModel | |
from llama_index.agent import OpenAIAgent | |
from sqlalchemy import create_engine | |
from llama_index import SQLDatabase | |
from llama_index.indices.struct_store.sql_query import NLSQLTableQueryEngine | |
from llama_index.tools.query_engine import QueryEngineTool | |
import pandas as pd | |
import openai | |
set_global_handler("wandb", run_args={"project": "llamaindex-demo-v1"}) | |
wandb_callback = llama_index.global_handler | |
def create_semantic_agent(service_context): | |
# Load in wikipedia index | |
storage_context = wandb_callback.load_storage_context( | |
artifact_url="jfreeman/llamaindex-demo-v1/wiki-index:v1" | |
) | |
index = load_index_from_storage(storage_context, service_context=service_context) | |
def auto_retrieve_fn( | |
query: str, filter_key_list: List[str], filter_value_list: List[str] | |
): | |
"""Auto retrieval function. | |
Performs auto-retrieval from a vector database, and then applies a set of filters. | |
""" | |
query = query or "Query" | |
exact_match_filters = [ | |
ExactMatchFilter(key=k, value=v) | |
for k, v in zip(filter_key_list, filter_value_list) | |
] | |
retriever = VectorIndexRetriever( | |
index, filters=MetadataFilters(filters=exact_match_filters), top_k=3 | |
) | |
query_engine = RetrieverQueryEngine.from_args(retriever, service_context=service_context) | |
response = query_engine.query(query) | |
return str(response) | |
vector_store_info = VectorStoreInfo( | |
content_info="semantic information about movies", | |
metadata_info=[MetadataInfo( | |
name="title", | |
type="str", | |
description="title of the movie, one of [John Wick (film), John Wick: Chapter 2, John Wick: Chapter 3 β Parabellum, John Wick: Chapter 4]", | |
)] | |
) | |
description = f"""\ | |
Use this tool to look up semantic information about films. | |
The vector database schema is given below: | |
{vector_store_info.json()} | |
""" | |
auto_retrieve_tool = FunctionTool.from_defaults( | |
fn=auto_retrieve_fn, | |
name="semantic-film-info", | |
description=description, | |
fn_schema=AutoRetrieveModel | |
) | |
return auto_retrieve_tool | |
def create_sql_agent(service_context): | |
engine = create_engine("sqlite+pysqlite:///:memory:") | |
for i in range(1,5): | |
fn = os.path.join('wick_tables',f'jw{i}.csv') | |
df = pd.read_csv(fn) | |
df.to_sql( | |
f"John Wick {i}", | |
engine | |
) | |
sql_database = SQLDatabase( | |
engine=engine, | |
include_tables=["John Wick 1", "John Wick 2", "John Wick 3", "John Wick 4"] | |
) | |
sql_query_engine = NLSQLTableQueryEngine( | |
sql_database=sql_database, | |
tables=["John Wick 1", "John Wick 2", "John Wick 3", "John Wick 4"], | |
service_context=service_context | |
) | |
sql_tool = QueryEngineTool.from_defaults( | |
query_engine=sql_query_engine, | |
name="sql-query", | |
description=( | |
"Useful for translating a natrual language query into a SQL query over a table containing: " | |
"John Wick 1, containing information related to reviews of the first John Wick movie call 'John Wick'" | |
"John Wick 2, containing information related to reviews of the second John Wick movie call 'John Wick: Chapter 2'" | |
"John Wick 3, containing information related to reviews of the third John Wick movie call 'John Wick: Chapter 3 - Parabellum'" | |
"John Wick 4, containing information related to reviews of the forth John Wick movie call 'John Wick: Chapter 4'" | |
), | |
) | |
return sql_tool | |
welcome_message = "Welcome to the John Wick RAQA demo! Ask me anything about the John Wick movies." | |
# marks a function that will be executed at the start of a user session | |
async def start_chat(): | |
# Create the service context | |
embed_model = OpenAIEmbedding() | |
chunk_size = 500 | |
llm = OpenAI( | |
temperature=0, | |
model='gpt-4-1106-preview', | |
streaming=True | |
) | |
service_context = ServiceContext.from_defaults( | |
llm=llm, | |
chunk_size=chunk_size, | |
embed_model=embed_model, | |
) | |
auto_retrieve_tool = create_semantic_agent(service_context) | |
sql_tool = create_sql_agent(service_context) | |
''' | |
agent = OpenAIAgent.from_tools( | |
tools=[auto_retrieve_tool, sql_tool], | |
) | |
''' | |
agent = OpenAIAgent.from_tools( | |
tools=[sql_tool, auto_retrieve_tool], | |
) | |
cl.user_session.set("agent", agent) | |
await cl.Message(content=welcome_message).send() | |
# marks a function that should be run each time the chatbot receives a message from a user | |
async def main(message: cl.Message): | |
agent = cl.user_session.get("agent") | |
res = await agent.achat(message.content) | |
answer = str(res) | |
await cl.Message(content=answer).send() | |