DrishtiSharma commited on
Commit
9bd334d
Β·
verified Β·
1 Parent(s): 77389d5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +110 -101
app.py CHANGED
@@ -7,7 +7,6 @@ from pathlib import Path
7
  from datetime import datetime, timezone
8
  from crewai import Agent, Crew, Process, Task
9
  from crewai_tools import tool
10
- from langchain_core.prompts import ChatPromptTemplate
11
  from langchain_groq import ChatGroq
12
  from langchain.schema.output import LLMResult
13
  from langchain_core.callbacks.base import BaseCallbackHandler
@@ -21,14 +20,10 @@ from langchain_community.utilities.sql_database import SQLDatabase
21
  from datasets import load_dataset
22
  import tempfile
23
 
 
24
  os.environ["GROQ_API_KEY"] = st.secrets.get("GROQ_API_KEY", "")
25
 
26
- class Event:
27
- def __init__(self, event, text):
28
- self.event = event
29
- self.timestamp = datetime.now(timezone.utc).isoformat()
30
- self.text = text
31
-
32
  class LLMCallbackHandler(BaseCallbackHandler):
33
  def __init__(self, log_path: Path):
34
  self.log_path = log_path
@@ -42,6 +37,7 @@ class LLMCallbackHandler(BaseCallbackHandler):
42
  with self.log_path.open("a", encoding="utf-8") as file:
43
  file.write(json.dumps({"event": "llm_end", "text": generation, "timestamp": datetime.now().isoformat()}) + "\n")
44
 
 
45
  llm = ChatGroq(
46
  temperature=0,
47
  model_name="mixtral-8x7b-32768",
@@ -49,102 +45,115 @@ llm = ChatGroq(
49
  )
50
 
51
  st.title("SQL-RAG using CrewAI πŸš€")
52
- st.write("Analyze and summarize Hugging Face datasets using natural language queries with SQL-based retrieval.")
53
-
54
- default_dataset = "datascience/ds-salaries"
55
- st.text("Example dataset: `datascience/ds-salaries` (You can enter your own dataset name)")
56
-
57
- dataset_name = st.text_input("Enter Hugging Face dataset name:", value=default_dataset)
58
-
59
- if dataset_name:
60
- with st.spinner("Loading dataset..."):
 
 
 
 
61
  try:
62
  dataset = load_dataset(dataset_name, split="train")
63
  df = pd.DataFrame(dataset)
64
  st.success(f"Dataset '{dataset_name}' loaded successfully!")
65
- st.write("Preview of the dataset:")
66
- st.dataframe(df.head())
67
-
68
- temp_dir = tempfile.TemporaryDirectory()
69
- db_path = os.path.join(temp_dir.name, "data.db")
70
- connection = sqlite3.connect(db_path)
71
- df.to_sql("data_table", connection, if_exists="replace", index=False)
72
- db = SQLDatabase.from_uri(f"sqlite:///{db_path}")
73
-
74
- @tool("list_tables")
75
- def list_tables() -> str:
76
- return ListSQLDatabaseTool(db=db).invoke("")
77
-
78
- @tool("tables_schema")
79
- def tables_schema(tables: str) -> str:
80
- return InfoSQLDatabaseTool(db=db).invoke(tables)
81
-
82
- @tool("execute_sql")
83
- def execute_sql(sql_query: str) -> str:
84
- return QuerySQLDataBaseTool(db=db).invoke(sql_query)
85
-
86
- @tool("check_sql")
87
- def check_sql(sql_query: str) -> str:
88
- return QuerySQLCheckerTool(db=db, llm=llm).invoke({"query": sql_query})
89
-
90
- sql_dev = Agent(
91
- role="Database Developer",
92
- goal="Extract data from the database.",
93
- llm=llm,
94
- tools=[list_tables, tables_schema, execute_sql, check_sql],
95
- allow_delegation=False,
96
- )
97
-
98
- data_analyst = Agent(
99
- role="Data Analyst",
100
- goal="Analyze and provide insights.",
101
- llm=llm,
102
- allow_delegation=False,
103
- )
104
-
105
- report_writer = Agent(
106
- role="Report Editor",
107
- goal="Summarize the analysis.",
108
- llm=llm,
109
- allow_delegation=False,
110
- )
111
-
112
- extract_data = Task(
113
- description="Extract data required for the query: {query}.",
114
- expected_output="Database result for the query",
115
- agent=sql_dev,
116
- )
117
-
118
- analyze_data = Task(
119
- description="Analyze the data for: {query}.",
120
- expected_output="Detailed analysis text",
121
- agent=data_analyst,
122
- context=[extract_data],
123
- )
124
-
125
- write_report = Task(
126
- description="Summarize the analysis into a short report.",
127
- expected_output="Markdown report",
128
- agent=report_writer,
129
- context=[analyze_data],
130
- )
131
-
132
- crew = Crew(
133
- agents=[sql_dev, data_analyst, report_writer],
134
- tasks=[extract_data, analyze_data, write_report],
135
- process=Process.sequential,
136
- verbose=2,
137
- memory=False,
138
- )
139
-
140
- query = st.text_input("Enter your query:", placeholder="e.g., 'How does salary vary by company size?'")
141
- if query:
142
- with st.spinner("Processing your query..."):
143
- inputs = {"query": query}
144
- result = crew.kickoff(inputs=inputs)
145
- st.markdown("### Analysis Report:")
146
- st.markdown(result)
147
-
148
- temp_dir.cleanup()
149
  except Exception as e:
150
- st.error(f"Error loading dataset: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  from datetime import datetime, timezone
8
  from crewai import Agent, Crew, Process, Task
9
  from crewai_tools import tool
 
10
  from langchain_groq import ChatGroq
11
  from langchain.schema.output import LLMResult
12
  from langchain_core.callbacks.base import BaseCallbackHandler
 
20
  from datasets import load_dataset
21
  import tempfile
22
 
23
+ # Setup API key
24
  os.environ["GROQ_API_KEY"] = st.secrets.get("GROQ_API_KEY", "")
25
 
26
+ # Callback handler for logging
 
 
 
 
 
27
  class LLMCallbackHandler(BaseCallbackHandler):
28
  def __init__(self, log_path: Path):
29
  self.log_path = log_path
 
37
  with self.log_path.open("a", encoding="utf-8") as file:
38
  file.write(json.dumps({"event": "llm_end", "text": generation, "timestamp": datetime.now().isoformat()}) + "\n")
39
 
40
+ # LLM Setup
41
  llm = ChatGroq(
42
  temperature=0,
43
  model_name="mixtral-8x7b-32768",
 
45
  )
46
 
47
  st.title("SQL-RAG using CrewAI πŸš€")
48
+ st.write("Analyze and summarize data using natural language queries with SQL-based retrieval.")
49
+
50
+ # File upload or Hugging Face dataset input
51
+ option = st.radio("Choose your input method:", ["Upload a CSV file", "Enter Hugging Face dataset name"])
52
+
53
+ if option == "Upload a CSV file":
54
+ uploaded_file = st.file_uploader("Upload your dataset (CSV format)", type=["csv"])
55
+ if uploaded_file:
56
+ df = pd.read_csv(uploaded_file)
57
+ st.success("File uploaded successfully!")
58
+ else:
59
+ dataset_name = st.text_input("Enter Hugging Face dataset name:", placeholder="e.g., imdb, ag_news")
60
+ if dataset_name:
61
  try:
62
  dataset = load_dataset(dataset_name, split="train")
63
  df = pd.DataFrame(dataset)
64
  st.success(f"Dataset '{dataset_name}' loaded successfully!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  except Exception as e:
66
+ st.error(f"Error loading Hugging Face dataset: {e}")
67
+ df = None
68
+
69
+ if 'df' in locals() and not df.empty:
70
+ st.write("### Dataset Preview:")
71
+ st.dataframe(df.head())
72
+
73
+ # Create a temporary SQLite database
74
+ temp_dir = tempfile.TemporaryDirectory()
75
+ db_path = os.path.join(temp_dir.name, "data.db")
76
+ connection = sqlite3.connect(db_path)
77
+ df.to_sql("data_table", connection, if_exists="replace", index=False)
78
+ db = SQLDatabase.from_uri(f"sqlite:///{db_path}")
79
+
80
+ # Tools
81
+ @tool("list_tables")
82
+ def list_tables() -> str:
83
+ return ListSQLDatabaseTool(db=db).invoke("")
84
+
85
+ @tool("tables_schema")
86
+ def tables_schema(tables: str) -> str:
87
+ return InfoSQLDatabaseTool(db=db).invoke(tables)
88
+
89
+ @tool("execute_sql")
90
+ def execute_sql(sql_query: str) -> str:
91
+ return QuerySQLDataBaseTool(db=db).invoke(sql_query)
92
+
93
+ @tool("check_sql")
94
+ def check_sql(sql_query: str) -> str:
95
+ return QuerySQLCheckerTool(db=db, llm=llm).invoke({"query": sql_query})
96
+
97
+ # Agents
98
+ sql_dev = Agent(
99
+ role="Database Developer",
100
+ goal="Extract data from the database.",
101
+ llm=llm,
102
+ tools=[list_tables, tables_schema, execute_sql, check_sql],
103
+ allow_delegation=False,
104
+ )
105
+
106
+ data_analyst = Agent(
107
+ role="Data Analyst",
108
+ goal="Analyze and provide insights.",
109
+ llm=llm,
110
+ allow_delegation=False,
111
+ )
112
+
113
+ report_writer = Agent(
114
+ role="Report Editor",
115
+ goal="Summarize the analysis.",
116
+ llm=llm,
117
+ allow_delegation=False,
118
+ )
119
+
120
+ # Tasks
121
+ extract_data = Task(
122
+ description="Extract data required for the query: {query}.",
123
+ expected_output="Database result for the query",
124
+ agent=sql_dev,
125
+ )
126
+
127
+ analyze_data = Task(
128
+ description="Analyze the data for: {query}.",
129
+ expected_output="Detailed analysis text",
130
+ agent=data_analyst,
131
+ context=[extract_data],
132
+ )
133
+
134
+ write_report = Task(
135
+ description="Summarize the analysis into a short report.",
136
+ expected_output="Markdown report",
137
+ agent=report_writer,
138
+ context=[analyze_data],
139
+ )
140
+
141
+ crew = Crew(
142
+ agents=[sql_dev, data_analyst, report_writer],
143
+ tasks=[extract_data, analyze_data, write_report],
144
+ process=Process.sequential,
145
+ verbose=2,
146
+ memory=False,
147
+ )
148
+
149
+ query = st.text_input("Enter your query:", placeholder="e.g., 'What are the top 5 highest salaries?'")
150
+ if query:
151
+ with st.spinner("Processing your query..."):
152
+ inputs = {"query": query}
153
+ result = crew.kickoff(inputs=inputs)
154
+ st.markdown("### Analysis Report:")
155
+ st.markdown(result)
156
+
157
+ temp_dir.cleanup()
158
+ else:
159
+ st.warning("Please upload a valid file or provide a correct Hugging Face dataset name.")