Spaces:
Running
Running
File size: 4,067 Bytes
ed4d993 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
from langchain.memory import ConversationBufferMemory
from langchain.output_parsers.list import CommaSeparatedListOutputParser
from langchain.sql_database import SQLDatabase
from langchain_core.prompts import PromptTemplate
from langchain_experimental.sql.base import SQLDatabaseChain, SQLDatabaseSequentialChain
from tests.unit_tests.fake_llm import FakeLLM
# Fake db to test SQL-Chain
db = SQLDatabase.from_uri("sqlite:///:memory:")
def create_fake_db(db: SQLDatabase) -> SQLDatabase:
"""Create a table in fake db to test SQL-Chain"""
db.run(
"""
CREATE TABLE foo (baaz TEXT);
"""
)
db.run(
"""
INSERT INTO foo (baaz)
VALUES ('baaz');
"""
)
return db
db = create_fake_db(db)
def test_sql_chain_without_memory() -> None:
queries = {"foo": "SELECT baaz from foo", "foo2": "SELECT baaz from foo"}
llm = FakeLLM(queries=queries, sequential_responses=True)
db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True)
assert db_chain.run("hello") == "SELECT baaz from foo"
def test_sql_chain_sequential_without_memory() -> None:
queries = {
"foo": "SELECT baaz from foo",
"foo2": "SELECT baaz from foo",
"foo3": "SELECT baaz from foo",
}
llm = FakeLLM(queries=queries, sequential_responses=True)
db_chain = SQLDatabaseSequentialChain.from_llm(llm, db, verbose=True)
assert db_chain.run("hello") == "SELECT baaz from foo"
def test_sql_chain_with_memory() -> None:
valid_prompt_with_history = """
Only use the following tables:
{table_info}
Question: {input}
Given an input question, first create a syntactically correct
{dialect} query to run.
Always limit your query to at most {top_k} results.
Relevant pieces of previous conversation:
{history}
(You do not need to use these pieces of information if not relevant)
"""
prompt = PromptTemplate(
input_variables=["input", "table_info", "dialect", "top_k", "history"],
template=valid_prompt_with_history,
)
queries = {"foo": "SELECT baaz from foo", "foo2": "SELECT baaz from foo"}
llm = FakeLLM(queries=queries, sequential_responses=True)
memory = ConversationBufferMemory()
db_chain = SQLDatabaseChain.from_llm(
llm, db, memory=memory, prompt=prompt, verbose=True
)
assert db_chain.run("hello") == "SELECT baaz from foo"
def test_sql_chain_sequential_with_memory() -> None:
valid_query_prompt_str = """
Only use the following tables:
{table_info}
Question: {input}
Given an input question, first create a syntactically correct
{dialect} query to run.
Always limit your query to at most {top_k} results.
Relevant pieces of previous conversation:
{history}
(You do not need to use these pieces of information
if not relevant)
"""
valid_decider_prompt_str = """Given the below input question and list of potential
tables, output a comma separated list of the
table names that may be necessary to answer this question.
Question: {query}
Table Names: {table_names}
Relevant Table Names:"""
valid_query_prompt = PromptTemplate(
input_variables=["input", "table_info", "dialect", "top_k", "history"],
template=valid_query_prompt_str,
)
valid_decider_prompt = PromptTemplate(
input_variables=["query", "table_names"],
template=valid_decider_prompt_str,
output_parser=CommaSeparatedListOutputParser(),
)
queries = {
"foo": "SELECT baaz from foo",
"foo2": "SELECT baaz from foo",
"foo3": "SELECT baaz from foo",
}
llm = FakeLLM(queries=queries, sequential_responses=True)
memory = ConversationBufferMemory(memory_key="history", input_key="query")
db_chain = SQLDatabaseSequentialChain.from_llm(
llm,
db,
memory=memory,
decider_prompt=valid_decider_prompt,
query_prompt=valid_query_prompt,
verbose=True,
)
assert db_chain.run("hello") == "SELECT baaz from foo"
|