Spaces:
Runtime error
Runtime error
Upload 4 files
Browse filesadding new model. {DeepSeek Coder}
- My_SQL_Connection.py +37 -17
- app.py +99 -47
- model_functions.py +39 -9
- requirements.txt +3 -4
My_SQL_Connection.py
CHANGED
@@ -4,22 +4,23 @@ import mysql.connector
|
|
4 |
import pandas as pd
|
5 |
###======================================================================database details-=======================================================
|
6 |
def database_details(host,user,password):
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
|
|
23 |
|
24 |
#### =========================================================================retrieving the tables==========================================================
|
25 |
def tables_in_this_DB(host,user,password,db_name):
|
@@ -98,4 +99,23 @@ def create_table_command(host,user,password,db_name):
|
|
98 |
cursor.close()
|
99 |
connection.close()
|
100 |
|
101 |
-
return create_table_statements
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
import pandas as pd
|
5 |
###======================================================================database details-=======================================================
|
6 |
def database_details(host,user,password):
|
7 |
+
|
8 |
+
connection = mysql.connector.connect(
|
9 |
+
host = host,
|
10 |
+
user = user,
|
11 |
+
password = password,
|
12 |
+
buffered = True
|
13 |
+
)
|
14 |
+
cursor = connection.cursor()
|
15 |
+
databases = ("Show databases")
|
16 |
+
cursor.execute(databases)
|
17 |
+
db = []
|
18 |
+
for (databases) in cursor:
|
19 |
+
db.append(databases[0])
|
20 |
+
|
21 |
+
cursor.close()
|
22 |
+
connection.close()
|
23 |
+
return db, len(db)
|
24 |
|
25 |
#### =========================================================================retrieving the tables==========================================================
|
26 |
def tables_in_this_DB(host,user,password,db_name):
|
|
|
99 |
cursor.close()
|
100 |
connection.close()
|
101 |
|
102 |
+
return create_table_statements
|
103 |
+
|
104 |
+
|
105 |
+
def retrieve_result(host,user,password,db_name,query):
|
106 |
+
db_config = {
|
107 |
+
'host': host,
|
108 |
+
'user': user,
|
109 |
+
'password': password,
|
110 |
+
'database': db_name,
|
111 |
+
}
|
112 |
+
|
113 |
+
connection = mysql.connector.connect(**db_config)
|
114 |
+
cursor = connection.cursor()
|
115 |
+
query = query
|
116 |
+
cursor.execute(query)
|
117 |
+
res = cursor.fetchall()
|
118 |
+
|
119 |
+
cursor.close()
|
120 |
+
connection.close()
|
121 |
+
return res
|
app.py
CHANGED
@@ -1,13 +1,12 @@
|
|
1 |
import streamlit as st
|
2 |
|
3 |
-
from My_SQL_Connection import database_details, tables_in_this_DB, printing_tables, create_table_command
|
4 |
from streamlit_option_menu import option_menu
|
5 |
-
from model_functions import LOAD_GEMMA
|
6 |
import torch
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
|
12 |
|
13 |
if 'localhost' not in st.session_state:
|
@@ -18,23 +17,31 @@ if 'localhost' not in st.session_state:
|
|
18 |
|
19 |
with st.sidebar:
|
20 |
selected = option_menu("Querio Lingua", ["Log In", 'main functionalities','Chat with AI'],
|
21 |
-
icons=['
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
if selected == 'Log In':
|
24 |
-
st.title(f'welcome to our web application :green[{user_name}]')
|
25 |
-
st.subheader('welcome to our MY SQL Database Explorer ~ ')
|
26 |
|
27 |
-
st.
|
|
|
28 |
st.session_state.localhost = st.text_input("what is your host, (localhost if in local) or give the url", 'localhost',help='host')
|
29 |
st.session_state.user = st.text_input("what is your user name (usually root)", 'root')
|
30 |
st.session_state.password = st.text_input('Password', type='password')
|
31 |
|
32 |
elif selected == 'main functionalities':
|
33 |
-
st.title(f'welcome to our web application :green[{user_name}]')
|
34 |
st.subheader('welcome to our MY SQL Database Explorer ~ ')
|
35 |
if st.button('All your databases ~ '):
|
36 |
-
|
37 |
-
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
st.subheader('Now we will see details of any database~ ')
|
40 |
|
@@ -44,31 +51,43 @@ elif selected == 'main functionalities':
|
|
44 |
if not st.session_state.db_name:
|
45 |
st.warning('Input database name first')
|
46 |
else:
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
|
|
|
|
51 |
st.subheader('check out tables~ ')
|
52 |
|
53 |
if st.button('Print the tables~'):
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
|
|
|
|
|
|
58 |
|
59 |
st.subheader('Retrieve the CREATE TABLE Statements')
|
60 |
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
|
68 |
|
69 |
elif selected == 'Chat with AI':
|
70 |
-
#st.set_page_config(page_title='🧠MemoryBot🤖', layout='wide')
|
71 |
-
# Initialize session states
|
72 |
if "generated" not in st.session_state:
|
73 |
st.session_state["generated"] = []
|
74 |
if "past" not in st.session_state:
|
@@ -103,32 +122,65 @@ elif selected == 'Chat with AI':
|
|
103 |
st.session_state["past"] = []
|
104 |
st.session_state["input"] = ""
|
105 |
|
106 |
-
with st.sidebar.expander("
|
107 |
-
|
108 |
-
|
109 |
|
110 |
-
st.title("
|
111 |
-
st.
|
|
|
112 |
|
113 |
|
114 |
st.sidebar.button("New Chat", on_click = new_chat, type='primary')
|
|
|
115 |
user_input = get_text()
|
116 |
|
117 |
if user_input:
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
st.session_state.past.append(user_input)
|
129 |
-
st.session_state.generated.append(
|
130 |
-
|
131 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
|
133 |
download_str = []
|
134 |
# Display the conversation history using an expander, and allow the user to download it
|
|
|
1 |
import streamlit as st
|
2 |
|
3 |
+
from My_SQL_Connection import database_details, tables_in_this_DB, printing_tables, create_table_command,retrieve_result
|
4 |
from streamlit_option_menu import option_menu
|
5 |
+
from model_functions import LOAD_GEMMA,DeepSeekCoder,LOAD_GEMMA_GGUF
|
6 |
import torch
|
7 |
+
import mysql.connector
|
8 |
+
user_name = 'arya & Shritama'
|
9 |
+
st.set_page_config(page_title="My SQL Explorer", page_icon="🔍", layout="centered", initial_sidebar_state="expanded")
|
|
|
10 |
|
11 |
|
12 |
if 'localhost' not in st.session_state:
|
|
|
17 |
|
18 |
with st.sidebar:
|
19 |
selected = option_menu("Querio Lingua", ["Log In", 'main functionalities','Chat with AI'],
|
20 |
+
icons=['person-circle', 'info-circle-fill', 'chat-fill'], menu_icon="cast", default_index=0,
|
21 |
+
styles={
|
22 |
+
"container": {"padding": "5!important","background-color":'black'},
|
23 |
+
"icon": {"color": "white", "font-size": "23px"},
|
24 |
+
"nav-link": {"color":"white","font-size": "20px", "text-align": "left", "margin":"0px", "--hover-color": "gray"},
|
25 |
+
"nav-link-selected": {"background-color": "#1B2135"},})
|
26 |
|
27 |
if selected == 'Log In':
|
|
|
|
|
28 |
|
29 |
+
st.subheader('Please Log in into your MySql server by providing the following details ~ ')
|
30 |
+
|
31 |
st.session_state.localhost = st.text_input("what is your host, (localhost if in local) or give the url", 'localhost',help='host')
|
32 |
st.session_state.user = st.text_input("what is your user name (usually root)", 'root')
|
33 |
st.session_state.password = st.text_input('Password', type='password')
|
34 |
|
35 |
elif selected == 'main functionalities':
|
|
|
36 |
st.subheader('welcome to our MY SQL Database Explorer ~ ')
|
37 |
if st.button('All your databases ~ '):
|
38 |
+
try:
|
39 |
+
db, l = database_details(st.session_state.localhost, st.session_state.user, st.session_state.password)
|
40 |
+
st.table(db)
|
41 |
+
except mysql.connector.Error as e:
|
42 |
+
error_code = e.errno
|
43 |
+
st.warning(f"An error occurred (Error Code: {error_code}). Please check your login details.")
|
44 |
+
|
45 |
|
46 |
st.subheader('Now we will see details of any database~ ')
|
47 |
|
|
|
51 |
if not st.session_state.db_name:
|
52 |
st.warning('Input database name first')
|
53 |
else:
|
54 |
+
try:
|
55 |
+
tables, l = tables_in_this_DB(st.session_state.localhost, st.session_state.user, st.session_state.password, st.session_state.db_name)
|
56 |
+
st.write(f'There is only {l} tables present in this database')
|
57 |
+
st.markdown(f"**:rainbow[{tables[0][0]}]**")
|
58 |
+
except mysql.connector.Error as e:
|
59 |
+
st.warning("An error occured. Please select the correct database from the above list or check that you are loged in into your server.")
|
60 |
st.subheader('check out tables~ ')
|
61 |
|
62 |
if st.button('Print the tables~'):
|
63 |
+
try:
|
64 |
+
tables_data = printing_tables(st.session_state.localhost, st.session_state.user, st.session_state.password, st.session_state.db_name)
|
65 |
+
for table_name, table_data in tables_data.items():
|
66 |
+
st.write(f"Table: {table_name}")
|
67 |
+
st.table(table_data)
|
68 |
+
except mysql.connector.Error as e:
|
69 |
+
st.warning("An error occured. Please check that you have selected a database or have loged in into your server.")
|
70 |
|
71 |
st.subheader('Retrieve the CREATE TABLE Statements')
|
72 |
|
73 |
+
statement_options = st.radio("Choose the Context option for chat",["Generate the Context for chat AI based on your tables",
|
74 |
+
"Give custom chat context"])
|
75 |
+
if statement_options == 'Generate the Context for chat AI based on your tables':
|
76 |
+
if st.button('Generate context'):
|
77 |
+
try:
|
78 |
+
statements = create_table_command(st.session_state.localhost, st.session_state.user, st.session_state.password, st.session_state.db_name)
|
79 |
+
for table_name, table_statements in statements.items():
|
80 |
+
st.write(f'{table_name}')
|
81 |
+
st.session_state.table_commands = table_statements
|
82 |
+
st.code(table_statements)
|
83 |
+
except mysql.connector.Error as e:
|
84 |
+
st.warning('An error occured. Please check that you have selected a database or have loged in into your server.')
|
85 |
+
elif statement_options == 'Give custom chat context':
|
86 |
+
context = st.text_area("Paste your context here (Usually the tables schema)")
|
87 |
+
st.session_state.table_commands = context
|
88 |
|
89 |
|
90 |
elif selected == 'Chat with AI':
|
|
|
|
|
91 |
if "generated" not in st.session_state:
|
92 |
st.session_state["generated"] = []
|
93 |
if "past" not in st.session_state:
|
|
|
122 |
st.session_state["past"] = []
|
123 |
st.session_state["input"] = ""
|
124 |
|
125 |
+
#with st.sidebar.expander("Available Fine Tuned Models", expanded=False):
|
126 |
+
MODEL = st.sidebar.selectbox(label='Available Fine Tuned Models', options=['GEMMA-2B','Gemma-GGUF', 'DeepSeekCoder 1.3B'])
|
127 |
+
st.sidebar.warning('Load only one model at a time as it loads the model into cache so it may cause cache overload',icon="⚠️")
|
128 |
|
129 |
+
st.title("Querio Lingua 🤖")
|
130 |
+
st.markdown("Your own SQL code helper⭐")
|
131 |
+
st.markdown(" Powered by GEMMA & DeepSeek🚀")
|
132 |
|
133 |
|
134 |
st.sidebar.button("New Chat", on_click = new_chat, type='primary')
|
135 |
+
|
136 |
user_input = get_text()
|
137 |
|
138 |
if user_input:
|
139 |
+
if MODEL == 'GEMMA-2B':
|
140 |
+
|
141 |
+
gemma_tokenizer,gemma_model = LOAD_GEMMA()
|
142 |
+
device = torch.device("cpu")
|
143 |
+
alpeca_prompt = f"""Below are sql tables schemas paired with instruction that describes a task. Using valid SQLite, write a response that appropriately completes the request for the provided tables.
|
144 |
+
### Instruction: {user_input}. ### Input: {st.session_state.table_commands}
|
145 |
+
### Response:
|
146 |
+
"""
|
147 |
+
with st.status('Generating Result',expanded=False) as status:
|
148 |
+
inputs = gemma_tokenizer([alpeca_prompt], return_tensors="pt").to(device)
|
149 |
+
outputs = gemma_model.generate(**inputs, max_new_tokens=30)
|
150 |
+
output = gemma_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
151 |
+
response_portion = output.split("### Response:")[-1].strip()
|
152 |
+
|
153 |
+
st.session_state.past.append(user_input)
|
154 |
+
st.session_state.generated.append(response_portion)
|
155 |
+
|
156 |
+
status.update(label="Result Generated!", state="complete", expanded=False)
|
157 |
+
|
158 |
+
|
159 |
+
elif MODEL == 'Gemma-GGUF':
|
160 |
+
|
161 |
+
with st.status('Generating Result',expanded=False) as status:
|
162 |
+
response = LOAD_GEMMA_GGUF(user_input,st.session_state.table_commands)
|
163 |
+
response_portion = response.split("### Response:")[-1].strip()
|
164 |
st.session_state.past.append(user_input)
|
165 |
+
st.session_state.generated.append(response_portion)
|
166 |
+
|
167 |
+
elif MODEL == 'DeepSeekCoder 1.3B':
|
168 |
+
|
169 |
+
with st.status('Generating Result',expanded=False) as status:
|
170 |
+
try:
|
171 |
+
response_portion = DeepSeekCoder(user_input,st.session_state.table_commands)
|
172 |
+
final_output = response_portion + f"\n {retrieve_result(st.session_state.localhost, st.session_state.user, st.session_state.password, st.session_state.db_name,response_portion)}"
|
173 |
+
|
174 |
+
st.session_state.past.append(user_input)
|
175 |
+
st.session_state.generated.append(final_output)
|
176 |
+
print(final_output)
|
177 |
+
except mysql.connector.Error as e:
|
178 |
+
|
179 |
+
st.session_state.past.append(user_input)
|
180 |
+
st.session_state.generated.append(response_portion + '{Query not executable}')
|
181 |
+
|
182 |
+
status.update(label="Result Generated!", state="complete", expanded=False)
|
183 |
+
|
184 |
|
185 |
download_str = []
|
186 |
# Display the conversation history using an expander, and allow the user to download it
|
model_functions.py
CHANGED
@@ -1,19 +1,49 @@
|
|
1 |
from transformers import AutoModelForCausalLM,AutoTokenizer
|
2 |
import streamlit as st
|
|
|
|
|
|
|
3 |
|
4 |
@st.cache_resource(show_spinner='Loading the Gemma model. Be patient🙏')
|
5 |
def LOAD_GEMMA():
|
6 |
model_id = "aryachakraborty/GEMMA-2B-NL-SQL"
|
7 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
8 |
-
model = AutoModelForCausalLM.from_pretrained(model_id
|
9 |
-
low_cpu_mem_usage = True
|
10 |
-
).cpu()
|
11 |
return tokenizer,model
|
12 |
|
13 |
-
|
14 |
-
def
|
15 |
-
model_id=''
|
16 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
17 |
-
model = AutoModelForCausalLM.from_pretrained(model_id
|
18 |
-
|
19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from transformers import AutoModelForCausalLM,AutoTokenizer
|
2 |
import streamlit as st
|
3 |
+
import torch
|
4 |
+
from langchain.prompts import PromptTemplate
|
5 |
+
from langchain_community.llms import CTransformers
|
6 |
|
7 |
@st.cache_resource(show_spinner='Loading the Gemma model. Be patient🙏')
|
8 |
def LOAD_GEMMA():
|
9 |
model_id = "aryachakraborty/GEMMA-2B-NL-SQL"
|
10 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
11 |
+
model = AutoModelForCausalLM.from_pretrained(model_id).cpu()
|
|
|
|
|
12 |
return tokenizer,model
|
13 |
|
14 |
+
@st.cache_resource(show_spinner='Loading the DeepSeek Coder model. Be patient🙏')
|
15 |
+
def DeepSeekCoder(user_input, context):
|
16 |
+
model_id='aryachakraborty/DeepSeek_1.3B_Fine_Tuned'
|
17 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
18 |
+
model = AutoModelForCausalLM.from_pretrained(model_id
|
19 |
+
).cpu()
|
20 |
+
device = torch.device("cpu")
|
21 |
+
alpeca_prompt = f"""You are an AI programming assistant, utilizing the Deepseek Coder model, developed by arya chakraborty, and your task is to convert natural language to sql queries. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer.\n
|
22 |
+
Below are sql tables schemas paired with instruction that describes a task. Using valid SQLite, write a response that appropriately completes the request for the provided tables.
|
23 |
+
### Instruction: {user_input}. ### Input: {context}
|
24 |
+
### Response:
|
25 |
+
"""
|
26 |
+
inputs = tokenizer([alpeca_prompt.format(user_input=user_input, context=context)], return_tensors="pt").to(device)
|
27 |
+
outputs = model.generate(**inputs, max_new_tokens=30)
|
28 |
+
output = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
29 |
+
response_portion = output.split("### Response:")[-1].strip()
|
30 |
+
return response_portion
|
31 |
+
|
32 |
+
|
33 |
+
@st.cache_resource(show_spinner='Loading the GGUF GEMMA model. Be patient🙏')
|
34 |
+
def LOAD_GEMMA_GGUF(user_input, context):
|
35 |
+
llm = CTransformers(
|
36 |
+
model='D:\ISnartech Folder\Project_Folder\Streamlit APP\GgufModels\Q4_K_M.gguf',
|
37 |
+
#model_type='llama',
|
38 |
+
config={'max_new_tokens': 256, 'temperature': 0.01}
|
39 |
+
)
|
40 |
+
|
41 |
+
alpeca_prompt = f"""Below are sql tables schemas paired with instruction that describes a task. Using valid SQLite, write a response that appropriately completes the request for the provided tables.
|
42 |
+
### Instruction: {user_input}. ### Input: {context}
|
43 |
+
### Response:
|
44 |
+
"""
|
45 |
+
|
46 |
+
prompt = PromptTemplate(input_variables=['user_input', 'context'], template=alpeca_prompt)
|
47 |
+
#return prompt,llm
|
48 |
+
response = llm(prompt.format(user_input=user_input, context=context))
|
49 |
+
return response
|
requirements.txt
CHANGED
@@ -1,6 +1,3 @@
|
|
1 |
-
bitsandbytes==0.42.0
|
2 |
-
peft==0.8.2
|
3 |
-
trl==0.7.10
|
4 |
accelerate==0.27.1
|
5 |
datasets==2.17.0
|
6 |
transformers==4.38.0
|
@@ -8,4 +5,6 @@ streamlit
|
|
8 |
streamlit_option_menu
|
9 |
torch
|
10 |
mysql.connector
|
11 |
-
pandas
|
|
|
|
|
|
|
|
|
|
|
|
1 |
accelerate==0.27.1
|
2 |
datasets==2.17.0
|
3 |
transformers==4.38.0
|
|
|
5 |
streamlit_option_menu
|
6 |
torch
|
7 |
mysql.connector
|
8 |
+
pandas
|
9 |
+
langchain_community
|
10 |
+
langchain
|