Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -15,29 +15,27 @@ from langchain_community.utilities.sql_database import SQLDatabase
|
|
15 |
from datasets import load_dataset
|
16 |
import tempfile
|
17 |
|
18 |
-
st.title("Blah Blah App
|
19 |
st.write("Analyze datasets using natural language queries.")
|
20 |
|
21 |
-
# LLM Initialization
|
22 |
def initialize_llm(model_choice):
|
23 |
groq_api_key = os.getenv("GROQ_API_KEY")
|
24 |
openai_api_key = os.getenv("OPENAI_API_KEY")
|
25 |
|
26 |
if model_choice == "llama-3.3-70b":
|
27 |
if not groq_api_key:
|
28 |
-
st.error("Groq API key is missing.")
|
29 |
return None
|
30 |
return ChatGroq(groq_api_key=groq_api_key, model="groq/llama-3.3-70b-versatile")
|
31 |
elif model_choice == "GPT-4o":
|
32 |
if not openai_api_key:
|
33 |
-
st.error("OpenAI API key is missing.")
|
34 |
return None
|
35 |
return ChatOpenAI(api_key=openai_api_key, model="gpt-4o")
|
36 |
|
37 |
model_choice = st.radio("Select LLM", ["GPT-4o", "llama-3.3-70b"], index=0, horizontal=True)
|
38 |
llm = initialize_llm(model_choice)
|
39 |
|
40 |
-
# Dataset Loading
|
41 |
def load_dataset_into_session():
|
42 |
input_option = st.radio("Select Dataset Input:", ["Use Hugging Face Dataset", "Upload CSV File"])
|
43 |
if input_option == "Use Hugging Face Dataset":
|
@@ -61,7 +59,6 @@ if "df" not in st.session_state:
|
|
61 |
st.session_state.df = None
|
62 |
load_dataset_into_session()
|
63 |
|
64 |
-
# Database Initialization
|
65 |
def initialize_database(df):
|
66 |
temp_dir = tempfile.TemporaryDirectory()
|
67 |
db_path = os.path.join(temp_dir.name, "patent_data.db")
|
@@ -70,23 +67,24 @@ def initialize_database(df):
|
|
70 |
db = SQLDatabase.from_uri(f"sqlite:///{db_path}")
|
71 |
return db, temp_dir
|
72 |
|
73 |
-
# SQL Tools
|
74 |
def create_sql_tools(db):
|
75 |
@tool("list_tables")
|
76 |
def list_tables() -> str:
|
|
|
77 |
return ListSQLDatabaseTool(db=db).invoke("")
|
78 |
|
79 |
@tool("tables_schema")
|
80 |
def tables_schema(tables: str) -> str:
|
|
|
81 |
return InfoSQLDatabaseTool(db=db).invoke(tables)
|
82 |
|
83 |
@tool("execute_sql")
|
84 |
def execute_sql(sql_query: str) -> str:
|
|
|
85 |
return QuerySQLDataBaseTool(db=db).invoke(sql_query)
|
86 |
|
87 |
return list_tables, tables_schema, execute_sql
|
88 |
|
89 |
-
# Agent Initialization
|
90 |
def initialize_agents(llm, tools):
|
91 |
list_tables, tables_schema, execute_sql = tools
|
92 |
|
@@ -114,7 +112,6 @@ def initialize_agents(llm, tools):
|
|
114 |
|
115 |
return sql_agent, analyst_agent, writer_agent
|
116 |
|
117 |
-
# Crew and Tasks Setup
|
118 |
def setup_crew(sql_agent, analyst_agent, writer_agent):
|
119 |
extract_task = Task(
|
120 |
description="Extract patents related to the query: {query}.",
|
@@ -143,7 +140,6 @@ def setup_crew(sql_agent, analyst_agent, writer_agent):
|
|
143 |
verbose=True,
|
144 |
)
|
145 |
|
146 |
-
# Execution Flow
|
147 |
if st.session_state.df is not None:
|
148 |
db, temp_dir = initialize_database(st.session_state.df)
|
149 |
tools = create_sql_tools(db)
|
@@ -152,9 +148,12 @@ if st.session_state.df is not None:
|
|
152 |
|
153 |
query = st.text_area("Enter Patent Analysis Query:", placeholder="e.g., 'How many patents related to Machine Learning were filed after 2016?'")
|
154 |
if st.button("Submit Query"):
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
|
|
|
|
|
|
159 |
else:
|
160 |
st.info("Please load a patent dataset to proceed.")
|
|
|
15 |
from datasets import load_dataset
|
16 |
import tempfile
|
17 |
|
18 |
+
st.title("Blah Blah App \U0001F680")
|
19 |
st.write("Analyze datasets using natural language queries.")
|
20 |
|
|
|
21 |
def initialize_llm(model_choice):
|
22 |
groq_api_key = os.getenv("GROQ_API_KEY")
|
23 |
openai_api_key = os.getenv("OPENAI_API_KEY")
|
24 |
|
25 |
if model_choice == "llama-3.3-70b":
|
26 |
if not groq_api_key:
|
27 |
+
st.error("Groq API key is missing. Please set the GROQ_API_KEY environment variable.")
|
28 |
return None
|
29 |
return ChatGroq(groq_api_key=groq_api_key, model="groq/llama-3.3-70b-versatile")
|
30 |
elif model_choice == "GPT-4o":
|
31 |
if not openai_api_key:
|
32 |
+
st.error("OpenAI API key is missing. Please set the OPENAI_API_KEY environment variable.")
|
33 |
return None
|
34 |
return ChatOpenAI(api_key=openai_api_key, model="gpt-4o")
|
35 |
|
36 |
model_choice = st.radio("Select LLM", ["GPT-4o", "llama-3.3-70b"], index=0, horizontal=True)
|
37 |
llm = initialize_llm(model_choice)
|
38 |
|
|
|
39 |
def load_dataset_into_session():
|
40 |
input_option = st.radio("Select Dataset Input:", ["Use Hugging Face Dataset", "Upload CSV File"])
|
41 |
if input_option == "Use Hugging Face Dataset":
|
|
|
59 |
st.session_state.df = None
|
60 |
load_dataset_into_session()
|
61 |
|
|
|
62 |
def initialize_database(df):
|
63 |
temp_dir = tempfile.TemporaryDirectory()
|
64 |
db_path = os.path.join(temp_dir.name, "patent_data.db")
|
|
|
67 |
db = SQLDatabase.from_uri(f"sqlite:///{db_path}")
|
68 |
return db, temp_dir
|
69 |
|
|
|
70 |
def create_sql_tools(db):
|
71 |
@tool("list_tables")
|
72 |
def list_tables() -> str:
|
73 |
+
"""List all tables in the patent database."""
|
74 |
return ListSQLDatabaseTool(db=db).invoke("")
|
75 |
|
76 |
@tool("tables_schema")
|
77 |
def tables_schema(tables: str) -> str:
|
78 |
+
"""Get schema and sample rows for given tables."""
|
79 |
return InfoSQLDatabaseTool(db=db).invoke(tables)
|
80 |
|
81 |
@tool("execute_sql")
|
82 |
def execute_sql(sql_query: str) -> str:
|
83 |
+
"""Execute a SQL query against the patent database."""
|
84 |
return QuerySQLDataBaseTool(db=db).invoke(sql_query)
|
85 |
|
86 |
return list_tables, tables_schema, execute_sql
|
87 |
|
|
|
88 |
def initialize_agents(llm, tools):
|
89 |
list_tables, tables_schema, execute_sql = tools
|
90 |
|
|
|
112 |
|
113 |
return sql_agent, analyst_agent, writer_agent
|
114 |
|
|
|
115 |
def setup_crew(sql_agent, analyst_agent, writer_agent):
|
116 |
extract_task = Task(
|
117 |
description="Extract patents related to the query: {query}.",
|
|
|
140 |
verbose=True,
|
141 |
)
|
142 |
|
|
|
143 |
if st.session_state.df is not None:
|
144 |
db, temp_dir = initialize_database(st.session_state.df)
|
145 |
tools = create_sql_tools(db)
|
|
|
148 |
|
149 |
query = st.text_area("Enter Patent Analysis Query:", placeholder="e.g., 'How many patents related to Machine Learning were filed after 2016?'")
|
150 |
if st.button("Submit Query"):
|
151 |
+
if query.strip():
|
152 |
+
with st.spinner("Processing your query..."):
|
153 |
+
result = crew.kickoff(inputs={"query": query})
|
154 |
+
st.markdown("### π Patent Analysis Report")
|
155 |
+
st.markdown(result)
|
156 |
+
else:
|
157 |
+
st.warning("Please enter a valid query.")
|
158 |
else:
|
159 |
st.info("Please load a patent dataset to proceed.")
|