cboettig commited on
Commit
057cfd5
1 Parent(s): 4dbbaf7
Files changed (3) hide show
  1. app.py +172 -66
  2. minimal-example.py +7 -1
  3. minimal-requirements.txt +1 -1
app.py CHANGED
@@ -1,87 +1,193 @@
1
- import streamlit as st
2
- from pathlib import Path
3
- from langchain.llms.openai import OpenAI
4
- from langchain.agents import create_sql_agent
5
- from langchain.sql_database import SQLDatabase
6
- from langchain.agents.agent_types import AgentType
7
- from langchain_community.callbacks import StreamlitCallbackHandler
8
- from langchain.agents.agent_toolkits import SQLDatabaseToolkit
9
- from sqlalchemy import create_engine
10
- import sqlite3
11
  import os
12
- from langchain_openai import ChatOpenAI
13
- os.environ["OPENAI_API_KEY"] = st.secrets["OPENAI_API_KEY"]
14
- st.set_page_config(page_title="Protected Areas Database Chat", page_icon="🦜", layout="wide")
15
- st.title("🦜 Protected Areas Database Chat")
16
 
17
- #db_uri = "duckdb:///:memory:"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  db_uri = "duckdb:///pad.duckdb"
19
- engine = create_engine(db_uri)
20
- from sqlalchemy import text
21
- con = engine.connect()
22
- #con.execute(text("create or replace view agency_name as select * from 'https://huggingface.co/datasets/boettiger-lab/pad-us-3/resolve/main/parquet/pad-agency-name.parquet'"))
23
- #con.execute(text("create or replace view agency_name as select * from 'https://huggingface.co/datasets/boettiger-lab/pad-us-3/resolve/main/parquet/pad-agency-name.parquet'"))
24
- #con.execute(text("create or replace view agency_type as select * from 'https://huggingface.co/datasets/boettiger-lab/pad-us-3/resolve/main/parquet/pad-agency-type.parquet'"))
25
- #con.execute(text("create or replace view category as select * from 'https://huggingface.co/datasets/boettiger-lab/pad-us-3/resolve/main/parquet/pad-category.parquet'"))
26
- #con.execute(text("create or replace view designation_type as select * from 'https://huggingface.co/datasets/boettiger-lab/pad-us-3/resolve/main/parquet/pad-desgination-type.parquet'"))
27
- #con.execute(text("create or replace view easement as select * from 'https://huggingface.co/datasets/boettiger-lab/pad-us-3/resolve/main/parquet/pad-easement.parquet'"))
28
- #con.execute(text("create or replace view fee as select * from 'https://huggingface.co/datasets/boettiger-lab/pad-us-3/resolve/main/parquet/pad-fee.parquet'"))
29
- #con.execute(text("create or replace view marine as select * from 'https://huggingface.co/datasets/boettiger-lab/pad-us-3/resolve/main/parquet/pad-marine.parquet'"))
30
- #con.execute(text("create or replace view iucn as select * from 'https://huggingface.co/datasets/boettiger-lab/pad-us-3/resolve/main/parquet/pad-iucn.parquet'"))
31
- #con.execute(text("create or replace view public_access as select * from 'https://huggingface.co/datasets/boettiger-lab/pad-us-3/resolve/main/parquet/pad-public-access.parquet'"))
32
- #con.execute(text("create or replace view state_name as select * from 'https://huggingface.co/datasets/boettiger-lab/pad-us-3/resolve/main/parquet/pad-state-name.parquet'"))
33
- #con.execute(text("create or replace view combined as select * from 'https://huggingface.co/datasets/boettiger-lab/pad-us-3/resolve/main/parquet/pad-combined.parquet'"))
34
-
35
  db = SQLDatabase(engine, view_support=True)
36
- db.get_usable_table_names()
37
 
 
 
 
38
 
 
 
39
 
 
 
40
 
41
- # User inputs
42
- radio_opt = ["US Protected Areas v3"]
43
- selected_opt = st.sidebar.radio(label="Choose suitable option", options=radio_opt)
44
 
45
- llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
46
- agent = create_sql_agent(llm, db=db, agent_type="openai-tools", verbose=True)
47
 
48
- def handle_user_input(user_query):
49
- with history:
50
- st.session_state.messages.append({"role": "user", "content": user_query})
51
- #st.chat_message("user").write(user_query)
 
 
 
 
 
 
 
 
52
 
53
- with st.chat_message("assistant"):
54
- st_cb = StreamlitCallbackHandler(st.container())
55
- response = agent.run(user_query, callbacks=[st_cb])
56
- st.session_state.messages.append({"role": "assistant", "content": response})
57
- # st.write(response) # thinking is only shown transiently this way
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
- if "messages" not in st.session_state:
61
- st.session_state["messages"] = []
62
 
63
  main = st.container()
 
 
64
  with main:
65
- history = st.container(height=400)
66
- # stores all questions and responses, but not the 'thinking'
67
- with history:
68
- for msg in st.session_state.messages:
69
- st.chat_message(msg["role"]).write(msg["content"])
70
- if user_query := st.chat_input(placeholder="Ask me about US Protected areas!"):
71
- handle_user_input(user_query)
72
-
73
- st.markdown("\n") #add some space for iphone users
74
 
 
 
 
 
75
 
 
 
 
 
76
 
 
77
 
78
- EXAMPLE_PROMPTS = ["What is the total area in each GAP_Sts category in the fee table?",
79
- "List the name of each table in the database",
80
- "How much BLM land (BLM is a Mang_Name in the fee table) is in each GAP_Sts category?",
81
- "Federal agencies are identified as 'FED' in the Mang_Type column in the 'combined' data table. The Mang_Name column indicates the different agencies. The full name of each agency is given in the agency_name table. Which federal agencies, by full name, manage the greatest area of GAP_Sts 1 or 2 land?"]
82
 
83
- with st.sidebar:
84
- with st.container():
85
- st.title("Examples")
86
- for prompt in EXAMPLE_PROMPTS:
87
- st.button(prompt, args=(prompt,), on_click=handle_user_input)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This example does not use a langchain agent,
2
+ # The langchain sql chain has knowledge of the database, but doesn't interact with it becond intialization.
3
+ # The output of the sql chain is parsed seperately and passed to `duckdb.sql()` by streamlit
4
+
 
 
 
 
 
 
5
  import os
6
+ os.environ["WEBSOCKET_TIMEOUT_MS"] = "300000" # no effect
 
 
 
7
 
8
+ import streamlit as st
9
+ import geopandas as gpd
10
+ from shapely import wkb
11
+ import leafmap.foliumap as leafmap
12
+
13
+ # Helper plotting functions
14
+ import pydeck as pdk
15
+ def deck_map(gdf):
16
+ st.write(
17
+ pdk.Deck(
18
+ map_style="mapbox://styles/mapbox/light-v9",
19
+ initial_view_state={
20
+ "latitude": 35,
21
+ "longitude": -100,
22
+ "zoom": 3,
23
+ "pitch": 50,
24
+ },
25
+ layers=[
26
+ pdk.Layer(
27
+ "GeoJsonLayer",
28
+ gdf,
29
+ pickable=True,
30
+ stroked=True,
31
+ filled=True,
32
+ extruded=True,
33
+ elevation_scale=10,
34
+ get_fill_color=[2, 200, 100],
35
+ get_line_color=[0,0,0],
36
+ line_width_min_pixels=0,
37
+ ),
38
+ ],
39
+ )
40
+ )
41
+
42
+ def leaf_map(gdf):
43
+ m = leafmap.Map(center=[35, -100], zoom=4, layers_control=True)
44
+ m.add_gdf(gdf)
45
+ return m.to_streamlit()
46
+
47
+
48
+ @st.cache_data
49
+ def query_database(response):
50
+ return con.sql(response).to_pandas().head(25)
51
+
52
+ @st.cache_data
53
+ def get_geom(tbl):
54
+ tbl['geometry'] = tbl['geometry'].apply(wkb.loads)
55
+ gdf = gpd.GeoDataFrame(tbl, geometry='geometry')
56
+ return gdf
57
+
58
+
59
+ ## Database connection
60
+ from sqlalchemy import create_engine
61
+ from langchain.sql_database import SQLDatabase
62
  db_uri = "duckdb:///pad.duckdb"
63
+ engine = create_engine(db_uri, connect_args={'read_only': True})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  db = SQLDatabase(engine, view_support=True)
 
65
 
66
+ import ibis
67
+ con = ibis.connect("duckdb://pad.duckdb", read_only=True)
68
+ con.load_extension("spatial")
69
 
70
+ ## ChatGPT Connection
71
+ from langchain_openai import ChatOpenAI
72
 
73
+ # Requires ollama server running locally
74
+ from langchain_community.llms import Ollama
75
 
76
+ ## should we use ChatOllama instead?
77
+ # from langchain_community.llms import ChatOllama
 
78
 
79
+ models = {"chatgpt3.5": ChatOpenAI(model="gpt-3.5-turbo", temperature=0, api_key=st.secrets["OPENAI_API_KEY"])}
 
80
 
81
+ other_models = {
82
+ "chatgpt4": ChatOpenAI(model="gpt-4", temperature=0, api_key=st.secrets["OPENAI_API_KEY"]),
83
+ "duckdb-nsql": Ollama(model="duckdb-nsql", temperature=0),
84
+ "command-r-plus": Ollama(model="command-r-plus", temperature=0),
85
+ "mixtral:8x22b": Ollama(model="mixtral:8x22b", temperature=0),
86
+ "wizardlm2:8x22b": Ollama(model="wizardlm2:8x22b", temperature=0),
87
+ "sqlcoder": Ollama(model="sqlcoder", temperature=0),
88
+ "zephyr": Ollama(model="zephyr", temperature=0),
89
+ "gemma:7b": Ollama(model="gemma:7b", temperature=0),
90
+ "codegemma": Ollama(model="codegemma", temperature=0),
91
+ "llama2": Ollama(model="llama2", temperature=0),
92
+ }
93
 
 
 
 
 
 
94
 
95
+ st.set_page_config(page_title="Protected Areas Database Chat", page_icon="🦜", layout="wide")
96
+ st.title("Protected Areas Database Chat")
97
+
98
+ map_tool = {"leafmap": leaf_map,
99
+ "deckgl": deck_map
100
+ }
101
+
102
+ with st.sidebar:
103
+ choice = st.radio("Select an LLM:", models)
104
+ llm = models[choice]
105
+ map_choice = st.radio("Select mapping tool", map_tool)
106
+ mapper = map_tool[map_choice]
107
+ ## A SQL Chain
108
+ from langchain.chains import create_sql_query_chain
109
+ chain = create_sql_query_chain(llm, db)
110
 
 
 
111
 
112
  main = st.container()
113
+
114
+ ## Does not preserve history
115
  with main:
 
 
 
 
 
 
 
 
 
116
 
117
+ '''
118
+ The Protected Areas Database of the United States (PAD-US) is the official national inventory of
119
+ America’s parks and other protected lands, and is published by the USGS Gap Analysis Project,
120
+ [https://doi.org/10.5066/P9Q9LQ4B.](https://doi.org/10.5066/P9Q9LQ4B).
121
 
122
+ This interactive tool allows users to explore the dataset, as well as a range of biodiversity
123
+ and climate indicators associated with each protected area. These indicators are integrated into
124
+ a single table format shown below. The chatbot assistant can turn natural language queries into
125
+ SQL queries based on the table schema.
126
 
127
+ See our [Protected Areas Explorer](https://huggingface.co/spaces/boettiger-lab/pad-us) for a companion non-chat-based tool.
128
 
129
+ ##### Example Queries returning summary tables
 
 
 
130
 
131
+ - What is the percent area in each gap code as a fraction of the total protected area?
132
+ - The manager_type column indicates whether a manager is federal, state, local, private, or NGO.
133
+ the manager_name column indicates the responsible agency (National Park Service, Bureau of Land Management,
134
+ etc) in the case of federal manager types. Which of the federal managers manage the most land in
135
+ gap_code 1 or 2, as a fraction of the total area?
136
+
137
+ When queries refer to specific managed areas, the chatbot can show those areas on an interactive map.
138
+ Do to software limitations, these maps will show no more than 25 polygons, even if more areas match the
139
+ requested search. The chatbot sometimes requires help identifying the right columns. In order to create
140
+ a map, the SQL query must also return the geometry column. Conisder the following examples:
141
+
142
+ ##### Example queries returning maps + tables
143
+
144
+ - Show me all the national monuments (designation_type) in Utah. Include the geometry column
145
+ - Show examples of Bureau of Land Management (manager_name) with the highest species richness? Include the geometry column
146
+ - Which site has the overall highest range-size-rarity? Include the geometry column, manager_name, and IUCN category.
147
+
148
+ '''
149
+
150
+ st.markdown("## 🦜 Chatbot:")
151
+ chatbox = st.container()
152
+ with chatbox:
153
+ if prompt := st.chat_input(key="chain"):
154
+ st.chat_message("user").write(prompt)
155
+ with st.chat_message("assistant"):
156
+ response = chain.invoke({"question": prompt})
157
+ st.write(response)
158
+ tbl = query_database(response)
159
+ if 'geometry' in tbl:
160
+ gdf = get_geom(tbl)
161
+ mapper(gdf)
162
+ n = len(gdf)
163
+ st.write(f"matching features: {n}")
164
+ st.dataframe(tbl)
165
+
166
+
167
+ st.divider()
168
+
169
+ with st.container():
170
+ st.text("Database schema (top 3 rows)")
171
+ tbl = tbl = query_database("select * from pad limit 3")
172
+ st.dataframe(tbl)
173
+
174
+
175
+ st.divider()
176
+
177
+ '''
178
+ Experimental prototype.
179
+
180
+ - Author: [Carl Boettiger](https://carlboettiger.info)
181
+ - For data sources and processing, see: https://beta.source.coop/repositories/cboettig/pad-us-3/description/
182
+
183
+
184
+ '''
185
+
186
+ # duckdb_sql fails but chatgpt3.5 succeeds with a query like:
187
+ # use the st_area function and st_GeomFromWKB functions to compute the area of the Shape column in the fee table, and then use that to compute the total area under each GAP_Sts category
188
+
189
+ # For most queries, duckdb_sql does much better than alternative open models though
190
+
191
+ # Federal agencies are identified as 'FED' in the Mang_Type column in the 'combined' data table. The Mang_Name column indicates the different agencies. Which federal agencies manage the greatest area of GAP_Sts 1 or 2 land?
192
+
193
+ # Federal agencies are identified as 'FED' in the Mang_Type column in the table named "fee". The Mang_Name column indicates the different agencies. List which managers manage the largest total areas that identified as GAP_Sts '1' or '2' ?
minimal-example.py CHANGED
@@ -29,7 +29,10 @@ from langchain_community.llms import Ollama
29
 
30
  models = {"duckdb-nsql": Ollama(model="duckdb-nsql", temperature=0),
31
  "sqlcoder": Ollama(model="sqlcoder", temperature=0),
32
- "gemma": Ollama(model="gemma", temperature=0),
 
 
 
33
  "chatgpt3.5": chatgpt_llm,
34
  "chatgpt4": chatgpt4_llm}
35
  with st.sidebar:
@@ -57,5 +60,8 @@ if prompt := st.chat_input():
57
  # use the st_area function and st_GeomFromWKB functions to compute the area of the Shape column in the fee table, and then use that to compute the total area under each GAP_Sts category
58
 
59
 
 
 
60
  # Federal agencies are identified as 'FED' in the Mang_Type column in the 'combined' data table. The Mang_Name column indicates the different agencies. Which federal agencies manage the greatest area of GAP_Sts 1 or 2 land?
61
 
 
 
29
 
30
  models = {"duckdb-nsql": Ollama(model="duckdb-nsql", temperature=0),
31
  "sqlcoder": Ollama(model="sqlcoder", temperature=0),
32
+ "zephyr": Ollama(model="zephyr", temperature=0),
33
+ "gemma:7b": Ollama(model="gemma:7b", temperature=0),
34
+ "codegemma": Ollama(model="codegemma", temperature=0),
35
+ "llama2:70b": Ollama(model="llama2:70b", temperature=0),
36
  "chatgpt3.5": chatgpt_llm,
37
  "chatgpt4": chatgpt4_llm}
38
  with st.sidebar:
 
60
  # use the st_area function and st_GeomFromWKB functions to compute the area of the Shape column in the fee table, and then use that to compute the total area under each GAP_Sts category
61
 
62
 
63
+ # For most queries, duckdb_sql does much better than alternative open models though
64
+
65
  # Federal agencies are identified as 'FED' in the Mang_Type column in the 'combined' data table. The Mang_Name column indicates the different agencies. Which federal agencies manage the greatest area of GAP_Sts 1 or 2 land?
66
 
67
+ # Federal agencies are identified as 'FED' in the Mang_Type column in the table named "fee". The Mang_Name column indicates the different agencies. List which managers manage the largest total areas that identified as GAP_Sts '1' or '2' ?
minimal-requirements.txt CHANGED
@@ -4,4 +4,4 @@ langchain
4
  langchain-community
5
  langchain-openai
6
  SQLAlchemy==1.4.52
7
- streamlit
 
4
  langchain-community
5
  langchain-openai
6
  SQLAlchemy==1.4.52
7
+ streamlit