from langgraph.graph import StateGraph, START, END from typing import TypedDict, Optional from agents.table_selection import table_selection_agent from agents.data_retrieval import sample_data_retrieval_agent from agents.sql_generation import sql_generation_agent from agents.validation import query_validation_and_optimization from agents.execution import execution_agent from utils.bigquery_utils import init_bigquery_connection # Define the state schema class SQLExecutionState(TypedDict): sql_query: str # Natural language query client: Optional[object] # BigQuery client relevant_tables: Optional[list] # Tables identified as relevant sample_data: Optional[dict] # Sample data from relevant tables generated_sql: Optional[str] # The actual SQL query (not JSON) validation_result: Optional[dict] optimized_sql: Optional[str] execution_result: Optional[dict] def initialize_client(state: SQLExecutionState) -> SQLExecutionState: """Initialize the BigQuery client and add it to the state.""" client = init_bigquery_connection() return {"client": client} def create_workflow(): """Create and return the workflow graph.""" # Initialize the LangGraph Workflow graph = StateGraph(state_schema=SQLExecutionState) # Add nodes graph.add_node("Initialize Client", initialize_client) graph.add_node("Table Selection", table_selection_agent) graph.add_node("Sample Data Retrieval", sample_data_retrieval_agent) graph.add_node("SQL Generation", sql_generation_agent) graph.add_node("Query Validation & Optimization", query_validation_and_optimization) graph.add_node("SQL Execution", execution_agent) # Define execution flow graph.add_edge(START, "Initialize Client") graph.add_edge("Initialize Client", "Table Selection") graph.add_edge("Table Selection", "Sample Data Retrieval") graph.add_edge("Sample Data Retrieval", "SQL Generation") graph.add_edge("SQL Generation", "Query Validation & Optimization") graph.add_edge("Query Validation & Optimization", "SQL Execution") graph.add_edge("SQL Execution", END) # Compile the graph return graph.compile()