Spaces:
Sleeping
Sleeping
saman shrestha
commited on
Commit
·
6812fa3
1
Parent(s):
80a80a1
initial commit
Browse files- .gitignore +78 -0
- Dockerfile +31 -0
- flask_app.py +162 -0
- helpers/GROQ.py +115 -0
- helpers/postgres.py +45 -0
- helpers/prompts.py +11 -0
- prompts/base_prompts.txt +51 -0
- requirement.txt +0 -0
- requirements.txt +59 -0
.gitignore
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ignore all env directories
|
2 |
+
env/
|
3 |
+
venv/
|
4 |
+
.env/
|
5 |
+
.venv/
|
6 |
+
|
7 |
+
# Ignore environment-related files
|
8 |
+
*.env
|
9 |
+
.env.*
|
10 |
+
.envrc
|
11 |
+
|
12 |
+
# Ignore Python virtual environment files
|
13 |
+
pyvenv.cfg
|
14 |
+
# Ignore Python bytecode files
|
15 |
+
__pycache__/
|
16 |
+
*.py[cod]
|
17 |
+
*$py.class
|
18 |
+
|
19 |
+
# Ignore Python distribution / packaging
|
20 |
+
.Python
|
21 |
+
build/
|
22 |
+
develop-eggs/
|
23 |
+
dist/
|
24 |
+
downloads/
|
25 |
+
eggs/
|
26 |
+
.eggs/
|
27 |
+
lib/
|
28 |
+
lib64/
|
29 |
+
parts/
|
30 |
+
sdist/
|
31 |
+
var/
|
32 |
+
wheels/
|
33 |
+
share/python-wheels/
|
34 |
+
*.egg-info/
|
35 |
+
.installed.cfg
|
36 |
+
*.egg
|
37 |
+
MANIFEST
|
38 |
+
|
39 |
+
# Ignore pip logs
|
40 |
+
pip-log.txt
|
41 |
+
pip-delete-this-directory.txt
|
42 |
+
|
43 |
+
# Ignore Python testing
|
44 |
+
.tox/
|
45 |
+
.coverage
|
46 |
+
.coverage.*
|
47 |
+
.cache
|
48 |
+
nosetests.xml
|
49 |
+
coverage.xml
|
50 |
+
*.cover
|
51 |
+
*.py,cover
|
52 |
+
.hypothesis/
|
53 |
+
.pytest_cache/
|
54 |
+
cover/
|
55 |
+
|
56 |
+
# Ignore Jupyter Notebook
|
57 |
+
.ipynb_checkpoints
|
58 |
+
|
59 |
+
# Ignore IPython
|
60 |
+
profile_default/
|
61 |
+
ipython_config.py
|
62 |
+
|
63 |
+
# Ignore mypy
|
64 |
+
.mypy_cache/
|
65 |
+
.dmypy.json
|
66 |
+
dmypy.json
|
67 |
+
|
68 |
+
# Ignore Pylint
|
69 |
+
.pylintrc
|
70 |
+
|
71 |
+
# Ignore Python rope project settings
|
72 |
+
.ropeproject
|
73 |
+
|
74 |
+
# Ignore mkdocs documentation
|
75 |
+
/site
|
76 |
+
|
77 |
+
# Ignore Sphinx documentation
|
78 |
+
docs/_build/
|
Dockerfile
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.10-slim-buster
|
2 |
+
|
3 |
+
### Set up user with permissions
|
4 |
+
# Set up a new user named "user" with user ID 1000
|
5 |
+
RUN useradd -m -u 1000 user
|
6 |
+
|
7 |
+
# Switch to the "user" user
|
8 |
+
USER user
|
9 |
+
|
10 |
+
# Set home to the user's home directory
|
11 |
+
ENV HOME=/home/user \
|
12 |
+
PATH=/home/user/.local/bin:$PATH
|
13 |
+
|
14 |
+
# Set the working directory to the user's home directory
|
15 |
+
WORKDIR $HOME/app
|
16 |
+
|
17 |
+
# Copy the current directory contents into the container at $HOME/app setting the owner to the user
|
18 |
+
COPY --chown=user . $HOME/app
|
19 |
+
|
20 |
+
### Set up app-specific content
|
21 |
+
COPY requirements.txt requirements.txt
|
22 |
+
RUN pip3 install -r requirements.txt
|
23 |
+
|
24 |
+
COPY . .
|
25 |
+
|
26 |
+
### Update permissions for the app
|
27 |
+
USER root
|
28 |
+
RUN chmod 777 ~/app/*
|
29 |
+
USER user
|
30 |
+
|
31 |
+
CMD ["gunicorn", "-b", "0.0.0.0:7860", "flask_app:app"]
|
flask_app.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from flask import Flask, g, render_template, request, jsonify, session
|
2 |
+
import os
|
3 |
+
|
4 |
+
from helpers.GROQ import ConversationGROQ
|
5 |
+
from helpers.postgres import DatabaseConnection
|
6 |
+
from helpers.prompts import PromptManager
|
7 |
+
import re
|
8 |
+
import pandas as pd
|
9 |
+
|
10 |
+
app = Flask(__name__)
|
11 |
+
app.secret_key = os.urandom(24) # Set a secret key for session encryption
|
12 |
+
|
13 |
+
prompt_manager = PromptManager()
|
14 |
+
prompt_manager.load_prompt('base_schema', 'prompts/base_prompts.txt')
|
15 |
+
base_prompt = prompt_manager.get_prompt('base_schema')
|
16 |
+
def extract_sql_regex(input_string) -> str | None:
|
17 |
+
# Pattern to match SQL query within double quotes after "sql":
|
18 |
+
pattern = r'"sql":\s*"(.*?)"'
|
19 |
+
|
20 |
+
match = re.search(pattern, input_string)
|
21 |
+
if match:
|
22 |
+
return match.group(1)
|
23 |
+
else:
|
24 |
+
return None
|
25 |
+
|
26 |
+
# Add a function to get or create the database connection
|
27 |
+
def get_db():
|
28 |
+
if 'db' not in g:
|
29 |
+
return DatabaseConnection(
|
30 |
+
db_host=session['DB_HOST'],
|
31 |
+
db_port=session['DB_PORT'],
|
32 |
+
db_name=session['DB_NAME'],
|
33 |
+
db_user=session['DB_USER'],
|
34 |
+
db_password=session['DB_PASSWORD']
|
35 |
+
)
|
36 |
+
return None
|
37 |
+
|
38 |
+
@app.route('/', methods=['POST'])
|
39 |
+
def index():
|
40 |
+
data = request.json
|
41 |
+
db_user = data.get('DB_USER', '')
|
42 |
+
db_host = data.get('DB_HOST', '')
|
43 |
+
db_port = data.get('DB_PORT', '')
|
44 |
+
db_name = data.get('DB_NAME', '')
|
45 |
+
db_password = data.get('DB_PASSWORD', '')
|
46 |
+
missing_fields = []
|
47 |
+
if not db_user:
|
48 |
+
missing_fields.append('DB_USER')
|
49 |
+
if not db_host:
|
50 |
+
missing_fields.append('DB_HOST')
|
51 |
+
if not db_port:
|
52 |
+
missing_fields.append('DB_PORT')
|
53 |
+
if not db_name:
|
54 |
+
missing_fields.append('DB_NAME')
|
55 |
+
if not db_password:
|
56 |
+
missing_fields.append('DB_PASSWORD')
|
57 |
+
|
58 |
+
if missing_fields:
|
59 |
+
return jsonify({
|
60 |
+
"error": f"Missing credentials: {', '.join(missing_fields)}",
|
61 |
+
"format": "json"
|
62 |
+
}), 400
|
63 |
+
|
64 |
+
# Store database credentials in session
|
65 |
+
session['DB_HOST'] = db_host
|
66 |
+
session['DB_PORT'] = db_port
|
67 |
+
session['DB_NAME'] = db_name
|
68 |
+
session['DB_USER'] = db_user
|
69 |
+
session['DB_PASSWORD'] = db_password
|
70 |
+
|
71 |
+
# Test the connection
|
72 |
+
try:
|
73 |
+
db = get_db()
|
74 |
+
if db is None:
|
75 |
+
return jsonify({"error": "Database connection failed", "format": "json"}), 500
|
76 |
+
return jsonify({"message": "Database connection successful", "format": "json"}), 200
|
77 |
+
except Exception as e:
|
78 |
+
return jsonify({"error": f"Database connection failed: {str(e)}", "format": "json"}), 500
|
79 |
+
|
80 |
+
@app.route('/chat', methods=['POST'])
|
81 |
+
def chat():
|
82 |
+
data = request.json
|
83 |
+
|
84 |
+
if 'DB_HOST' not in session:
|
85 |
+
return jsonify({"error": "Database connection not established", "format": "json"}), 400
|
86 |
+
|
87 |
+
prompt = data.get('prompt', '')
|
88 |
+
if not prompt:
|
89 |
+
return jsonify({"error": "Prompt is required", "format": "json"}), 400
|
90 |
+
|
91 |
+
db = get_db()
|
92 |
+
|
93 |
+
schema = db.execute_query('SELECT schema_name FROM information_schema.schemata;').fetchall()
|
94 |
+
schema = [schema[0] for schema in schema]
|
95 |
+
|
96 |
+
tables = db.execute_query('''SELECT
|
97 |
+
table_name,
|
98 |
+
json_object_agg(column_name, data_type) AS columns
|
99 |
+
FROM
|
100 |
+
information_schema.columns
|
101 |
+
WHERE
|
102 |
+
table_schema = 'public'
|
103 |
+
GROUP BY
|
104 |
+
table_name
|
105 |
+
ORDER BY
|
106 |
+
table_name;''').fetchall()
|
107 |
+
table_info = {table[0]: table[1] for table in tables}
|
108 |
+
full_prompt = base_prompt.format(schema_list=schema, tables=tables, table_info=table_info, user_question=prompt)
|
109 |
+
|
110 |
+
groq = ConversationGROQ(api_key='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')
|
111 |
+
groq.create_conversation(full_prompt)
|
112 |
+
response = groq.chat(prompt)
|
113 |
+
sql_query = extract_sql_regex(response)
|
114 |
+
if(sql_query is None):
|
115 |
+
print("No SQL query found")
|
116 |
+
return jsonify({"message": response, "response": response, "Sql": sql_query,"format": "json"}), 200
|
117 |
+
result = db.execute_query(sql_query)
|
118 |
+
print(sql_query, 'result')
|
119 |
+
row = result.fetchall()
|
120 |
+
df = pd.DataFrame(row, columns=[desc[0] for desc in result.description])
|
121 |
+
df = df.reset_index(drop=True)
|
122 |
+
print(df.to_markdown(index=False))
|
123 |
+
prompt = """
|
124 |
+
A user asked the following question:
|
125 |
+
{user_question}
|
126 |
+
|
127 |
+
Based on this question, a query was executed and returned the following data:
|
128 |
+
|
129 |
+
{df}
|
130 |
+
|
131 |
+
Please provide a clear and concise summary of this data in non-technical language.
|
132 |
+
Focus on the key insights and how they relate to the user's question.
|
133 |
+
Avoid using technical terms and present the information in a way that's easy for anyone to understand.
|
134 |
+
|
135 |
+
If there are any notable trends, patterns, or important points in the data, please highlight them.
|
136 |
+
If the data includes price or amount information, please also provide a brief comparison. For example, highlight the highest and lowest values, or compare average prices/amounts between different categories if applicable.
|
137 |
+
|
138 |
+
Additionally, if the data spans multiple time periods (e.g., different dates or years), please provide a brief overview of any trends or changes over time.
|
139 |
+
If applicable, include any relevant statistics or figures, but explain them in simple terms.
|
140 |
+
|
141 |
+
Your summary should be informative yet accessible to someone without a technical background.
|
142 |
+
""".format(user_question=prompt, df=df)
|
143 |
+
final_response = groq.chat(prompt)
|
144 |
+
print(final_response)
|
145 |
+
return jsonify({"message": final_response, "df": df.to_html(),"response": response, "sql": sql_query,"format": "json"}), 200
|
146 |
+
|
147 |
+
@app.route('/chat', methods=['POST'])
|
148 |
+
def query():
|
149 |
+
data = request.json
|
150 |
+
# Process the query here
|
151 |
+
print(data)
|
152 |
+
# For now, we'll just echo back the received data
|
153 |
+
return jsonify({"response": f"Received: {data}"})
|
154 |
+
|
155 |
+
@app.teardown_appcontext
|
156 |
+
def close_db(error):
|
157 |
+
db = g.pop('db', None)
|
158 |
+
if db is not None:
|
159 |
+
db.close()
|
160 |
+
|
161 |
+
if __name__ == '__main__':
|
162 |
+
app.run(debug=True, port=5001)
|
helpers/GROQ.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from groq import Groq
|
2 |
+
from langchain_groq import ChatGroq
|
3 |
+
from langchain_core.prompts import (
|
4 |
+
ChatPromptTemplate,
|
5 |
+
HumanMessagePromptTemplate,
|
6 |
+
MessagesPlaceholder,
|
7 |
+
)
|
8 |
+
from langchain.chains import LLMChain
|
9 |
+
from langchain_core.messages import SystemMessage
|
10 |
+
from langchain.chains.conversation.memory import ConversationBufferWindowMemory
|
11 |
+
from typing import Dict, Optional
|
12 |
+
import pandas as pd
|
13 |
+
|
14 |
+
class GROQ:
|
15 |
+
def __init__(self, api_key: str = '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'):
|
16 |
+
self.client: Groq = Groq(
|
17 |
+
api_key=api_key
|
18 |
+
)
|
19 |
+
|
20 |
+
def chat(self, prompt: str, model: str, response_format: Optional[Dict]) -> str:
|
21 |
+
completion = self.client.chat.completions.create(
|
22 |
+
model=model, messages=[{"role": "user", "content": prompt}], response_format=response_format)
|
23 |
+
|
24 |
+
return completion.choices[0].message.content
|
25 |
+
|
26 |
+
def errorChat(self, user_question: str, sql_query: str, error: str, model: str) -> str:
|
27 |
+
# Check the ai need user feedback or not
|
28 |
+
prompt = """
|
29 |
+
User question: {user_question}
|
30 |
+
Error: {error}
|
31 |
+
Error Occured in thisSQL Query: {sql_query}
|
32 |
+
Update the SQL query to fix the error.
|
33 |
+
if its need user feedback, return the feedback prompt. If not, return None.
|
34 |
+
Response in json {{"sql": <sql query here>, "feedback": <feedback prompt here>, "summarization": <summarization prompt here>,
|
35 |
+
"user_feedback": boolean if true send {{"user_feedback": true}} if false send {{"user_feedback": false}}
|
36 |
+
""".format(user_question = user_question, sql_query = sql_query, error = error)
|
37 |
+
return self.chat(prompt, model, None)
|
38 |
+
|
39 |
+
def get_summarization(self, user_question: str, df: pd.DataFrame, model: str) -> str:
|
40 |
+
"""
|
41 |
+
This function generates a summarization prompt based on the user's question and the resulting data.
|
42 |
+
It then sends this summarization prompt to the Groq API and retrieves the AI's response.
|
43 |
+
|
44 |
+
Parameters:
|
45 |
+
client (Groqcloud): The Groq API client.
|
46 |
+
user_question (str): The user's question.
|
47 |
+
df (DataFrame): The DataFrame resulting from the SQL query.
|
48 |
+
model (str): The AI model to use for the response.
|
49 |
+
|
50 |
+
|
51 |
+
Returns:
|
52 |
+
str: The content of the AI's response to the summarization prompt.
|
53 |
+
"""
|
54 |
+
prompt = '''
|
55 |
+
A user asked the following question pertaining to local database tables:
|
56 |
+
|
57 |
+
{user_question}
|
58 |
+
|
59 |
+
To answer the question, a dataframe was returned:
|
60 |
+
Dataframe should be shown as a table.
|
61 |
+
* Dataframe is structured as easy to read table.
|
62 |
+
* Dataframe is clean and dates are converted to a readable format.
|
63 |
+
|
64 |
+
Dataframe:
|
65 |
+
{df}
|
66 |
+
* Ensure all numeric values are formatted with appropriate precision.
|
67 |
+
* If there are any percentages, display them with the % symbol.
|
68 |
+
* Format any currency values with the appropriate currency symbol and decimal places.
|
69 |
+
* If there are any date columns, format them as 'YYYY-MM-DD' for clarity.
|
70 |
+
* If the dataframe has more than 10 rows, show only the first 10 rows and indicate there are more.
|
71 |
+
* Include the total number of rows in the dataframe.
|
72 |
+
|
73 |
+
In a few sentences and show the dataframe, summarize the data in the table as it pertains to the original user question. Avoid qualifiers like "based on the data" and do not comment on the structure or metadata of the table itself
|
74 |
+
'''.format(user_question = user_question, df = df.to_markdown(index=False))
|
75 |
+
# Response format is set to 'None'
|
76 |
+
return self.chat(prompt,model,None)
|
77 |
+
|
78 |
+
|
79 |
+
class ConversationGROQ:
|
80 |
+
def __init__(self, conversational_memory_length: int = 10, api_key: str = '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', model: str = 'llama3-8b-8192'):
|
81 |
+
self.client: ChatGroq = ChatGroq(
|
82 |
+
groq_api_key=api_key,
|
83 |
+
model=model
|
84 |
+
)
|
85 |
+
self.memory: ConversationBufferWindowMemory = ConversationBufferWindowMemory(k=conversational_memory_length, memory_key="chat_history", return_messages=True)
|
86 |
+
self.conversation: Optional[LLMChain] = None
|
87 |
+
|
88 |
+
def create_template(self, base_prompt: str) -> ChatPromptTemplate:
|
89 |
+
return ChatPromptTemplate.from_messages([
|
90 |
+
SystemMessage(
|
91 |
+
content=base_prompt
|
92 |
+
), # This is the persistent system prompt that is always included at the start of the chat.
|
93 |
+
|
94 |
+
MessagesPlaceholder(
|
95 |
+
variable_name="chat_history"
|
96 |
+
), # This placeholder will be replaced by the actual chat history during the conversation. It helps in maintaining context.
|
97 |
+
|
98 |
+
HumanMessagePromptTemplate.from_template(
|
99 |
+
"{human_input}"
|
100 |
+
), # This template is where the user's current input will be injected into the prompt.
|
101 |
+
])
|
102 |
+
|
103 |
+
def create_conversation(self, prompt: str) -> LLMChain:
|
104 |
+
self.conversation = LLMChain(
|
105 |
+
llm=self.client,
|
106 |
+
memory=self.memory,
|
107 |
+
prompt=self.create_template(prompt),
|
108 |
+
verbose=True
|
109 |
+
)
|
110 |
+
return self.conversation
|
111 |
+
|
112 |
+
def chat(self, user_input: str) -> str:
|
113 |
+
if self.conversation is None:
|
114 |
+
raise ValueError("Conversation not initialized. Call create_conversation() first.")
|
115 |
+
return self.conversation.predict(human_input =user_input)
|
helpers/postgres.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import psycopg2
|
3 |
+
from psycopg2 import pool, OperationalError
|
4 |
+
from dotenv import load_dotenv
|
5 |
+
from typing import Optional, Union, List, Dict, Tuple
|
6 |
+
import psycopg2.extensions
|
7 |
+
|
8 |
+
# Load environment variables
|
9 |
+
load_dotenv()
|
10 |
+
def log_method(func):
|
11 |
+
def wrapper(*args, **kwargs):
|
12 |
+
print(f"Calling method {func.__name__}")
|
13 |
+
return func(*args, **kwargs)
|
14 |
+
return wrapper
|
15 |
+
|
16 |
+
class DatabaseConnection:
|
17 |
+
_instance = None # For Singleton pattern (optional)
|
18 |
+
|
19 |
+
|
20 |
+
|
21 |
+
def __init__(self, db_user, db_password, db_host, db_port, db_name):
|
22 |
+
self.db_user = db_user
|
23 |
+
self.db_password = db_password
|
24 |
+
self.db_host = db_host
|
25 |
+
self.db_port = db_port
|
26 |
+
self.db_name = db_name
|
27 |
+
try:
|
28 |
+
# Create a connection pool (min and max connection count)
|
29 |
+
self.connection_pool = psycopg2.pool.SimpleConnectionPool(
|
30 |
+
1, 10, # Min and max number of connections
|
31 |
+
user=db_user,
|
32 |
+
password=db_password,
|
33 |
+
host=db_host,
|
34 |
+
port=db_port,
|
35 |
+
database=db_name
|
36 |
+
)
|
37 |
+
if self.connection_pool:
|
38 |
+
print("Connection pool created successfully")
|
39 |
+
except OperationalError as e:
|
40 |
+
print(f"Error while connecting to PostgreSQL: {e}")
|
41 |
+
def execute_query(self, query: str, params: Optional[Union[List, Dict]] = None, connection: Optional[psycopg2.extensions.connection] = None) -> List[Tuple]:
|
42 |
+
cursor = self.connection_pool.getconn().cursor()
|
43 |
+
cursor.execute(query, params)
|
44 |
+
|
45 |
+
return cursor
|
helpers/prompts.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
class PromptManager:
|
3 |
+
def __init__(self):
|
4 |
+
self.prompts = {}
|
5 |
+
|
6 |
+
def load_prompt(self, name, file_path):
|
7 |
+
with open(file_path, 'r') as file:
|
8 |
+
self.prompts[name] = file.read()
|
9 |
+
|
10 |
+
def get_prompt(self, name):
|
11 |
+
return self.prompts.get(name, '')
|
prompts/base_prompts.txt
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
You are SQL Advisor, tasked with generating SQL queries for PostgreSQL based on user questions about data stored in these schemas:
|
2 |
+
Perform a detailed analysis of the following PostgreSQL database schemas. For each schema, determine its purpose, possible relationships with other schemas, and its overall role within the database system. Categorize schemas into system-level (e.g., those related to database management) and user-defined schemas (e.g., for specific business processes). Provide recommendations for improving schema organization, minimizing redundancy, and optimizing performance.
|
3 |
+
|
4 |
+
Analysis is the schema list: {schema_list}.
|
5 |
+
|
6 |
+
Key Areas to Focus On:
|
7 |
+
|
8 |
+
Categorizing schemas into system-level and user-defined.
|
9 |
+
Analyzing the function and use case of each schema.
|
10 |
+
Offering suggestions for schema optimization and best practices for management.
|
11 |
+
|
12 |
+
Important Notice:
|
13 |
+
* This system is designed for read-only operations. Queries that modify data (INSERT, UPDATE, DELETE) are not permitted.
|
14 |
+
* If a user requests data modification, respond with an error message explaining that such operations are restricted.
|
15 |
+
|
16 |
+
|
17 |
+
|
18 |
+
Read this table and remember the table name and all the columns you need to search for the user's question.
|
19 |
+
Here is the table format:
|
20 |
+
Format: {{
|
21 |
+
"table_name": {{
|
22 |
+
"column_name": "data_type",
|
23 |
+
...
|
24 |
+
}},
|
25 |
+
...
|
26 |
+
}} to understand the table structure.
|
27 |
+
Table information is :
|
28 |
+
{table_info}
|
29 |
+
|
30 |
+
Reminder:
|
31 |
+
* If the user asks for a greeting or introduction then respond with some greetings.
|
32 |
+
|
33 |
+
Given a user's question about data in a specific schema, write a valid PostgreSQL SQL query that accurately extracts or calculates the requested information from the tables in that schema, adhering to SQL best practices for PostgreSQL, optimizing for readability and performance where applicable.
|
34 |
+
|
35 |
+
Here are some tips for writing PostgreSQL queries:
|
36 |
+
* Use standard SQL syntax for querying tables
|
37 |
+
* Include the schema name when referencing tables (e.g., schema_name.table_name)
|
38 |
+
* Include appropriate JOIN clauses when querying across multiple tables
|
39 |
+
* Use CURRENT_DATE to get today's date
|
40 |
+
* Alias aggregated fields like COUNT(*) for clarity
|
41 |
+
|
42 |
+
Question:
|
43 |
+
--------
|
44 |
+
{user_question}
|
45 |
+
--------
|
46 |
+
Reminder: Generate a PostgreSQL SQL query to answer the question:
|
47 |
+
* respond as a valid JSON Document
|
48 |
+
* [Best] If the question can be answered with the available tables: {{"sql": <sql here>}}
|
49 |
+
* If the question cannot be answered with the available tables: {{"error": <explanation here>}}
|
50 |
+
* Ensure that the entire output is returned on only one single line
|
51 |
+
* Keep your query as simple and straightforward as possible; avoid unnecessary subqueries
|
requirement.txt
ADDED
File without changes
|
requirements.txt
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
aiohappyeyeballs==2.4.0
|
2 |
+
aiohttp==3.10.5
|
3 |
+
aiosignal==1.3.1
|
4 |
+
annotated-types==0.7.0
|
5 |
+
anyio==4.6.0
|
6 |
+
async-timeout==4.0.3
|
7 |
+
attrs==24.2.0
|
8 |
+
blinker==1.8.2
|
9 |
+
certifi==2024.8.30
|
10 |
+
charset-normalizer==3.3.2
|
11 |
+
click==8.1.7
|
12 |
+
dataclasses-json==0.6.7
|
13 |
+
distro==1.9.0
|
14 |
+
exceptiongroup==1.2.2
|
15 |
+
Flask==3.0.3
|
16 |
+
frozenlist==1.4.1
|
17 |
+
groq==0.11.0
|
18 |
+
h11==0.14.0
|
19 |
+
httpcore==1.0.5
|
20 |
+
httpx==0.27.2
|
21 |
+
idna==3.10
|
22 |
+
itsdangerous==2.2.0
|
23 |
+
Jinja2==3.1.4
|
24 |
+
jsonpatch==1.33
|
25 |
+
jsonpointer==3.0.0
|
26 |
+
langchain==0.1.16
|
27 |
+
langchain-community==0.0.38
|
28 |
+
langchain-core==0.1.52
|
29 |
+
langchain-groq==0.1.5
|
30 |
+
langchain-text-splitters==0.0.2
|
31 |
+
langsmith==0.1.125
|
32 |
+
load-dotenv==0.1.0
|
33 |
+
MarkupSafe==2.1.5
|
34 |
+
marshmallow==3.22.0
|
35 |
+
multidict==6.1.0
|
36 |
+
mypy-extensions==1.0.0
|
37 |
+
numpy==1.26.4
|
38 |
+
orjson==3.10.7
|
39 |
+
packaging==23.2
|
40 |
+
pandas==2.2.3
|
41 |
+
psycopg2-binary==2.9.9
|
42 |
+
pydantic==2.9.2
|
43 |
+
pydantic_core==2.23.4
|
44 |
+
python-dateutil==2.9.0.post0
|
45 |
+
python-dotenv==1.0.1
|
46 |
+
pytz==2024.2
|
47 |
+
PyYAML==6.0.2
|
48 |
+
requests==2.32.3
|
49 |
+
six==1.16.0
|
50 |
+
sniffio==1.3.1
|
51 |
+
SQLAlchemy==2.0.35
|
52 |
+
tabulate==0.9.0
|
53 |
+
tenacity==8.5.0
|
54 |
+
typing-inspect==0.9.0
|
55 |
+
typing_extensions==4.12.2
|
56 |
+
tzdata==2024.1
|
57 |
+
urllib3==2.2.3
|
58 |
+
Werkzeug==3.0.4
|
59 |
+
yarl==1.11.1
|