File size: 4,067 Bytes
ed4d993
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from langchain.memory import ConversationBufferMemory
from langchain.output_parsers.list import CommaSeparatedListOutputParser
from langchain.sql_database import SQLDatabase
from langchain_core.prompts import PromptTemplate

from langchain_experimental.sql.base import SQLDatabaseChain, SQLDatabaseSequentialChain
from tests.unit_tests.fake_llm import FakeLLM

# Fake db to test SQL-Chain
db = SQLDatabase.from_uri("sqlite:///:memory:")


def create_fake_db(db: SQLDatabase) -> SQLDatabase:
    """Create a table in fake db to test SQL-Chain"""
    db.run(
        """
    CREATE TABLE foo (baaz TEXT);
    """
    )
    db.run(
        """
    INSERT INTO foo (baaz)
    VALUES ('baaz');
    """
    )
    return db


db = create_fake_db(db)


def test_sql_chain_without_memory() -> None:
    queries = {"foo": "SELECT baaz from foo", "foo2": "SELECT baaz from foo"}
    llm = FakeLLM(queries=queries, sequential_responses=True)
    db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True)
    assert db_chain.run("hello") == "SELECT baaz from foo"


def test_sql_chain_sequential_without_memory() -> None:
    queries = {
        "foo": "SELECT baaz from foo",
        "foo2": "SELECT baaz from foo",
        "foo3": "SELECT baaz from foo",
    }
    llm = FakeLLM(queries=queries, sequential_responses=True)
    db_chain = SQLDatabaseSequentialChain.from_llm(llm, db, verbose=True)
    assert db_chain.run("hello") == "SELECT baaz from foo"


def test_sql_chain_with_memory() -> None:
    valid_prompt_with_history = """
    Only use the following tables:
    {table_info}
    Question: {input}

    Given an input question, first create a syntactically correct
    {dialect} query to run.
    Always limit your query to at most {top_k} results.

    Relevant pieces of previous conversation:
    {history}

    (You do not need to use these pieces of information if not relevant)
    """
    prompt = PromptTemplate(
        input_variables=["input", "table_info", "dialect", "top_k", "history"],
        template=valid_prompt_with_history,
    )
    queries = {"foo": "SELECT baaz from foo", "foo2": "SELECT baaz from foo"}
    llm = FakeLLM(queries=queries, sequential_responses=True)
    memory = ConversationBufferMemory()
    db_chain = SQLDatabaseChain.from_llm(
        llm, db, memory=memory, prompt=prompt, verbose=True
    )
    assert db_chain.run("hello") == "SELECT baaz from foo"


def test_sql_chain_sequential_with_memory() -> None:
    valid_query_prompt_str = """
    Only use the following tables:
    {table_info}
    Question: {input}

    Given an input question, first create a syntactically correct
    {dialect} query to run.
    Always limit your query to at most {top_k} results.

    Relevant pieces of previous conversation:
    {history}

    (You do not need to use these pieces of information
    if not relevant)
    """
    valid_decider_prompt_str = """Given the below input question and list of potential
    tables, output a comma separated list of the
    table names that may be necessary to answer this question.

    Question: {query}

    Table Names: {table_names}

    Relevant Table Names:"""

    valid_query_prompt = PromptTemplate(
        input_variables=["input", "table_info", "dialect", "top_k", "history"],
        template=valid_query_prompt_str,
    )
    valid_decider_prompt = PromptTemplate(
        input_variables=["query", "table_names"],
        template=valid_decider_prompt_str,
        output_parser=CommaSeparatedListOutputParser(),
    )
    queries = {
        "foo": "SELECT baaz from foo",
        "foo2": "SELECT baaz from foo",
        "foo3": "SELECT baaz from foo",
    }
    llm = FakeLLM(queries=queries, sequential_responses=True)
    memory = ConversationBufferMemory(memory_key="history", input_key="query")
    db_chain = SQLDatabaseSequentialChain.from_llm(
        llm,
        db,
        memory=memory,
        decider_prompt=valid_decider_prompt,
        query_prompt=valid_query_prompt,
        verbose=True,
    )
    assert db_chain.run("hello") == "SELECT baaz from foo"