DrishtiSharma commited on
Commit
a849379
Β·
verified Β·
1 Parent(s): c9e66b7

Update interim.py

Browse files
Files changed (1) hide show
  1. interim.py +43 -48
interim.py CHANGED
@@ -20,10 +20,9 @@ from langchain_community.utilities.sql_database import SQLDatabase
20
  from datasets import load_dataset
21
  import tempfile
22
 
23
- # Setup API key
24
  os.environ["GROQ_API_KEY"] = st.secrets.get("GROQ_API_KEY", "")
25
 
26
- # Callback handler for logging
27
  class LLMCallbackHandler(BaseCallbackHandler):
28
  def __init__(self, log_path: Path):
29
  self.log_path = log_path
@@ -37,103 +36,100 @@ class LLMCallbackHandler(BaseCallbackHandler):
37
  with self.log_path.open("a", encoding="utf-8") as file:
38
  file.write(json.dumps({"event": "llm_end", "text": generation, "timestamp": datetime.now().isoformat()}) + "\n")
39
 
40
- # LLM Setup
41
  llm = ChatGroq(
42
  temperature=0,
43
  model_name="mixtral-8x7b-32768",
44
  callbacks=[LLMCallbackHandler(Path("prompts.jsonl"))],
45
  )
46
 
47
- st.title("SQL-RAG using CrewAI πŸš€")
48
- st.write("Analyze and summarize data using natural language queries with SQL-based retrieval.")
49
 
50
- # File upload or Hugging Face dataset input
51
- option = st.radio("Choose your input method:", ["Upload a CSV file", "Enter Hugging Face dataset name"])
 
52
 
53
- if option == "Upload a CSV file":
54
- uploaded_file = st.file_uploader("Upload your dataset (CSV format)", type=["csv"])
 
 
 
 
 
 
 
 
 
 
 
55
  if uploaded_file:
56
  df = pd.read_csv(uploaded_file)
57
  st.success("File uploaded successfully!")
58
- else:
59
- dataset_name = st.text_input("Enter Hugging Face dataset name:", placeholder="e.g., imdb, ag_news")
60
- if dataset_name:
61
- try:
62
- dataset = load_dataset(dataset_name, split="train")
63
- df = pd.DataFrame(dataset)
64
- st.success(f"Dataset '{dataset_name}' loaded successfully!")
65
- except Exception as e:
66
- st.error(f"Error loading Hugging Face dataset: {e}")
67
- df = None
68
-
69
- if 'df' in locals() and not df.empty:
70
- st.write("### Dataset Preview:")
71
- st.dataframe(df.head())
72
 
73
- # Create a temporary SQLite database
 
74
  temp_dir = tempfile.TemporaryDirectory()
75
  db_path = os.path.join(temp_dir.name, "data.db")
76
  connection = sqlite3.connect(db_path)
77
- df.to_sql("data_table", connection, if_exists="replace", index=False)
78
  db = SQLDatabase.from_uri(f"sqlite:///{db_path}")
79
 
80
- # Tools
81
  @tool("list_tables")
82
  def list_tables() -> str:
 
83
  return ListSQLDatabaseTool(db=db).invoke("")
84
 
85
  @tool("tables_schema")
86
  def tables_schema(tables: str) -> str:
 
87
  return InfoSQLDatabaseTool(db=db).invoke(tables)
88
 
89
  @tool("execute_sql")
90
  def execute_sql(sql_query: str) -> str:
 
91
  return QuerySQLDataBaseTool(db=db).invoke(sql_query)
92
 
93
  @tool("check_sql")
94
  def check_sql(sql_query: str) -> str:
 
95
  return QuerySQLCheckerTool(db=db, llm=llm).invoke({"query": sql_query})
96
 
97
- # Agents
98
  sql_dev = Agent(
99
- role="Database Developer",
100
- goal="Extract data from the database.",
101
  llm=llm,
102
  tools=[list_tables, tables_schema, execute_sql, check_sql],
103
- allow_delegation=False,
104
  )
105
 
106
  data_analyst = Agent(
107
- role="Data Analyst",
108
- goal="Analyze and provide insights.",
109
  llm=llm,
110
- allow_delegation=False,
111
  )
112
 
113
  report_writer = Agent(
114
- role="Report Editor",
115
- goal="Summarize the analysis.",
116
  llm=llm,
117
- allow_delegation=False,
118
  )
119
 
120
- # Tasks
121
  extract_data = Task(
122
- description="Extract data required for the query: {query}.",
123
- expected_output="Database result for the query",
124
  agent=sql_dev,
125
  )
126
 
127
  analyze_data = Task(
128
- description="Analyze the data for: {query}.",
129
- expected_output="Detailed analysis text",
130
  agent=data_analyst,
131
  context=[extract_data],
132
  )
133
 
134
  write_report = Task(
135
- description="Summarize the analysis into a short report.",
136
- expected_output="Markdown report",
137
  agent=report_writer,
138
  context=[analyze_data],
139
  )
@@ -143,12 +139,11 @@ if 'df' in locals() and not df.empty:
143
  tasks=[extract_data, analyze_data, write_report],
144
  process=Process.sequential,
145
  verbose=2,
146
- memory=False,
147
  )
148
 
149
- query = st.text_input("Enter your query:", placeholder="e.g., 'What are the top 5 highest salaries?'")
150
- if query:
151
- with st.spinner("Processing your query..."):
152
  inputs = {"query": query}
153
  result = crew.kickoff(inputs=inputs)
154
  st.markdown("### Analysis Report:")
@@ -156,4 +151,4 @@ if 'df' in locals() and not df.empty:
156
 
157
  temp_dir.cleanup()
158
  else:
159
- st.warning("Please upload a valid file or provide a correct Hugging Face dataset name.")
 
20
  from datasets import load_dataset
21
  import tempfile
22
 
 
23
  os.environ["GROQ_API_KEY"] = st.secrets.get("GROQ_API_KEY", "")
24
 
25
+ # LLM Logging
26
  class LLMCallbackHandler(BaseCallbackHandler):
27
  def __init__(self, log_path: Path):
28
  self.log_path = log_path
 
36
  with self.log_path.open("a", encoding="utf-8") as file:
37
  file.write(json.dumps({"event": "llm_end", "text": generation, "timestamp": datetime.now().isoformat()}) + "\n")
38
 
 
39
  llm = ChatGroq(
40
  temperature=0,
41
  model_name="mixtral-8x7b-32768",
42
  callbacks=[LLMCallbackHandler(Path("prompts.jsonl"))],
43
  )
44
 
45
+ st.title("SQL-RAG Using CrewAI πŸš€")
46
+ st.write("Analyze datasets using natural language queries powered by SQL and CrewAI.")
47
 
48
+ # Data Input Options
49
+ input_option = st.radio("Select Dataset Input:", ["Use Hugging Face Dataset", "Upload CSV File"])
50
+ df = None
51
 
52
+ if input_option == "Use Hugging Face Dataset":
53
+ dataset_name = st.text_input("Enter Hugging Face Dataset Name:", value="Einstellung/demo-salaries")
54
+ if st.button("Load Dataset"):
55
+ try:
56
+ with st.spinner("Loading Hugging Face dataset..."):
57
+ dataset = load_dataset(dataset_name, split="train")
58
+ df = pd.DataFrame(dataset)
59
+ st.success(f"Dataset '{dataset_name}' loaded successfully!")
60
+ st.dataframe(df.head())
61
+ except Exception as e:
62
+ st.error(f"Error loading dataset: {e}")
63
+ else:
64
+ uploaded_file = st.file_uploader("Upload CSV File:", type=["csv"])
65
  if uploaded_file:
66
  df = pd.read_csv(uploaded_file)
67
  st.success("File uploaded successfully!")
68
+ st.dataframe(df.head())
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
+ # SQL-RAG and Query Workflow
71
+ if df is not None:
72
  temp_dir = tempfile.TemporaryDirectory()
73
  db_path = os.path.join(temp_dir.name, "data.db")
74
  connection = sqlite3.connect(db_path)
75
+ df.to_sql("salaries", connection, if_exists="replace", index=False)
76
  db = SQLDatabase.from_uri(f"sqlite:///{db_path}")
77
 
 
78
  @tool("list_tables")
79
  def list_tables() -> str:
80
+ """List all tables in the database."""
81
  return ListSQLDatabaseTool(db=db).invoke("")
82
 
83
  @tool("tables_schema")
84
  def tables_schema(tables: str) -> str:
85
+ """Return schema and example rows for given tables."""
86
  return InfoSQLDatabaseTool(db=db).invoke(tables)
87
 
88
  @tool("execute_sql")
89
  def execute_sql(sql_query: str) -> str:
90
+ """Execute a SQL query and return results."""
91
  return QuerySQLDataBaseTool(db=db).invoke(sql_query)
92
 
93
  @tool("check_sql")
94
  def check_sql(sql_query: str) -> str:
95
+ """Check SQL query validity."""
96
  return QuerySQLCheckerTool(db=db, llm=llm).invoke({"query": sql_query})
97
 
 
98
  sql_dev = Agent(
99
+ role="Senior Database Developer",
100
+ goal="Construct and execute SQL queries.",
101
  llm=llm,
102
  tools=[list_tables, tables_schema, execute_sql, check_sql],
 
103
  )
104
 
105
  data_analyst = Agent(
106
+ role="Senior Data Analyst",
107
+ goal="Analyze the data returned from SQL queries.",
108
  llm=llm,
 
109
  )
110
 
111
  report_writer = Agent(
112
+ role="Senior Report Editor",
113
+ goal="Summarize the analysis into a short report.",
114
  llm=llm,
 
115
  )
116
 
 
117
  extract_data = Task(
118
+ description="Extract data for the query: {query}.",
119
+ expected_output="Database query results.",
120
  agent=sql_dev,
121
  )
122
 
123
  analyze_data = Task(
124
+ description="Analyze the query results for: {query}.",
125
+ expected_output="Detailed analysis report.",
126
  agent=data_analyst,
127
  context=[extract_data],
128
  )
129
 
130
  write_report = Task(
131
+ description="Summarize the analysis into a brief executive summary.",
132
+ expected_output="Markdown report.",
133
  agent=report_writer,
134
  context=[analyze_data],
135
  )
 
139
  tasks=[extract_data, analyze_data, write_report],
140
  process=Process.sequential,
141
  verbose=2,
 
142
  )
143
 
144
+ query = st.text_area("Enter Query:", placeholder="e.g., 'What is the average salary by experience level?'")
145
+ if st.button("Submit Query"):
146
+ with st.spinner("Processing your query with CrewAI..."):
147
  inputs = {"query": query}
148
  result = crew.kickoff(inputs=inputs)
149
  st.markdown("### Analysis Report:")
 
151
 
152
  temp_dir.cleanup()
153
  else:
154
+ st.info("Load a dataset to proceed.")