alejandro commited on
Commit
5bea3fb
·
1 Parent(s): bda4e9b

finish tutorial

Browse files
Files changed (1) hide show
  1. src/app.py +89 -90
src/app.py CHANGED
@@ -1,140 +1,139 @@
1
- import streamlit as st
 
 
 
2
  from langchain_community.utilities import SQLDatabase
3
  from langchain_core.output_parsers import StrOutputParser
4
- from langchain_core.runnables import RunnablePassthrough
5
  from langchain_openai import ChatOpenAI
6
  from langchain_groq import ChatGroq
7
- from langchain_core.messages import HumanMessage, AIMessage
8
- from langchain_core.prompts import ChatPromptTemplate
9
- from dotenv import load_dotenv
10
 
11
- def initialize_database(host, port, username, password, database):
12
- db_uri = f"mysql+mysqlconnector://{username}:{password}@{host}:{port}/{database}"
13
  return SQLDatabase.from_uri(db_uri)
14
 
15
  def get_sql_chain(db):
16
- template = """
17
- Based on the table schema below, write a SQL query that would answer the user's question.
18
- {schema}
19
-
20
- Write only the SQL query and nothing else. For example:
 
 
 
 
 
 
21
  Question: which 3 artists have the most tracks?
22
  SQL Query: SELECT ArtistId, COUNT(*) as track_count FROM Track GROUP BY ArtistId ORDER BY track_count DESC LIMIT 3;
23
  Question: Name 10 artists
24
  SQL Query: SELECT Name FROM Artist LIMIT 10;
 
 
 
25
  Question: {question}
26
  SQL Query:
27
  """
28
-
29
- prompt = ChatPromptTemplate.from_template(template)
30
-
31
- llm = ChatOpenAI()
32
- # llm = ChatGroq(temperature=0, model_name="mixtral-8x7b-32768")
33
 
34
- def get_schema(_):
35
- return db.get_table_info()
36
-
37
- return (
38
- RunnablePassthrough.assign(schema=get_schema)
39
- | prompt
40
- | llm.bind(stop="\nSQL Result:")
41
- | StrOutputParser()
42
- )
43
 
44
- def get_response(user_query, chat_history, db):
45
-
 
 
 
 
 
 
 
 
 
 
 
 
46
  sql_chain = get_sql_chain(db)
47
 
48
  template = """
49
- Based on the table schema below, question, sql query, and sql response, write a natural language response:
50
- {schema}
51
-
52
- Conversation History: {chat_history}
53
- Question: {question}
54
- SQL Query: {query}
55
- SQL Response: {response}"""
56
-
 
57
  prompt = ChatPromptTemplate.from_template(template)
58
 
59
- # llm = ChatGroq(temperature=0, model_name="mixtral-8x7b-32768")
60
- llm = ChatOpenAI()
61
-
62
- def get_schema(_):
63
- return db.get_table_info()
64
 
65
  chain = (
66
  RunnablePassthrough.assign(query=sql_chain).assign(
67
- schema=get_schema,
68
- response= lambda vars: db.run(vars["query"])
69
- )
70
- | prompt
71
- | llm
72
- | StrOutputParser()
73
  )
74
 
75
- return chain.stream({
76
  "question": user_query,
77
  "chat_history": chat_history,
78
  })
 
79
 
 
 
 
 
 
80
  load_dotenv()
81
 
82
- st.set_page_config(initial_sidebar_state="expanded", page_title="Chat with a MySQL Database", page_icon=":speech_balloon:")
83
 
84
- if 'chat_history' not in st.session_state:
85
- st.session_state.chat_history = [
86
- AIMessage(content="Hello! I'm a chatbot that can help you with your SQL queries. Ask me anything about your database!")
87
- ]
88
-
89
- if 'db' not in st.session_state:
90
- st.session_state.db = None
91
 
92
  with st.sidebar:
93
- st.title("Chat with a MySQL Database")
94
- st.write("This is a simple chat application allows you to chat with a MySQL database.")
95
-
96
- st.text_input("Host", key="name", value="localhost")
97
- st.text_input("Port", key="port", value="3306")
98
- st.text_input("Username", key="username", value="root")
99
- st.text_input("Password", key="password", type="password", value="admin")
100
- st.text_input("Database", key="database", value="Chinook")
101
 
102
  if st.button("Connect"):
103
- with st.spinner("Connecting to the database..."):
104
- st.session_state.db = initialize_database(
105
- username=st.session_state.username,
106
- password=st.session_state.password,
107
- host=st.session_state.name,
108
- port=st.session_state.port,
109
- database=st.session_state.database
110
- )
111
- if st.session_state.db is not None:
112
- st.success("Connected to the database!")
113
 
114
- user_query = st.chat_input("Type a message...")
115
-
116
- # conversation
117
  for message in st.session_state.chat_history:
118
  if isinstance(message, AIMessage):
119
  with st.chat_message("AI"):
120
- st.write(message.content)
121
  elif isinstance(message, HumanMessage):
122
  with st.chat_message("Human"):
123
- st.write(message.content)
124
-
125
 
126
- if user_query is not None and user_query != "":
 
127
  st.session_state.chat_history.append(HumanMessage(content=user_query))
128
-
129
  with st.chat_message("Human"):
130
  st.markdown(user_query)
131
-
132
  with st.chat_message("AI"):
133
- response = st.write_stream(get_response(
134
- user_query,
135
- st.session_state.chat_history,
136
- st.session_state.db
137
- ))
138
 
139
- st.session_state.chat_history.append(AIMessage(content=response))
140
-
 
1
+ from dotenv import load_dotenv
2
+ from langchain_core.messages import AIMessage, HumanMessage
3
+ from langchain_core.prompts import ChatPromptTemplate
4
+ from langchain_core.runnables import RunnablePassthrough
5
  from langchain_community.utilities import SQLDatabase
6
  from langchain_core.output_parsers import StrOutputParser
 
7
  from langchain_openai import ChatOpenAI
8
  from langchain_groq import ChatGroq
9
+ import streamlit as st
 
 
10
 
11
+ def init_database(user: str, password: str, host: str, port: str, database: str) -> SQLDatabase:
12
+ db_uri = f"mysql+mysqlconnector://{user}:{password}@{host}:{port}/{database}"
13
  return SQLDatabase.from_uri(db_uri)
14
 
15
  def get_sql_chain(db):
16
+ template = """
17
+ You are a data analyst at a company. You are interacting with a user who is asking you questions about the company's database.
18
+ Based on the table schema below, write a SQL query that would answer the user's question. Take the conversation history into account.
19
+
20
+ <SCHEMA>{schema}</SCHEMA>
21
+
22
+ Conversation History: {chat_history}
23
+
24
+ Write only the SQL query and nothing else. Do not wrap the SQL query in any other text, not even backticks.
25
+
26
+ For example:
27
  Question: which 3 artists have the most tracks?
28
  SQL Query: SELECT ArtistId, COUNT(*) as track_count FROM Track GROUP BY ArtistId ORDER BY track_count DESC LIMIT 3;
29
  Question: Name 10 artists
30
  SQL Query: SELECT Name FROM Artist LIMIT 10;
31
+
32
+ Your turn:
33
+
34
  Question: {question}
35
  SQL Query:
36
  """
 
 
 
 
 
37
 
38
+ prompt = ChatPromptTemplate.from_template(template)
 
 
 
 
 
 
 
 
39
 
40
+ # llm = ChatOpenAI(model="gpt-4-0125-preview")
41
+ llm = ChatGroq(model="mixtral-8x7b-32768", temperature=0)
42
+
43
+ def get_schema(_):
44
+ return db.get_table_info()
45
+
46
+ return (
47
+ RunnablePassthrough.assign(schema=get_schema)
48
+ | prompt
49
+ | llm
50
+ | StrOutputParser()
51
+ )
52
+
53
+ def get_response(user_query: str, db: SQLDatabase, chat_history: list):
54
  sql_chain = get_sql_chain(db)
55
 
56
  template = """
57
+ You are a data analyst at a company. You are interacting with a user who is asking you questions about the company's database.
58
+ Based on the table schema below, question, sql query, and sql response, write a natural language response.
59
+ <SCHEMA>{schema}</SCHEMA>
60
+
61
+ Conversation History: {chat_history}
62
+ SQL Query: <SQL>{query}</SQL>
63
+ User question: {question}
64
+ SQL Response: {response}"""
65
+
66
  prompt = ChatPromptTemplate.from_template(template)
67
 
68
+ # llm = ChatOpenAI(model="gpt-4-0125-preview")
69
+ llm = ChatGroq(model="mixtral-8x7b-32768", temperature=0)
 
 
 
70
 
71
  chain = (
72
  RunnablePassthrough.assign(query=sql_chain).assign(
73
+ schema=lambda _: db.get_table_info(),
74
+ response=lambda vars: db.run(vars["query"]),
75
+ )
76
+ | prompt
77
+ | llm
78
+ | StrOutputParser()
79
  )
80
 
81
+ return chain.invoke({
82
  "question": user_query,
83
  "chat_history": chat_history,
84
  })
85
+
86
 
87
+ if "chat_history" not in st.session_state:
88
+ st.session_state.chat_history = [
89
+ AIMessage(content="Hello! I'm a SQL assistant. Ask me anything about your database."),
90
+ ]
91
+
92
  load_dotenv()
93
 
94
+ st.set_page_config(page_title="Chat with MySQL", page_icon=":speech_balloon:")
95
 
96
+ st.title("Chat with MySQL")
 
 
 
 
 
 
97
 
98
  with st.sidebar:
99
+ st.subheader("Settings")
100
+ st.write("This is a simple chat application using MySQL. Connect to the database and start chatting.")
101
+
102
+ st.text_input("Host", value="localhost", key="Host")
103
+ st.text_input("Port", value="3306", key="Port")
104
+ st.text_input("User", value="root", key="User")
105
+ st.text_input("Password", type="password", value="admin", key="Password")
106
+ st.text_input("Database", value="Chinook", key="Database")
107
 
108
  if st.button("Connect"):
109
+ with st.spinner("Connecting to database..."):
110
+ db = init_database(
111
+ st.session_state["User"],
112
+ st.session_state["Password"],
113
+ st.session_state["Host"],
114
+ st.session_state["Port"],
115
+ st.session_state["Database"]
116
+ )
117
+ st.session_state.db = db
118
+ st.success("Connected to database!")
119
 
 
 
 
120
  for message in st.session_state.chat_history:
121
  if isinstance(message, AIMessage):
122
  with st.chat_message("AI"):
123
+ st.markdown(message.content)
124
  elif isinstance(message, HumanMessage):
125
  with st.chat_message("Human"):
126
+ st.markdown(message.content)
 
127
 
128
+ user_query = st.chat_input("Type a message...")
129
+ if user_query is not None and user_query.strip() != "":
130
  st.session_state.chat_history.append(HumanMessage(content=user_query))
131
+
132
  with st.chat_message("Human"):
133
  st.markdown(user_query)
134
+
135
  with st.chat_message("AI"):
136
+ response = get_response(user_query, st.session_state.db, st.session_state.chat_history)
137
+ st.markdown(response)
 
 
 
138
 
139
+ st.session_state.chat_history.append(AIMessage(content=response))