DrishtiSharma commited on
Commit
3c32de9
Β·
verified Β·
1 Parent(s): a1ef31a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -14
app.py CHANGED
@@ -15,29 +15,27 @@ from langchain_community.utilities.sql_database import SQLDatabase
15
  from datasets import load_dataset
16
  import tempfile
17
 
18
- st.title("Blah Blah App πŸš€")
19
  st.write("Analyze datasets using natural language queries.")
20
 
21
- # LLM Initialization
22
  def initialize_llm(model_choice):
23
  groq_api_key = os.getenv("GROQ_API_KEY")
24
  openai_api_key = os.getenv("OPENAI_API_KEY")
25
 
26
  if model_choice == "llama-3.3-70b":
27
  if not groq_api_key:
28
- st.error("Groq API key is missing.")
29
  return None
30
  return ChatGroq(groq_api_key=groq_api_key, model="groq/llama-3.3-70b-versatile")
31
  elif model_choice == "GPT-4o":
32
  if not openai_api_key:
33
- st.error("OpenAI API key is missing.")
34
  return None
35
  return ChatOpenAI(api_key=openai_api_key, model="gpt-4o")
36
 
37
  model_choice = st.radio("Select LLM", ["GPT-4o", "llama-3.3-70b"], index=0, horizontal=True)
38
  llm = initialize_llm(model_choice)
39
 
40
- # Dataset Loading
41
  def load_dataset_into_session():
42
  input_option = st.radio("Select Dataset Input:", ["Use Hugging Face Dataset", "Upload CSV File"])
43
  if input_option == "Use Hugging Face Dataset":
@@ -61,7 +59,6 @@ if "df" not in st.session_state:
61
  st.session_state.df = None
62
  load_dataset_into_session()
63
 
64
- # Database Initialization
65
  def initialize_database(df):
66
  temp_dir = tempfile.TemporaryDirectory()
67
  db_path = os.path.join(temp_dir.name, "patent_data.db")
@@ -70,23 +67,24 @@ def initialize_database(df):
70
  db = SQLDatabase.from_uri(f"sqlite:///{db_path}")
71
  return db, temp_dir
72
 
73
- # SQL Tools
74
  def create_sql_tools(db):
75
  @tool("list_tables")
76
  def list_tables() -> str:
 
77
  return ListSQLDatabaseTool(db=db).invoke("")
78
 
79
  @tool("tables_schema")
80
  def tables_schema(tables: str) -> str:
 
81
  return InfoSQLDatabaseTool(db=db).invoke(tables)
82
 
83
  @tool("execute_sql")
84
  def execute_sql(sql_query: str) -> str:
 
85
  return QuerySQLDataBaseTool(db=db).invoke(sql_query)
86
 
87
  return list_tables, tables_schema, execute_sql
88
 
89
- # Agent Initialization
90
  def initialize_agents(llm, tools):
91
  list_tables, tables_schema, execute_sql = tools
92
 
@@ -114,7 +112,6 @@ def initialize_agents(llm, tools):
114
 
115
  return sql_agent, analyst_agent, writer_agent
116
 
117
- # Crew and Tasks Setup
118
  def setup_crew(sql_agent, analyst_agent, writer_agent):
119
  extract_task = Task(
120
  description="Extract patents related to the query: {query}.",
@@ -143,7 +140,6 @@ def setup_crew(sql_agent, analyst_agent, writer_agent):
143
  verbose=True,
144
  )
145
 
146
- # Execution Flow
147
  if st.session_state.df is not None:
148
  db, temp_dir = initialize_database(st.session_state.df)
149
  tools = create_sql_tools(db)
@@ -152,9 +148,12 @@ if st.session_state.df is not None:
152
 
153
  query = st.text_area("Enter Patent Analysis Query:", placeholder="e.g., 'How many patents related to Machine Learning were filed after 2016?'")
154
  if st.button("Submit Query"):
155
- with st.spinner("Processing your query..."):
156
- result = crew.kickoff(inputs={"query": query})
157
- st.markdown("### πŸ“Š Patent Analysis Report")
158
- st.markdown(result)
 
 
 
159
  else:
160
  st.info("Please load a patent dataset to proceed.")
 
15
  from datasets import load_dataset
16
  import tempfile
17
 
18
+ st.title("Blah Blah App \U0001F680")
19
  st.write("Analyze datasets using natural language queries.")
20
 
 
21
  def initialize_llm(model_choice):
22
  groq_api_key = os.getenv("GROQ_API_KEY")
23
  openai_api_key = os.getenv("OPENAI_API_KEY")
24
 
25
  if model_choice == "llama-3.3-70b":
26
  if not groq_api_key:
27
+ st.error("Groq API key is missing. Please set the GROQ_API_KEY environment variable.")
28
  return None
29
  return ChatGroq(groq_api_key=groq_api_key, model="groq/llama-3.3-70b-versatile")
30
  elif model_choice == "GPT-4o":
31
  if not openai_api_key:
32
+ st.error("OpenAI API key is missing. Please set the OPENAI_API_KEY environment variable.")
33
  return None
34
  return ChatOpenAI(api_key=openai_api_key, model="gpt-4o")
35
 
36
  model_choice = st.radio("Select LLM", ["GPT-4o", "llama-3.3-70b"], index=0, horizontal=True)
37
  llm = initialize_llm(model_choice)
38
 
 
39
  def load_dataset_into_session():
40
  input_option = st.radio("Select Dataset Input:", ["Use Hugging Face Dataset", "Upload CSV File"])
41
  if input_option == "Use Hugging Face Dataset":
 
59
  st.session_state.df = None
60
  load_dataset_into_session()
61
 
 
62
  def initialize_database(df):
63
  temp_dir = tempfile.TemporaryDirectory()
64
  db_path = os.path.join(temp_dir.name, "patent_data.db")
 
67
  db = SQLDatabase.from_uri(f"sqlite:///{db_path}")
68
  return db, temp_dir
69
 
 
70
  def create_sql_tools(db):
71
  @tool("list_tables")
72
  def list_tables() -> str:
73
+ """List all tables in the patent database."""
74
  return ListSQLDatabaseTool(db=db).invoke("")
75
 
76
  @tool("tables_schema")
77
  def tables_schema(tables: str) -> str:
78
+ """Get schema and sample rows for given tables."""
79
  return InfoSQLDatabaseTool(db=db).invoke(tables)
80
 
81
  @tool("execute_sql")
82
  def execute_sql(sql_query: str) -> str:
83
+ """Execute a SQL query against the patent database."""
84
  return QuerySQLDataBaseTool(db=db).invoke(sql_query)
85
 
86
  return list_tables, tables_schema, execute_sql
87
 
 
88
  def initialize_agents(llm, tools):
89
  list_tables, tables_schema, execute_sql = tools
90
 
 
112
 
113
  return sql_agent, analyst_agent, writer_agent
114
 
 
115
  def setup_crew(sql_agent, analyst_agent, writer_agent):
116
  extract_task = Task(
117
  description="Extract patents related to the query: {query}.",
 
140
  verbose=True,
141
  )
142
 
 
143
  if st.session_state.df is not None:
144
  db, temp_dir = initialize_database(st.session_state.df)
145
  tools = create_sql_tools(db)
 
148
 
149
  query = st.text_area("Enter Patent Analysis Query:", placeholder="e.g., 'How many patents related to Machine Learning were filed after 2016?'")
150
  if st.button("Submit Query"):
151
+ if query.strip():
152
+ with st.spinner("Processing your query..."):
153
+ result = crew.kickoff(inputs={"query": query})
154
+ st.markdown("### πŸ“Š Patent Analysis Report")
155
+ st.markdown(result)
156
+ else:
157
+ st.warning("Please enter a valid query.")
158
  else:
159
  st.info("Please load a patent dataset to proceed.")