from flask import Flask, g, render_template, request, jsonify, session from flask_cors import CORS # Add this import from helpers.GROQ import ConversationGROQ from helpers.postgres import DatabaseConnection from helpers.prompts import PromptManager import re import pandas as pd app = Flask(__name__) CORS(app, supports_credentials=True, origins="https://db-bot-amber.vercel.app/", allow_headers=["Content-Type", "Authorization", "Access-Control-Allow-Credentials"], expose_headers=["Set-Cookie"]) prompt_manager = PromptManager() prompt_manager.load_prompt('base_schema', 'prompts/base_prompts.txt') base_prompt = prompt_manager.get_prompt('base_schema') def extract_sql_regex(input_string) -> str | None: # Pattern to match SQL query within double quotes after "sql": pattern = r'"sql":\s*"(.*?)"' match = re.search(pattern, input_string) if match: return match.group(1) else: return None @app.route('/chat', methods=['POST', 'OPTIONS']) # Add OPTIONS method def chat(): # if request.method == 'OPTIONS': # # Respond to preflight request # response = app.make_default_options_response() # response.headers['Access-Control-Allow-Headers'] = 'Content-Type' # response.headers['Access-Control-Allow-Methods'] = 'POST' # return response data = request.json # if 'DB_HOST' not in session: # return jsonify({"error": "Database connection not established", "format": "json"}), 400 db_user = data.get('DB_USER', '') db_host = data.get('DB_HOST', '') db_port = data.get('DB_PORT', '') db_name = data.get('DB_NAME', '') db_password = data.get('DB_PASSWORD', '') prompt = data.get('prompt', '') missing_fields = [] if not prompt: missing_fields.append('prompt') if not db_user: missing_fields.append('DB_USER') if not db_host: missing_fields.append('DB_HOST') if not db_port: missing_fields.append('DB_PORT') if not db_name: missing_fields.append('DB_NAME') if not db_password: missing_fields.append('DB_PASSWORD') if missing_fields: return jsonify({ "error": f"Missing credentials: {', '.join(missing_fields)}", "format": "json" }), 400 db = DatabaseConnection(db_host=db_host, db_port=db_port, db_name=db_name, db_user=db_user, db_password=db_password) schema = db.execute_query('SELECT schema_name FROM information_schema.schemata;').fetchall() schema = [schema[0] for schema in schema] tables = db.execute_query('''SELECT table_name, json_object_agg(column_name, data_type) AS columns FROM information_schema.columns WHERE table_schema = 'public' GROUP BY table_name ORDER BY table_name;''').fetchall() table_info = {table[0]: table[1] for table in tables} full_prompt = base_prompt.format(schema_list=schema, tables=tables, table_info=table_info, user_question=prompt) groq = ConversationGROQ(api_key='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') groq.create_conversation(full_prompt) response = groq.chat(prompt) sql_query = extract_sql_regex(response) if(sql_query is None): print("No SQL query found") return jsonify({"message": response, "response": response, "Sql": sql_query,"format": "json"}), 200 result = db.execute_query(sql_query) print(sql_query, 'result') row = result.fetchall() df = pd.DataFrame(row, columns=[desc[0] for desc in result.description]) df = df.reset_index(drop=True) print(df.to_markdown(index=False)) prompt = """ A user asked the following question: {user_question} Based on this question, a query was executed and returned the following data: {df} Please provide a clear and concise summary of this data in non-technical language. Focus on the key insights and how they relate to the user's question. Avoid using technical terms and present the information in a way that's easy for anyone to understand. If there are any notable trends, patterns, or important points in the data, please highlight them. If the data includes price or amount information, please also provide a brief comparison. For example, highlight the highest and lowest values, or compare average prices/amounts between different categories if applicable. Additionally, if the data spans multiple time periods (e.g., different dates or years), please provide a brief overview of any trends or changes over time. If applicable, include any relevant statistics or figures, but explain them in simple terms. Your summary should be informative yet accessible to someone without a technical background. """.format(user_question=prompt, df=df) final_response = groq.chat(prompt) print(final_response) return jsonify({"message": final_response, "df": df.to_html(),"response": response, "sql": sql_query,"format": "json"}), 200 @app.route('/chat', methods=['POST']) def query(): data = request.json # Process the query here print(data) # For now, we'll just echo back the received data return jsonify({"response": f"Received: {data}"}) @app.teardown_appcontext def close_db(error): db = g.pop('db', None) if db is not None: db.close() if __name__ == '__main__': app.run(debug=True, port=5001)