alejandro commited on
Commit
efb8ba7
·
1 Parent(s): c796379

refactor: abstract out creation of mysql chain

Browse files
Files changed (1) hide show
  1. src/app.py +23 -20
src/app.py CHANGED
@@ -4,37 +4,40 @@ from langchain_core.output_parsers import StrOutputParser
4
  from langchain_core.runnables import RunnablePassthrough
5
  from langchain_openai import ChatOpenAI
6
  from langchain_core.messages import HumanMessage, AIMessage
 
7
  from dotenv import load_dotenv
8
 
9
  def initialize_database(host, port, username, password, database):
10
  db_uri = f"mysql+mysqlconnector://{username}:{password}@{host}:{port}/{database}"
11
  return SQLDatabase.from_uri(db_uri)
12
 
13
- def get_response(user_query, chat_history, db):
14
-
15
- from langchain_core.prompts import ChatPromptTemplate
 
16
 
17
- template = """
18
- Based on the table schema below, write a SQL query that would answer the user's question.
19
- {schema}
20
 
21
- Question: {question}
22
- SQL Query:
23
- """
 
 
 
24
 
25
- prompt = ChatPromptTemplate.from_template(template)
 
 
 
 
 
26
 
27
- llm = ChatOpenAI()
28
-
29
- def get_schema(_):
30
- return db.get_table_info()
31
 
32
- sql_chain = (
33
- RunnablePassthrough.assign(schema=get_schema)
34
- | prompt
35
- | llm.bind(stop="\nSQL Result:")
36
- | StrOutputParser()
37
- )
38
 
39
  return sql_chain.invoke({
40
  "question": user_query
 
4
  from langchain_core.runnables import RunnablePassthrough
5
  from langchain_openai import ChatOpenAI
6
  from langchain_core.messages import HumanMessage, AIMessage
7
+ from langchain_core.prompts import ChatPromptTemplate
8
  from dotenv import load_dotenv
9
 
10
  def initialize_database(host, port, username, password, database):
11
  db_uri = f"mysql+mysqlconnector://{username}:{password}@{host}:{port}/{database}"
12
  return SQLDatabase.from_uri(db_uri)
13
 
14
+ def get_sql_chain(db):
15
+ template = """
16
+ Based on the table schema below, write a SQL query that would answer the user's question.
17
+ {schema}
18
 
19
+ Question: {question}
20
+ SQL Query:
21
+ """
22
 
23
+ prompt = ChatPromptTemplate.from_template(template)
24
+
25
+ llm = ChatOpenAI()
26
+
27
+ def get_schema(_):
28
+ return db.get_table_info()
29
 
30
+ return (
31
+ RunnablePassthrough.assign(schema=get_schema)
32
+ | prompt
33
+ | llm.bind(stop="\nSQL Result:")
34
+ | StrOutputParser()
35
+ )
36
 
 
 
 
 
37
 
38
+ def get_response(user_query, chat_history, db):
39
+
40
+ sql_chain = get_sql_chain(db)
 
 
 
41
 
42
  return sql_chain.invoke({
43
  "question": user_query