Spaces:
Sleeping
Sleeping
alejandro
commited on
Commit
·
bda4e9b
1
Parent(s):
411b037
feat: update prompt template && add groq LLM
Browse files- requirements.txt +2 -0
- src/app.py +13 -5
requirements.txt
CHANGED
@@ -4,3 +4,5 @@ langchain-community==0.0.21
|
|
4 |
langchain-core==0.1.24
|
5 |
langchain-openai==0.0.6
|
6 |
mysql-connector-python==8.3.0
|
|
|
|
|
|
4 |
langchain-core==0.1.24
|
5 |
langchain-openai==0.0.6
|
6 |
mysql-connector-python==8.3.0
|
7 |
+
groq==0.4.2
|
8 |
+
langchain-groq==0.0.1
|
src/app.py
CHANGED
@@ -3,6 +3,7 @@ from langchain_community.utilities import SQLDatabase
|
|
3 |
from langchain_core.output_parsers import StrOutputParser
|
4 |
from langchain_core.runnables import RunnablePassthrough
|
5 |
from langchain_openai import ChatOpenAI
|
|
|
6 |
from langchain_core.messages import HumanMessage, AIMessage
|
7 |
from langchain_core.prompts import ChatPromptTemplate
|
8 |
from dotenv import load_dotenv
|
@@ -16,6 +17,11 @@ def get_sql_chain(db):
|
|
16 |
Based on the table schema below, write a SQL query that would answer the user's question.
|
17 |
{schema}
|
18 |
|
|
|
|
|
|
|
|
|
|
|
19 |
Question: {question}
|
20 |
SQL Query:
|
21 |
"""
|
@@ -23,6 +29,7 @@ def get_sql_chain(db):
|
|
23 |
prompt = ChatPromptTemplate.from_template(template)
|
24 |
|
25 |
llm = ChatOpenAI()
|
|
|
26 |
|
27 |
def get_schema(_):
|
28 |
return db.get_table_info()
|
@@ -49,6 +56,7 @@ def get_response(user_query, chat_history, db):
|
|
49 |
|
50 |
prompt = ChatPromptTemplate.from_template(template)
|
51 |
|
|
|
52 |
llm = ChatOpenAI()
|
53 |
|
54 |
def get_schema(_):
|
@@ -85,11 +93,11 @@ with st.sidebar:
|
|
85 |
st.title("Chat with a MySQL Database")
|
86 |
st.write("This is a simple chat application allows you to chat with a MySQL database.")
|
87 |
|
88 |
-
st.text_input("Host", key="name")
|
89 |
-
st.text_input("Port", key="port")
|
90 |
-
st.text_input("Username", key="username")
|
91 |
-
st.text_input("Password", key="password")
|
92 |
-
st.text_input("Database", key="database")
|
93 |
|
94 |
if st.button("Connect"):
|
95 |
with st.spinner("Connecting to the database..."):
|
|
|
3 |
from langchain_core.output_parsers import StrOutputParser
|
4 |
from langchain_core.runnables import RunnablePassthrough
|
5 |
from langchain_openai import ChatOpenAI
|
6 |
+
from langchain_groq import ChatGroq
|
7 |
from langchain_core.messages import HumanMessage, AIMessage
|
8 |
from langchain_core.prompts import ChatPromptTemplate
|
9 |
from dotenv import load_dotenv
|
|
|
17 |
Based on the table schema below, write a SQL query that would answer the user's question.
|
18 |
{schema}
|
19 |
|
20 |
+
Write only the SQL query and nothing else. For example:
|
21 |
+
Question: which 3 artists have the most tracks?
|
22 |
+
SQL Query: SELECT ArtistId, COUNT(*) as track_count FROM Track GROUP BY ArtistId ORDER BY track_count DESC LIMIT 3;
|
23 |
+
Question: Name 10 artists
|
24 |
+
SQL Query: SELECT Name FROM Artist LIMIT 10;
|
25 |
Question: {question}
|
26 |
SQL Query:
|
27 |
"""
|
|
|
29 |
prompt = ChatPromptTemplate.from_template(template)
|
30 |
|
31 |
llm = ChatOpenAI()
|
32 |
+
# llm = ChatGroq(temperature=0, model_name="mixtral-8x7b-32768")
|
33 |
|
34 |
def get_schema(_):
|
35 |
return db.get_table_info()
|
|
|
56 |
|
57 |
prompt = ChatPromptTemplate.from_template(template)
|
58 |
|
59 |
+
# llm = ChatGroq(temperature=0, model_name="mixtral-8x7b-32768")
|
60 |
llm = ChatOpenAI()
|
61 |
|
62 |
def get_schema(_):
|
|
|
93 |
st.title("Chat with a MySQL Database")
|
94 |
st.write("This is a simple chat application allows you to chat with a MySQL database.")
|
95 |
|
96 |
+
st.text_input("Host", key="name", value="localhost")
|
97 |
+
st.text_input("Port", key="port", value="3306")
|
98 |
+
st.text_input("Username", key="username", value="root")
|
99 |
+
st.text_input("Password", key="password", type="password", value="admin")
|
100 |
+
st.text_input("Database", key="database", value="Chinook")
|
101 |
|
102 |
if st.button("Connect"):
|
103 |
with st.spinner("Connecting to the database..."):
|