fahmiaziz commited on
Commit
af733da
·
verified ·
1 Parent(s): 4180c11

Upload 9 files

Browse files
Files changed (9) hide show
  1. .gitattributes +1 -0
  2. Chinook_Sqlite.sqlite +3 -0
  3. app.py +60 -2
  4. constant.py +10 -0
  5. lib.py +102 -0
  6. music-database-schema.json +757 -0
  7. prompt.py +14 -0
  8. requirements.txt +8 -0
  9. utils.py +22 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ Chinook_Sqlite.sqlite filter=lfs diff=lfs merge=lfs -text
Chinook_Sqlite.sqlite ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bdf635be69850bd3be09c9a2dbeef7ddfb80036bd3ef3381383cd03b61e4a61a
3
+ size 1067008
app.py CHANGED
@@ -1,5 +1,63 @@
1
  import streamlit as st
 
 
2
 
3
- x = st.slider('Select a value')
4
- st.write(x, 'squared is', x * x)
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ from lib import Text2SQLRAG
3
+ from utils import execute_query_and_return_df
4
 
 
 
5
 
6
+ st.set_page_config(page_title="Text2SQLRAG")
7
+ st.title("Text2SQLRAG")
8
+
9
+ # Create an instance of Text2SQLRAG
10
+ text2sql = Text2SQLRAG()
11
+
12
+ # Initialize session state for storing chat messages
13
+ if "messages" not in st.session_state:
14
+ st.session_state.messages = []
15
+
16
+ # Display conversation history from session state
17
+ for message in st.session_state.messages:
18
+ role = message.get("role", "assistant")
19
+ with st.chat_message(role):
20
+ if "output" in message:
21
+ with st.expander("Reasoning", expanded=True):
22
+ st.markdown(message["reasoning"])
23
+ if "sql_query" in message and message["sql_query"]:
24
+ with st.expander("SQL Query", expanded=True):
25
+ st.code(message["sql_query"])
26
+
27
+ # Get user input
28
+ input_text = st.chat_input("Chat with your bot here...")
29
+
30
+ if input_text:
31
+ # Display user input
32
+ with st.chat_message("user"):
33
+ st.markdown(input_text)
34
+
35
+ # Add user input to chat history
36
+ st.session_state.chat_history.append({"role": "user", "text": input_text})
37
+
38
+ # Get chatbot response
39
+ response = text2sql.run(input_text)
40
+ sql_query = response.query
41
+ reasoning = response.reasoning
42
+ df = execute_query_and_return_df(sql_query)
43
+
44
+ if sql_query:
45
+ with st.expander("SQL Query", expanded=True):
46
+ st.code(sql_query)
47
+
48
+ with st.expander("Reasoning", expanded=True):
49
+ st.write(reasoning)
50
+
51
+ if df is not None:
52
+ st.dataframe(df)
53
+ else:
54
+ st.error("Error executing query")
55
+
56
+ # Append assistant response to session state
57
+ st.session_state.messages.append(
58
+ {
59
+ "role": "assistant",
60
+ "reasoning": reasoning,
61
+ "sql_query": sql_query,
62
+ }
63
+ )
constant.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dotenv import load_dotenv, find_dotenv
3
+ load_dotenv(find_dotenv())
4
+
5
+
6
+ GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
7
+ EMBEDDING_MODEL = "intfloat/multilingual-e5-small"
8
+ GEMINI_MODEL = "gemini-2.0-flash"
9
+ PATH_SCHEMA = "/data/music-database-schema.json"
10
+ PATH_DB = "/data/Cinhook_Sqlite.sqlite"
lib.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from pydantic import BaseModel, Field
3
+ from langchain_google_genai import ChatGoogleGenerativeAI
4
+ from langchain_core.runnables import RunnablePassthrough
5
+ from langchain_chroma import Chroma
6
+ from langchain_huggingface.embeddings import HuggingFaceEmbeddings
7
+ from langchain_community.document_loaders import JSONLoader
8
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
9
+
10
+ from prompt import prompt
11
+ from utils import execute_query_and_return_df
12
+ from constant import (
13
+ GEMINI_MODEL,
14
+ GOOGLE_API_KEY,
15
+ PATH_SCHEMA,
16
+ PATH_DB,
17
+ EMBEDDING_MODEL
18
+ )
19
+
20
+
21
+
22
+ class SQLOutput(BaseModel):
23
+ query: str = Field(description="The SQL query to run.")
24
+ reasoning: str = Field(description="Reasoning to understand the SQL query.")
25
+
26
+ class Text2SQLRAG:
27
+ def __init__(self,
28
+ path_schema: str = PATH_SCHEMA,
29
+ path_db: str = PATH_DB,
30
+ model: str = GEMINI_MODEL,
31
+ api_key: str = GOOGLE_API_KEY,
32
+ embedding_model: str = EMBEDDING_MODEL
33
+ ):
34
+ """
35
+ A class for generating SQL queries based on natural language text.
36
+ """
37
+ self.logger = logging.getLogger(__name__)
38
+ self.logger.info('Initializing Text2SQLRAG')
39
+
40
+ model_kwargs = {
41
+ "max_tokens": 512,
42
+ "temperature": 0.2,
43
+ "top_k": 250,
44
+ "top_p": 1,
45
+ "stop_sequences": ["\n\nHuman:"]
46
+ }
47
+ self.schema = path_schema
48
+ self.db = path_db
49
+ self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model)
50
+
51
+ self.model = ChatGoogleGenerativeAI(
52
+ model=model,
53
+ api_key=api_key,
54
+ model_kwargs=model_kwargs
55
+ )
56
+ self.llm = self.model.with_structured_output(SQLOutput)
57
+ self.retriever = self._indexing_vectore()
58
+
59
+ def _indexing_vectore(self):
60
+ """
61
+ Indexes the database schema using a vector store for efficient retrieval.
62
+
63
+ This method loads the schema from a JSON file, splits it into chunks,
64
+ embeds the chunks using a specified embedding model, and stores them in
65
+ a vector store. It returns a retriever configured to search for the top
66
+ k relevant documents.
67
+
68
+ Returns:
69
+ retriever: An object capable of retrieving the most relevant schema
70
+ chunks based on the given search parameters.
71
+ """
72
+
73
+ self.logger.info('Indexing schema')
74
+ db_schema_loader = JSONLoader(
75
+ file_path=self.schema,
76
+ jq_schema='.',
77
+ text_content=False
78
+ )
79
+ text_splitter = RecursiveCharacterTextSplitter(
80
+ separators=["separator"],
81
+ chunk_size=10000,
82
+ chunk_overlap=100
83
+ )
84
+
85
+ docs = text_splitter.split_documents(db_schema_loader.load())
86
+
87
+ vectorstore = Chroma.from_documents(documents=docs,
88
+ embedding=self.embeddings)
89
+
90
+ retriever = vectorstore.as_retriever(search_kwargs={"k": 2})
91
+ self.logger.info('Finished indexing schema')
92
+ return retriever
93
+
94
+ def run(self, question: str):
95
+ self.logger.info(f'Running Text2SQLRAG for question: {question}')
96
+ rag_chain = (
97
+ {"context": self.retriever, "question": RunnablePassthrough()}
98
+ | prompt
99
+ | self.llm
100
+ )
101
+ response = rag_chain.invoke(question)
102
+ return response
music-database-schema.json ADDED
@@ -0,0 +1,757 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "tables": [
3
+ {
4
+ "separator": "table_1",
5
+ "name": "Album",
6
+ "schema": "CREATE TABLE Album (AlbumId integer PRIMARY KEY, Title character varying(160), ArtistId integer);",
7
+ "description": "This table stores information about music albums.",
8
+ "columns": [
9
+ {
10
+ "name": "AlbumId",
11
+ "description": "unique identifier for albums.",
12
+ "synonyms": [
13
+ "album id"
14
+ ]
15
+ },
16
+ {
17
+ "name": "Title",
18
+ "description": "title of the album",
19
+ "synonyms": [
20
+ "album title",
21
+ "album name"
22
+ ]
23
+ },
24
+ {
25
+ "name": "ArtistId",
26
+ "description": "Id of the artist associated with the album",
27
+ "synonyms": [
28
+ "artist id"
29
+ ]
30
+ }
31
+ ],
32
+ "sample_queries": [
33
+ {
34
+ "query": "SELECT a.Title FROM Album a JOIN Artist ar ON a.ArtistId = ar.ArtistId WHERE ar.Name = 'AC/DC'",
35
+ "user_input": "Get all albums by AC/DC"
36
+ },
37
+ {
38
+ "query": "SELECT COUNT(*) FROM Album WHERE ArtistId = 1",
39
+ "user_input": "Count how many albums AC/DC has"
40
+ },
41
+ {
42
+ "query": "SELECT a.Title, ar.Name FROM Album a JOIN Artist ar ON a.ArtistId = ar.ArtistId ORDER BY a.Title ASC",
43
+ "user_input": "List all albums alphabetically with their artists"
44
+ },
45
+ {
46
+ "query": "SELECT a.Title FROM Album a JOIN Track t ON a.AlbumId = t.AlbumId GROUP BY a.AlbumId ORDER BY COUNT(t.TrackId) DESC LIMIT 1",
47
+ "user_input": "Find the album with the most tracks"
48
+ }
49
+ ]
50
+ },
51
+ {
52
+ "separator": "table_2",
53
+ "name": "Artist",
54
+ "schema": "CREATE TABLE Artist (ArtistId integer PRIMARY KEY, Name character varying(120));",
55
+ "description": "This table stores information about music artists.",
56
+ "columns": [
57
+ {
58
+ "name": "ArtistId",
59
+ "description": "unique identifier for artists.",
60
+ "synonyms": [
61
+ "artist id"
62
+ ]
63
+ },
64
+ {
65
+ "name": "Name",
66
+ "description": "name of the artist",
67
+ "synonyms": [
68
+ "artist name"
69
+ ]
70
+ }
71
+ ],
72
+ "sample_queries": [
73
+ {
74
+ "query": "SELECT Name FROM Artist WHERE ArtistId = 1",
75
+ "user_input": "Get artist name with id 1"
76
+ },
77
+ {
78
+ "query": "SELECT ar.Name, COUNT(a.AlbumId) as AlbumCount FROM Artist ar LEFT JOIN Album a ON ar.ArtistId = a.ArtistId GROUP BY ar.ArtistId ORDER BY AlbumCount DESC",
79
+ "user_input": "List artists by number of albums they have"
80
+ },
81
+ {
82
+ "query": "SELECT ar.Name FROM Artist ar JOIN Album a ON ar.ArtistId = a.ArtistId JOIN Track t ON a.AlbumId = t.AlbumId GROUP BY ar.ArtistId ORDER BY SUM(t.Milliseconds) DESC LIMIT 5",
83
+ "user_input": "Find top 5 artists with the longest total music duration"
84
+ },
85
+ {
86
+ "query": "SELECT ar.Name FROM Artist ar JOIN Album a ON ar.ArtistId = a.ArtistId JOIN Track t ON a.AlbumId = t.AlbumId JOIN InvoiceLine il ON t.TrackId = il.TrackId GROUP BY ar.ArtistId ORDER BY SUM(il.UnitPrice * il.Quantity) DESC LIMIT 3",
87
+ "user_input": "Find the top 3 best-selling artists by revenue"
88
+ }
89
+ ]
90
+ },
91
+ {
92
+ "separator": "table_3",
93
+ "name": "Customer",
94
+ "schema": "CREATE TABLE Customer (CustomerId integer PRIMARY KEY, FirstName character varying(40), LastName character varying(20), Company character varying(80), Address character varying(70), City character varying(40), State character varying(40), Country character varying(40), PostalCode character varying(10), Phone character varying(24), Fax character varying(24), Email character varying(60), SupportRepId integer);",
95
+ "description": "This table stores information about customers who purchase music.",
96
+ "columns": [
97
+ {
98
+ "name": "CustomerId",
99
+ "description": "unique identifier for customers",
100
+ "synonyms": [
101
+ "customer id"
102
+ ]
103
+ },
104
+ {
105
+ "name": "FirstName",
106
+ "description": "first name of the customer",
107
+ "synonyms": [
108
+ "first name"
109
+ ]
110
+ },
111
+ {
112
+ "name": "LastName",
113
+ "description": "last name of the customer",
114
+ "synonyms": [
115
+ "last name",
116
+ "surname"
117
+ ]
118
+ },
119
+ {
120
+ "name": "Company",
121
+ "description": "company name of the customer if applicable",
122
+ "synonyms": [
123
+ "company name",
124
+ "organization"
125
+ ]
126
+ },
127
+ {
128
+ "name": "Address",
129
+ "description": "street address of the customer",
130
+ "synonyms": [
131
+ "street address"
132
+ ]
133
+ },
134
+ {
135
+ "name": "City",
136
+ "description": "city where the customer lives",
137
+ "synonyms": [
138
+ "customer city"
139
+ ]
140
+ },
141
+ {
142
+ "name": "State",
143
+ "description": "state or province where customer lives",
144
+ "synonyms": [
145
+ "province"
146
+ ]
147
+ },
148
+ {
149
+ "name": "Country",
150
+ "description": "country where customer lives",
151
+ "synonyms": [
152
+ "nation"
153
+ ]
154
+ },
155
+ {
156
+ "name": "PostalCode",
157
+ "description": "postal or zip code of customer's address",
158
+ "synonyms": [
159
+ "zip code"
160
+ ]
161
+ },
162
+ {
163
+ "name": "Phone",
164
+ "description": "phone number of the customer",
165
+ "synonyms": [
166
+ "telephone"
167
+ ]
168
+ },
169
+ {
170
+ "name": "Fax",
171
+ "description": "fax number of the customer if available",
172
+ "synonyms": [
173
+ "facsimile"
174
+ ]
175
+ },
176
+ {
177
+ "name": "Email",
178
+ "description": "email address of the customer",
179
+ "synonyms": [
180
+ "email address"
181
+ ]
182
+ },
183
+ {
184
+ "name": "SupportRepId",
185
+ "description": "Id of employee assigned as the customer's support representative",
186
+ "synonyms": [
187
+ "support rep id",
188
+ "employee id"
189
+ ]
190
+ }
191
+ ],
192
+ "sample_queries": [
193
+ {
194
+ "query": "SELECT FirstName, LastName, Email FROM Customer WHERE Country = 'Brazil'",
195
+ "user_input": "Get all Brazilian customers with their emails"
196
+ },
197
+ {
198
+ "query": "SELECT Country, COUNT(*) as CustomerCount FROM Customer GROUP BY Country ORDER BY CustomerCount DESC",
199
+ "user_input": "List countries by number of customers"
200
+ },
201
+ {
202
+ "query": "SELECT c.FirstName, c.LastName, SUM(i.Total) as TotalSpent FROM Customer c JOIN Invoice i ON c.CustomerId = i.CustomerId GROUP BY c.CustomerId ORDER BY TotalSpent DESC LIMIT 5",
203
+ "user_input": "Find top 5 customers by total spending"
204
+ },
205
+ {
206
+ "query": "SELECT c.Email FROM Customer c LEFT JOIN Invoice i ON c.CustomerId = i.CustomerId WHERE i.InvoiceId IS NULL",
207
+ "user_input": "Find customers who haven't made any purchases"
208
+ }
209
+ ]
210
+ },
211
+ {
212
+ "separator": "table_4",
213
+ "name": "Employee",
214
+ "schema": "CREATE TABLE Employee (EmployeeId integer PRIMARY KEY, LastName character varying(20), FirstName character varying(20), Title character varying(30), ReportsTo integer, BirthDate timestamp without time zone, HireDate timestamp without time zone, Address character varying(70), City character varying(40), State character varying(40), Country character varying(40), PostalCode character varying(10), Phone character varying(24), Fax character varying(24), Email character varying(60));",
215
+ "description": "This table stores information about employees who work for the music store.",
216
+ "columns": [
217
+ {
218
+ "name": "EmployeeId",
219
+ "description": "unique identifier for employees",
220
+ "synonyms": [
221
+ "employee id"
222
+ ]
223
+ },
224
+ {
225
+ "name": "LastName",
226
+ "description": "last name of the employee",
227
+ "synonyms": [
228
+ "surname"
229
+ ]
230
+ },
231
+ {
232
+ "name": "FirstName",
233
+ "description": "first name of the employee",
234
+ "synonyms": [
235
+ "given name"
236
+ ]
237
+ },
238
+ {
239
+ "name": "Title",
240
+ "description": "job title of the employee",
241
+ "synonyms": [
242
+ "position",
243
+ "job title"
244
+ ]
245
+ },
246
+ {
247
+ "name": "ReportsTo",
248
+ "description": "Id of the employee's manager",
249
+ "synonyms": [
250
+ "manager id",
251
+ "supervisor id"
252
+ ]
253
+ },
254
+ {
255
+ "name": "BirthDate",
256
+ "description": "date of birth of the employee",
257
+ "synonyms": [
258
+ "date of birth",
259
+ "DOB"
260
+ ]
261
+ },
262
+ {
263
+ "name": "HireDate",
264
+ "description": "date when the employee was hired",
265
+ "synonyms": [
266
+ "start date",
267
+ "employment date"
268
+ ]
269
+ },
270
+ {
271
+ "name": "Address",
272
+ "description": "street address of the employee",
273
+ "synonyms": [
274
+ "street address"
275
+ ]
276
+ },
277
+ {
278
+ "name": "City",
279
+ "description": "city where the employee lives",
280
+ "synonyms": [
281
+ "employee city"
282
+ ]
283
+ },
284
+ {
285
+ "name": "State",
286
+ "description": "state or province where employee lives",
287
+ "synonyms": [
288
+ "province"
289
+ ]
290
+ },
291
+ {
292
+ "name": "Country",
293
+ "description": "country where employee lives",
294
+ "synonyms": [
295
+ "nation"
296
+ ]
297
+ },
298
+ {
299
+ "name": "PostalCode",
300
+ "description": "postal or zip code of employee's address",
301
+ "synonyms": [
302
+ "zip code"
303
+ ]
304
+ },
305
+ {
306
+ "name": "Phone",
307
+ "description": "phone number of the employee",
308
+ "synonyms": [
309
+ "telephone"
310
+ ]
311
+ },
312
+ {
313
+ "name": "Fax",
314
+ "description": "fax number of the employee if available",
315
+ "synonyms": [
316
+ "facsimile"
317
+ ]
318
+ },
319
+ {
320
+ "name": "Email",
321
+ "description": "email address of the employee",
322
+ "synonyms": [
323
+ "email address"
324
+ ]
325
+ }
326
+ ],
327
+ "sample_queries": [
328
+ {
329
+ "query": "SELECT FirstName, LastName FROM Employee WHERE Title = 'Sales Manager'",
330
+ "user_input": "Get all sales managers"
331
+ },
332
+ {
333
+ "query": "SELECT e1.FirstName, e1.LastName, e2.FirstName as ManagerFirstName, e2.LastName as ManagerLastName FROM Employee e1 LEFT JOIN Employee e2 ON e1.ReportsTo = e2.EmployeeId",
334
+ "user_input": "List all employees with their managers"
335
+ },
336
+ {
337
+ "query": "SELECT COUNT(*) FROM Employee WHERE STRFTIME('%Y', HireDate) = '2002'",
338
+ "user_input": "Count employees hired in 2002"
339
+ },
340
+ {
341
+ "query": "SELECT e.FirstName, e.LastName, SUM(i.Total) as TotalSales FROM Employee e JOIN Customer c ON e.EmployeeId = c.SupportRepId JOIN Invoice i ON c.CustomerId = i.CustomerId GROUP BY e.EmployeeId ORDER BY TotalSales DESC",
342
+ "user_input": "Rank employees by total sales from their customers"
343
+ }
344
+ ]
345
+ },
346
+ {
347
+ "separator": "table_5",
348
+ "name": "Genre",
349
+ "schema": "CREATE TABLE Genre (GenreId integer PRIMARY KEY, Name character varying(120));",
350
+ "description": "This table stores music genres.",
351
+ "columns": [
352
+ {
353
+ "name": "GenreId",
354
+ "description": "unique identifier for genres",
355
+ "synonyms": [
356
+ "genre id"
357
+ ]
358
+ },
359
+ {
360
+ "name": "Name",
361
+ "description": "name of the genre",
362
+ "synonyms": [
363
+ "genre name"
364
+ ]
365
+ }
366
+ ],
367
+ "sample_queries": [
368
+ {
369
+ "query": "SELECT Name FROM Genre WHERE GenreId = 1",
370
+ "user_input": "Get the genre name with id 1"
371
+ },
372
+ {
373
+ "query": "SELECT g.Name, COUNT(t.TrackId) as TrackCount FROM Genre g JOIN Track t ON g.GenreId = t.GenreId GROUP BY g.GenreId ORDER BY TrackCount DESC",
374
+ "user_input": "List genres by number of tracks"
375
+ },
376
+ {
377
+ "query": "SELECT g.Name, SUM(t.Milliseconds)/3600000 as Hours FROM Genre g JOIN Track t ON g.GenreId = t.GenreId GROUP BY g.GenreId ORDER BY Hours DESC",
378
+ "user_input": "Total hours of music by genre"
379
+ },
380
+ {
381
+ "query": "SELECT g.Name, SUM(il.UnitPrice * il.Quantity) as Revenue FROM Genre g JOIN Track t ON g.GenreId = t.GenreId JOIN InvoiceLine il ON t.TrackId = il.TrackId GROUP BY g.GenreId ORDER BY Revenue DESC",
382
+ "user_input": "Revenue generated by each genre"
383
+ }
384
+ ]
385
+ },
386
+ {
387
+ "separator": "table_6",
388
+ "name": "Invoice",
389
+ "schema": "CREATE TABLE Invoice (InvoiceId integer PRIMARY KEY, CustomerId integer, InvoiceDate timestamp without time zone, BillingAddress character varying(70), BillingCity character varying(40), BillingState character varying(40), BillingCountry character varying(40), BillingPostalCode character varying(10), Total numeric(10,2));",
390
+ "description": "This table stores invoices for customer purchases.",
391
+ "columns": [
392
+ {
393
+ "name": "InvoiceId",
394
+ "description": "unique identifier for invoices",
395
+ "synonyms": [
396
+ "invoice id"
397
+ ]
398
+ },
399
+ {
400
+ "name": "CustomerId",
401
+ "description": "Id of the customer associated with the invoice",
402
+ "synonyms": [
403
+ "customer id"
404
+ ]
405
+ },
406
+ {
407
+ "name": "InvoiceDate",
408
+ "description": "date when the invoice was generated",
409
+ "synonyms": [
410
+ "date",
411
+ "purchase date"
412
+ ]
413
+ },
414
+ {
415
+ "name": "BillingAddress",
416
+ "description": "billing address for the invoice",
417
+ "synonyms": [
418
+ "address"
419
+ ]
420
+ },
421
+ {
422
+ "name": "BillingCity",
423
+ "description": "city for the billing address",
424
+ "synonyms": [
425
+ "city"
426
+ ]
427
+ },
428
+ {
429
+ "name": "BillingState",
430
+ "description": "state or province for the billing address",
431
+ "synonyms": [
432
+ "state",
433
+ "province"
434
+ ]
435
+ },
436
+ {
437
+ "name": "BillingCountry",
438
+ "description": "country for the billing address",
439
+ "synonyms": [
440
+ "country"
441
+ ]
442
+ },
443
+ {
444
+ "name": "BillingPostalCode",
445
+ "description": "postal code for the billing address",
446
+ "synonyms": [
447
+ "zip code",
448
+ "postal code"
449
+ ]
450
+ },
451
+ {
452
+ "name": "Total",
453
+ "description": "total amount of the invoice",
454
+ "synonyms": [
455
+ "amount",
456
+ "price"
457
+ ]
458
+ }
459
+ ],
460
+ "sample_queries": [
461
+ {
462
+ "query": "SELECT SUM(Total) as TotalSales FROM Invoice WHERE strftime('%Y', InvoiceDate) = '2009'",
463
+ "user_input": "Get total sales for 2009"
464
+ },
465
+ {
466
+ "query": "SELECT strftime('%Y-%m', InvoiceDate) as Month, SUM(Total) as MonthlyRevenue FROM Invoice GROUP BY Month ORDER BY Month",
467
+ "user_input": "Monthly revenue over time"
468
+ },
469
+ {
470
+ "query": "SELECT BillingCountry, SUM(Total) as Revenue FROM Invoice GROUP BY BillingCountry ORDER BY Revenue DESC",
471
+ "user_input": "Total revenue by country"
472
+ },
473
+ {
474
+ "query": "SELECT i.InvoiceId, COUNT(il.InvoiceLineId) as Items FROM Invoice i JOIN InvoiceLine il ON i.InvoiceId = il.InvoiceId GROUP BY i.InvoiceId ORDER BY Items DESC LIMIT 1",
475
+ "user_input": "Find the invoice with the most items"
476
+ }
477
+ ]
478
+ },
479
+ {
480
+ "separator": "table_7",
481
+ "name": "InvoiceLine",
482
+ "schema": "CREATE TABLE InvoiceLine (InvoiceLineId integer PRIMARY KEY, InvoiceId integer, TrackId integer, UnitPrice numeric(10,2), Quantity integer);",
483
+ "description": "This table stores individual line items for each invoice.",
484
+ "columns": [
485
+ {
486
+ "name": "InvoiceLineId",
487
+ "description": "unique identifier for invoice lines",
488
+ "synonyms": [
489
+ "line id"
490
+ ]
491
+ },
492
+ {
493
+ "name": "InvoiceId",
494
+ "description": "Id of the invoice this line belongs to",
495
+ "synonyms": [
496
+ "invoice id"
497
+ ]
498
+ },
499
+ {
500
+ "name": "TrackId",
501
+ "description": "Id of the track purchased",
502
+ "synonyms": [
503
+ "track id"
504
+ ]
505
+ },
506
+ {
507
+ "name": "UnitPrice",
508
+ "description": "price of the track",
509
+ "synonyms": [
510
+ "price"
511
+ ]
512
+ },
513
+ {
514
+ "name": "Quantity",
515
+ "description": "number of tracks purchased",
516
+ "synonyms": [
517
+ "amount",
518
+ "count"
519
+ ]
520
+ }
521
+ ],
522
+ "sample_queries": [
523
+ {
524
+ "query": "SELECT t.Name, il.UnitPrice, il.Quantity FROM InvoiceLine il JOIN Track t ON il.TrackId = t.TrackId WHERE il.InvoiceId = 1",
525
+ "user_input": "Get details of purchase in invoice 1"
526
+ },
527
+ {
528
+ "query": "SELECT t.Name, SUM(il.Quantity) as TimesPurchased FROM InvoiceLine il JOIN Track t ON il.TrackId = t.TrackId GROUP BY il.TrackId ORDER BY TimesPurchased DESC LIMIT 10",
529
+ "user_input": "Find the top 10 most purchased tracks"
530
+ },
531
+ {
532
+ "query": "SELECT i.InvoiceId, SUM(il.UnitPrice * il.Quantity) as TotalAmount FROM InvoiceLine il JOIN Invoice i ON il.InvoiceId = i.InvoiceId GROUP BY i.InvoiceId HAVING TotalAmount > 10",
533
+ "user_input": "Find invoices with total amount greater than $10"
534
+ },
535
+ {
536
+ "query": "SELECT AVG(UnitPrice) as AveragePrice FROM InvoiceLine",
537
+ "user_input": "Find the average unit price of purchased items"
538
+ }
539
+ ]
540
+ },
541
+ {
542
+ "separator": "table_8",
543
+ "name": "MediaType",
544
+ "schema": "CREATE TABLE MediaType (MediaTypeId integer PRIMARY KEY, Name character varying(120));",
545
+ "description": "This table stores types of media formats for music tracks.",
546
+ "columns": [
547
+ {
548
+ "name": "MediaTypeId",
549
+ "description": "unique identifier for media types",
550
+ "synonyms": [
551
+ "media type id"
552
+ ]
553
+ },
554
+ {
555
+ "name": "Name",
556
+ "description": "name of the media type",
557
+ "synonyms": [
558
+ "media type name",
559
+ "format"
560
+ ]
561
+ }
562
+ ],
563
+ "sample_queries": [
564
+ {
565
+ "query": "SELECT Name FROM MediaType",
566
+ "user_input": "List all available media types"
567
+ },
568
+ {
569
+ "query": "SELECT mt.Name, COUNT(t.TrackId) as TrackCount FROM MediaType mt JOIN Track t ON mt.MediaTypeId = t.MediaTypeId GROUP BY mt.MediaTypeId ORDER BY TrackCount DESC",
570
+ "user_input": "Count tracks by media type"
571
+ },
572
+ {
573
+ "query": "SELECT mt.Name, SUM(il.UnitPrice * il.Quantity) as Revenue FROM MediaType mt JOIN Track t ON mt.MediaTypeId = t.MediaTypeId JOIN InvoiceLine il ON t.TrackId = il.TrackId GROUP BY mt.MediaTypeId ORDER BY Revenue DESC",
574
+ "user_input": "Revenue generated by each media type"
575
+ },
576
+ {
577
+ "query": "SELECT mt.Name, SUM(t.Milliseconds)/3600000 as Hours FROM MediaType mt JOIN Track t ON mt.MediaTypeId = t.MediaTypeId GROUP BY mt.MediaTypeId ORDER BY Hours DESC",
578
+ "user_input": "Total duration in hours of each media type"
579
+ }
580
+ ]
581
+ },
582
+ {
583
+ "separator": "table_9",
584
+ "name": "Playlist",
585
+ "schema": "CREATE TABLE Playlist (PlaylistId integer PRIMARY KEY, Name character varying(120));",
586
+ "description": "This table stores playlists of tracks.",
587
+ "columns": [
588
+ {
589
+ "name": "PlaylistId",
590
+ "description": "unique identifier for playlists",
591
+ "synonyms": [
592
+ "playlist id"
593
+ ]
594
+ },
595
+ {
596
+ "name": "Name",
597
+ "description": "name of the playlist",
598
+ "synonyms": [
599
+ "playlist name"
600
+ ]
601
+ }
602
+ ],
603
+ "sample_queries": [
604
+ {
605
+ "query": "SELECT Name FROM Playlist WHERE PlaylistId = 1",
606
+ "user_input": "Get the name of playlist with id 1"
607
+ },
608
+ {
609
+ "query": "SELECT p.Name, COUNT(pt.TrackId) as TrackCount FROM Playlist p JOIN PlaylistTrack pt ON p.PlaylistId = pt.PlaylistId GROUP BY p.PlaylistId ORDER BY TrackCount DESC",
610
+ "user_input": "List playlists by number of tracks"
611
+ },
612
+ {
613
+ "query": "SELECT p.Name, SUM(t.Milliseconds)/60000 as Minutes FROM Playlist p JOIN PlaylistTrack pt ON p.PlaylistId = pt.PlaylistId JOIN Track t ON pt.TrackId = t.TrackId GROUP BY p.PlaylistId ORDER BY Minutes DESC",
614
+ "user_input": "Find the total duration of each playlist in minutes"
615
+ },
616
+ {
617
+ "query": "SELECT p.Name, COUNT(DISTINCT g.GenreId) as GenreCount FROM Playlist p JOIN PlaylistTrack pt ON p.PlaylistId = pt.PlaylistId JOIN Track t ON pt.TrackId = t.TrackId JOIN Genre g ON t.GenreId = g.GenreId GROUP BY p.PlaylistId ORDER BY GenreCount DESC",
618
+ "user_input": "Count how many different genres are in each playlist"
619
+ }
620
+ ]
621
+ },
622
+ {
623
+ "separator": "table_10",
624
+ "name": "PlaylistTrack",
625
+ "schema": "CREATE TABLE PlaylistTrack (PlaylistId integer, TrackId integer, PRIMARY KEY (PlaylistId, TrackId));",
626
+ "description": "This junction table connects playlists and tracks in a many-to-many relationship.",
627
+ "columns": [
628
+ {
629
+ "name": "PlaylistId",
630
+ "description": "Id of the playlist",
631
+ "synonyms": [
632
+ "playlist id"
633
+ ]
634
+ },
635
+ {
636
+ "name": "TrackId",
637
+ "description": "Id of the track in the playlist",
638
+ "synonyms": [
639
+ "track id"
640
+ ]
641
+ }
642
+ ],
643
+ "sample_queries": [
644
+ {
645
+ "query": "SELECT t.Name FROM Track t JOIN PlaylistTrack pt ON t.TrackId = pt.TrackId WHERE pt.PlaylistId = 1",
646
+ "user_input": "Get all tracks in playlist 1"
647
+ },
648
+ {
649
+ "query": "SELECT t.Name, COUNT(pt.PlaylistId) as PlaylistCount FROM Track t JOIN PlaylistTrack pt ON t.TrackId = pt.TrackId GROUP BY t.TrackId ORDER BY PlaylistCount DESC LIMIT 10",
650
+ "user_input": "Find the top 10 tracks that appear in the most playlists"
651
+ },
652
+ {
653
+ "query": "SELECT p1.Name, p2.Name, COUNT(pt1.TrackId) as CommonTracks FROM Playlist p1 JOIN PlaylistTrack pt1 ON p1.PlaylistId = pt1.PlaylistId JOIN PlaylistTrack pt2 ON pt1.TrackId = pt2.TrackId JOIN Playlist p2 ON pt2.PlaylistId = p2.PlaylistId WHERE p1.PlaylistId < p2.PlaylistId GROUP BY p1.PlaylistId, p2.PlaylistId ORDER BY CommonTracks DESC LIMIT 5",
654
+ "user_input": "Find the top 5 playlist pairs that share the most tracks"
655
+ },
656
+ {
657
+ "query": "SELECT g.Name, COUNT(pt.TrackId) as TracksInPlaylists FROM Genre g JOIN Track t ON g.GenreId = t.GenreId JOIN PlaylistTrack pt ON t.TrackId = pt.TrackId GROUP BY g.GenreId ORDER BY TracksInPlaylists DESC",
658
+ "user_input": "Popularity of genres in playlists by track count"
659
+ }
660
+ ]
661
+ },
662
+ {
663
+ "separator": "table_11",
664
+ "name": "Track",
665
+ "schema": "CREATE TABLE Track (TrackId integer PRIMARY KEY, Name character varying(200), AlbumId integer, MediaTypeId integer, GenreId integer, Composer character varying(220), Milliseconds integer, Bytes integer, UnitPrice numeric(10,2));",
666
+ "description": "This table stores information about individual music tracks.",
667
+ "columns": [
668
+ {
669
+ "name": "TrackId",
670
+ "description": "unique identifier for tracks",
671
+ "synonyms": [
672
+ "track id"
673
+ ]
674
+ },
675
+ {
676
+ "name": "Name",
677
+ "description": "name of the track",
678
+ "synonyms": [
679
+ "track name",
680
+ "song title"
681
+ ]
682
+ },
683
+ {
684
+ "name": "AlbumId",
685
+ "description": "Id of the album this track belongs to",
686
+ "synonyms": [
687
+ "album id"
688
+ ]
689
+ },
690
+ {
691
+ "name": "MediaTypeId",
692
+ "description": "Id of the media type for this track",
693
+ "synonyms": [
694
+ "media type id",
695
+ "format id"
696
+ ]
697
+ },
698
+ {
699
+ "name": "GenreId",
700
+ "description": "Id of the genre this track belongs to",
701
+ "synonyms": [
702
+ "genre id"
703
+ ]
704
+ },
705
+ {
706
+ "name": "Composer",
707
+ "description": "name of the composer of this track",
708
+ "synonyms": [
709
+ "writer",
710
+ "songwriter"
711
+ ]
712
+ },
713
+ {
714
+ "name": "Milliseconds",
715
+ "description": "length of the track in milliseconds",
716
+ "synonyms": [
717
+ "duration",
718
+ "length"
719
+ ]
720
+ },
721
+ {
722
+ "name": "Bytes",
723
+ "description": "size of the track file in bytes",
724
+ "synonyms": [
725
+ "file size"
726
+ ]
727
+ },
728
+ {
729
+ "name": "UnitPrice",
730
+ "description": "price of the track",
731
+ "synonyms": [
732
+ "price",
733
+ "cost"
734
+ ]
735
+ }
736
+ ],
737
+ "sample_queries": [
738
+ {
739
+ "query": "SELECT t.Name, a.Title FROM Track t JOIN Album a ON t.AlbumId = a.AlbumId WHERE t.Composer LIKE '%Johnson%'",
740
+ "user_input": "Find all tracks composed by someone with Johnson in their name"
741
+ },
742
+ {
743
+ "query": "SELECT Name, Milliseconds/60000 as Minutes FROM Track ORDER BY Milliseconds DESC LIMIT 10",
744
+ "user_input": "Find the 10 longest tracks by duration"
745
+ },
746
+ {
747
+ "query": "SELECT g.Name, AVG(t.UnitPrice) as AvgPrice FROM Track t JOIN Genre g ON t.GenreId = g.GenreId GROUP BY t.GenreId ORDER BY AvgPrice DESC",
748
+ "user_input": "Average price of tracks by genre"
749
+ },
750
+ {
751
+ "query": "SELECT Name, Composer, Milliseconds/60000 as Minutes FROM Track WHERE Composer IS NOT NULL AND Milliseconds = (SELECT MAX(Milliseconds) FROM Track)",
752
+ "user_input": "Find the longest track and its composer"
753
+ }
754
+ ]
755
+ }
756
+ ]
757
+ }
prompt.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_core.prompts import ChatPromptTemplate
2
+
3
+ template = """Based on the table schema below, write a SQL query that would answer the user's question:
4
+ {context}
5
+
6
+ Question: {question}
7
+ SQL Query:"""
8
+
9
+ prompt = ChatPromptTemplate.from_messages(
10
+ [
11
+ ("system", "Given an input question, convert it to a SQL query. No pre-amble."),
12
+ ("human", template),
13
+ ]
14
+ )
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ jq
2
+ langchain
3
+ langchain-huggingface
4
+ sentence_transformers
5
+ langchain-google-genai
6
+ google-ai-generativelanguage==0.6.15
7
+ langchain-community
8
+ langchain-chroma
utils.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sqlite3
2
+ import pandas as pd
3
+
4
+ def execute_query_and_return_df(query):
5
+ """
6
+ Executes a SQL query and returns the results as a Pandas DataFrame.
7
+
8
+ Args:
9
+ query: The SQL query to execute.
10
+
11
+ Returns:
12
+ A Pandas DataFrame containing the query results.
13
+ """
14
+ conn = sqlite3.connect("/content/Chinook_Sqlite.sqlite")
15
+ try:
16
+ df = pd.read_sql_query(query, conn)
17
+ return df
18
+ except Exception as e:
19
+ print(f"Error executing query: {e}")
20
+ return None
21
+ finally:
22
+ conn.close()