alejandro commited on
Commit
bda4e9b
·
1 Parent(s): 411b037

feat: update prompt template && add groq LLM

Browse files
Files changed (2) hide show
  1. requirements.txt +2 -0
  2. src/app.py +13 -5
requirements.txt CHANGED
@@ -4,3 +4,5 @@ langchain-community==0.0.21
4
  langchain-core==0.1.24
5
  langchain-openai==0.0.6
6
  mysql-connector-python==8.3.0
 
 
 
4
  langchain-core==0.1.24
5
  langchain-openai==0.0.6
6
  mysql-connector-python==8.3.0
7
+ groq==0.4.2
8
+ langchain-groq==0.0.1
src/app.py CHANGED
@@ -3,6 +3,7 @@ from langchain_community.utilities import SQLDatabase
3
  from langchain_core.output_parsers import StrOutputParser
4
  from langchain_core.runnables import RunnablePassthrough
5
  from langchain_openai import ChatOpenAI
 
6
  from langchain_core.messages import HumanMessage, AIMessage
7
  from langchain_core.prompts import ChatPromptTemplate
8
  from dotenv import load_dotenv
@@ -16,6 +17,11 @@ def get_sql_chain(db):
16
  Based on the table schema below, write a SQL query that would answer the user's question.
17
  {schema}
18
 
 
 
 
 
 
19
  Question: {question}
20
  SQL Query:
21
  """
@@ -23,6 +29,7 @@ def get_sql_chain(db):
23
  prompt = ChatPromptTemplate.from_template(template)
24
 
25
  llm = ChatOpenAI()
 
26
 
27
  def get_schema(_):
28
  return db.get_table_info()
@@ -49,6 +56,7 @@ def get_response(user_query, chat_history, db):
49
 
50
  prompt = ChatPromptTemplate.from_template(template)
51
 
 
52
  llm = ChatOpenAI()
53
 
54
  def get_schema(_):
@@ -85,11 +93,11 @@ with st.sidebar:
85
  st.title("Chat with a MySQL Database")
86
  st.write("This is a simple chat application allows you to chat with a MySQL database.")
87
 
88
- st.text_input("Host", key="name")
89
- st.text_input("Port", key="port")
90
- st.text_input("Username", key="username")
91
- st.text_input("Password", key="password")
92
- st.text_input("Database", key="database")
93
 
94
  if st.button("Connect"):
95
  with st.spinner("Connecting to the database..."):
 
3
  from langchain_core.output_parsers import StrOutputParser
4
  from langchain_core.runnables import RunnablePassthrough
5
  from langchain_openai import ChatOpenAI
6
+ from langchain_groq import ChatGroq
7
  from langchain_core.messages import HumanMessage, AIMessage
8
  from langchain_core.prompts import ChatPromptTemplate
9
  from dotenv import load_dotenv
 
17
  Based on the table schema below, write a SQL query that would answer the user's question.
18
  {schema}
19
 
20
+ Write only the SQL query and nothing else. For example:
21
+ Question: which 3 artists have the most tracks?
22
+ SQL Query: SELECT ArtistId, COUNT(*) as track_count FROM Track GROUP BY ArtistId ORDER BY track_count DESC LIMIT 3;
23
+ Question: Name 10 artists
24
+ SQL Query: SELECT Name FROM Artist LIMIT 10;
25
  Question: {question}
26
  SQL Query:
27
  """
 
29
  prompt = ChatPromptTemplate.from_template(template)
30
 
31
  llm = ChatOpenAI()
32
+ # llm = ChatGroq(temperature=0, model_name="mixtral-8x7b-32768")
33
 
34
  def get_schema(_):
35
  return db.get_table_info()
 
56
 
57
  prompt = ChatPromptTemplate.from_template(template)
58
 
59
+ # llm = ChatGroq(temperature=0, model_name="mixtral-8x7b-32768")
60
  llm = ChatOpenAI()
61
 
62
  def get_schema(_):
 
93
  st.title("Chat with a MySQL Database")
94
  st.write("This is a simple chat application allows you to chat with a MySQL database.")
95
 
96
+ st.text_input("Host", key="name", value="localhost")
97
+ st.text_input("Port", key="port", value="3306")
98
+ st.text_input("Username", key="username", value="root")
99
+ st.text_input("Password", key="password", type="password", value="admin")
100
+ st.text_input("Database", key="database", value="Chinook")
101
 
102
  if st.button("Connect"):
103
  with st.spinner("Connecting to the database..."):