Spaces:
Runtime error
Runtime error
Upload 22 files
Browse files- app.py +555 -0
- app_data.pickle +3 -0
- crashes.csv +0 -0
- cyclists.csv +0 -0
- lib/.DS_Store +0 -0
- lib/.ipynb_checkpoints/__init__-checkpoint.py +0 -0
- lib/.ipynb_checkpoints/get_data-checkpoint.py +74 -0
- lib/.ipynb_checkpoints/study_classif-checkpoint.py +787 -0
- lib/.ipynb_checkpoints/transform_data-checkpoint.py +83 -0
- lib/.ipynb_checkpoints/vis_data-checkpoint.py +287 -0
- lib/__init__.py +0 -0
- lib/__pycache__/__init__.cpython-310.pyc +0 -0
- lib/__pycache__/get_data.cpython-310.pyc +0 -0
- lib/__pycache__/study_class.cpython-310.pyc +0 -0
- lib/__pycache__/study_classif.cpython-310.pyc +0 -0
- lib/__pycache__/transform_data.cpython-310.pyc +0 -0
- lib/__pycache__/vis_data.cpython-310.pyc +0 -0
- lib/study_classif.py +787 -0
- lib/transform_data.py +83 -0
- lib/vis_data.py +287 -0
- requirements.txt +77 -0
- study.pkl +3 -0
app.py
ADDED
@@ -0,0 +1,555 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import numpy as np
|
3 |
+
import streamlit as st
|
4 |
+
import plotly.express as px
|
5 |
+
from scipy import stats
|
6 |
+
import pickle
|
7 |
+
import shap
|
8 |
+
import lightgbm as lgb
|
9 |
+
from lightgbm import LGBMClassifier
|
10 |
+
# from sklearn.ensemble import HistGradientBoostingClassifier
|
11 |
+
|
12 |
+
########################
|
13 |
+
### Helper functions ###
|
14 |
+
########################
|
15 |
+
|
16 |
+
@st.cache_data
|
17 |
+
def get_data(filename):
|
18 |
+
"""
|
19 |
+
Read dataframe from CSV
|
20 |
+
"""
|
21 |
+
return pd.read_csv(filename)
|
22 |
+
|
23 |
+
|
24 |
+
#############
|
25 |
+
### Setup ###
|
26 |
+
#############
|
27 |
+
|
28 |
+
# Load dataframes
|
29 |
+
crashes = get_data('crashes.csv')
|
30 |
+
cyclists = get_data('cyclists.csv')
|
31 |
+
|
32 |
+
# Load in prepared labeling data for app components
|
33 |
+
with open('app_data.pickle', 'rb') as file:
|
34 |
+
period_data, cohort_data, time_cat_data, time_bin_data,\
|
35 |
+
geo_data, county_data, feature_names, ord_features,\
|
36 |
+
cat_features,flag_features,model_cat_data,\
|
37 |
+
model_bin_data,veh_data = pickle.load(file)
|
38 |
+
|
39 |
+
features = cat_features+flag_features+ord_features
|
40 |
+
features.sort(key=lambda x:feature_names[x].lower())
|
41 |
+
|
42 |
+
# Load trained classifier study object
|
43 |
+
@st.cache_resource(show_spinner=False)
|
44 |
+
def load_study():
|
45 |
+
"""
|
46 |
+
Load the trained classifier pipeline
|
47 |
+
"""
|
48 |
+
with open('study.pkl', 'rb') as file:
|
49 |
+
study = pickle.load(file)
|
50 |
+
return study
|
51 |
+
|
52 |
+
|
53 |
+
|
54 |
+
|
55 |
+
################################
|
56 |
+
### Initialize app structure ###
|
57 |
+
################################
|
58 |
+
|
59 |
+
st.header('BikeSaferPA: understanding cyclist outcomes')
|
60 |
+
tabs = st.tabs([
|
61 |
+
'Welcome',
|
62 |
+
'Crashes over time',
|
63 |
+
'Mapping crashes',
|
64 |
+
'Feature distributions',
|
65 |
+
'BikeSaferPA predictions',
|
66 |
+
])
|
67 |
+
with tabs[0]:
|
68 |
+
intro_container = st.container()
|
69 |
+
with tabs[1]:
|
70 |
+
time_intro_container = st.container()
|
71 |
+
time_settings_container = st.container()
|
72 |
+
time_plot_container = st.container()
|
73 |
+
with tabs[2]:
|
74 |
+
map_intro_container = st.container()
|
75 |
+
map_settings_container = st.container()
|
76 |
+
map_plot_container = st.container()
|
77 |
+
with tabs[3]:
|
78 |
+
feature_intro_container = st.container()
|
79 |
+
feature_settings_container = st.container()
|
80 |
+
feature_plot_container = st.container()
|
81 |
+
with tabs[4]:
|
82 |
+
model_intro_container = st.container()
|
83 |
+
model_settings_container = st.container()
|
84 |
+
model_result_container = st.container()
|
85 |
+
model_shap_container = st.container()
|
86 |
+
|
87 |
+
############################
|
88 |
+
### Populate welcome tab ###
|
89 |
+
############################
|
90 |
+
|
91 |
+
with intro_container:
|
92 |
+
st.markdown(
|
93 |
+
"""
|
94 |
+
This app provides a suite of tools to accompany Eamonn Tweedy's [BikeSaferPA project](https://github.com/e-tweedy/BikeSaferPA). These tools allow the user to:
|
95 |
+
- Visualize data related to crashes involving bicycles in Pennsylvania during the years 2002-2021, which was collected from a publically available [PENNDOT crash dataset](https://pennshare.maps.arcgis.com/apps/webappviewer/index.html?id=8fdbf046e36e41649bbfd9d7dd7c7e7e).
|
96 |
+
- Experiment with the BikeSaferPA model, which was trained on this cyclist crash data and designed to predict severity outcomes for cyclists based on crash data.
|
97 |
+
|
98 |
+
Navigate the tabs using the menu at the top to try them out.
|
99 |
+
""")
|
100 |
+
|
101 |
+
######################################
|
102 |
+
### Populate crashes over time tab ###
|
103 |
+
######################################
|
104 |
+
|
105 |
+
### Intro text ###
|
106 |
+
|
107 |
+
with time_intro_container:
|
108 |
+
st.subheader('Visualizing bicycle crashes in PA over time')
|
109 |
+
|
110 |
+
st.markdown("""
|
111 |
+
This tool provides plots of cyclist crash counts by year, month of the year, day of the week, or hour of the day and can stratify the counts by various crash features.
|
112 |
+
|
113 |
+
You also have the option to restrict to Philadelpha county only, or the PA counties in the greater Philadelphia area (Bucks, Chester, Delaware, Montgomery, and Philadelphia).
|
114 |
+
|
115 |
+
Expand the toolbox below to choose plot options.
|
116 |
+
""")
|
117 |
+
|
118 |
+
### User input - settings for plot ###
|
119 |
+
|
120 |
+
with time_settings_container:
|
121 |
+
# Expander containing plot option user input
|
122 |
+
with st.expander('Click here to expand or collapse plot options menu'):
|
123 |
+
col1,col2 = st.columns([0.4,0.6])
|
124 |
+
with col1:
|
125 |
+
# Geographic restriction selectbox
|
126 |
+
geo = st.selectbox(
|
127 |
+
'Geographic scope:',
|
128 |
+
list(geo_data.keys()),index=0,
|
129 |
+
format_func = lambda x:geo_data[x][0],
|
130 |
+
key = 'time_geo_select',
|
131 |
+
)
|
132 |
+
# Time period selectbox
|
133 |
+
period = st.selectbox(
|
134 |
+
'Time period:',
|
135 |
+
list(period_data.keys()),index=3,
|
136 |
+
format_func = lambda x:period_data[x][0],
|
137 |
+
key = 'time_period_select',
|
138 |
+
)
|
139 |
+
|
140 |
+
with col2:
|
141 |
+
# Cyclist cohort selectbox
|
142 |
+
cohort = st.selectbox(
|
143 |
+
'Crash severity:',
|
144 |
+
list(cohort_data.keys()),index=0,
|
145 |
+
format_func = lambda x:cohort_data[x],
|
146 |
+
key = 'time_cohort_select',
|
147 |
+
)
|
148 |
+
# Category stratification selectbox
|
149 |
+
stratify = st.selectbox('Stratify crashes by:',
|
150 |
+
['no']+list(time_cat_data.keys()),index=0,
|
151 |
+
key = 'time_cat_stratify_select',
|
152 |
+
format_func = lambda x:time_cat_data[x][0]\
|
153 |
+
if x!='no' else 'do not stratify',
|
154 |
+
)
|
155 |
+
st.markdown('Restrict to crashes containing the following factor(s):')
|
156 |
+
title_add = ''
|
157 |
+
|
158 |
+
cols = st.columns(len(time_bin_data))
|
159 |
+
# Columns of binary feature checkboxes
|
160 |
+
for k,col in enumerate(cols):
|
161 |
+
with col:
|
162 |
+
for feat in time_bin_data[k]:
|
163 |
+
# make checkbox
|
164 |
+
time_bin_data[k][feat][2]=st.checkbox(time_bin_data[k][feat][0],key=f'time_{feat}')
|
165 |
+
# if checked, filter samples and add feature to plot title addendum
|
166 |
+
if time_bin_data[k][feat][2]:
|
167 |
+
crashes = crashes[crashes[time_bin_data[k][feat][1]]==1]
|
168 |
+
title_add+= ', '+time_bin_data[k][feat][0].split('one ')[-1]
|
169 |
+
|
170 |
+
### Post-process user-selected setting data ###
|
171 |
+
|
172 |
+
# Geographic restriction
|
173 |
+
if geo != 'statewide':
|
174 |
+
crashes = crashes[crashes.COUNTY.isin(geo_data[geo][1])]
|
175 |
+
# Relegate rare categories to 'other' for plot readability
|
176 |
+
if stratify=='int_type':
|
177 |
+
crashes['INTERSECT_TYPE']=crashes['INTERSECT_TYPE']\
|
178 |
+
.replace({cat:'other' for cat in crashes.INTERSECT_TYPE.value_counts().index[3:]})
|
179 |
+
if stratify=='coll_type':
|
180 |
+
crashes['COLLISION_TYPE']=crashes['COLLISION_TYPE']\
|
181 |
+
.replace({cat:'other' for cat in crashes.COLLISION_TYPE.value_counts().index[6:]})
|
182 |
+
if stratify=='weather':
|
183 |
+
crashes['WEATHER']=crashes['WEATHER']\
|
184 |
+
.replace({cat:'other' for cat in crashes.WEATHER.value_counts().index[5:]})
|
185 |
+
if stratify=='tcd':
|
186 |
+
crashes['TCD_TYPE']=crashes['TCD_TYPE']\
|
187 |
+
.replace({cat:'other' for cat in crashes.TCD_TYPE.value_counts().index[3:]})
|
188 |
+
crashes=crashes.dropna(subset=period_data[period][1])
|
189 |
+
|
190 |
+
# Order categories in descending order by frequency
|
191 |
+
category_orders = {time_cat_data[cat][1]:list(crashes[time_cat_data[cat][1]].value_counts().index) for cat in time_cat_data}
|
192 |
+
|
193 |
+
# Define cohort
|
194 |
+
if cohort == 'inj':
|
195 |
+
crashes = crashes[crashes.BICYCLE_SUSP_SERIOUS_INJ_COUNT > 0]
|
196 |
+
elif cohort == 'fat':
|
197 |
+
crashes = crashes[crashes.BICYCLE_DEATH_COUNT > 0]
|
198 |
+
|
199 |
+
# Replace day,month numbers with string labels
|
200 |
+
if period in ['day','month']:
|
201 |
+
crashes[period_data[period][1]] = crashes[period_data[period][1]].apply(lambda x:period_data[period][2][x-1])
|
202 |
+
|
203 |
+
# Plot title addendum
|
204 |
+
if len(title_add)>0:
|
205 |
+
title_add = '<br>with'+title_add.lstrip(',')
|
206 |
+
|
207 |
+
# Category stratification plot settings
|
208 |
+
if stratify=='no':
|
209 |
+
color,legend_title = None,None
|
210 |
+
else:
|
211 |
+
color,legend_title=time_cat_data[stratify][1],time_cat_data[stratify][2]
|
212 |
+
title_add += f'<br>stratified {time_cat_data[stratify][0]}'
|
213 |
+
|
214 |
+
### Build and display plot ###
|
215 |
+
|
216 |
+
with time_plot_container:
|
217 |
+
# Plot samples if any, else report no samples remain
|
218 |
+
if crashes.shape[0]>0:
|
219 |
+
fig = px.histogram(crashes,
|
220 |
+
x=period_data[period][1],
|
221 |
+
color=color,
|
222 |
+
nbins=len(period_data[period][2]),
|
223 |
+
title=f'PA bicycle crashes 2002-2021 by {period_data[period][0]} - {cohort_data[cohort]}'+title_add,
|
224 |
+
category_orders = category_orders,
|
225 |
+
)
|
226 |
+
fig.update_layout(bargap=0.2,
|
227 |
+
xaxis_title=period_data[period][0],
|
228 |
+
legend_title_text=legend_title,
|
229 |
+
)
|
230 |
+
fig.update_xaxes(categoryorder="array",
|
231 |
+
categoryarray=period_data[period][2],
|
232 |
+
dtick=1,
|
233 |
+
)
|
234 |
+
st.plotly_chart(fig,use_container_width=True)
|
235 |
+
else:
|
236 |
+
st.markdown('#### No samples meet these criteria. Please remove some factors.')
|
237 |
+
|
238 |
+
####################################
|
239 |
+
### Populate mapping crashes tab ###
|
240 |
+
####################################
|
241 |
+
|
242 |
+
### Intro text ###
|
243 |
+
|
244 |
+
with map_intro_container:
|
245 |
+
st.subheader('Mapping bicycle crashes in PA')
|
246 |
+
|
247 |
+
st.markdown("""
|
248 |
+
This tool provides interactive maps of crash events, either statewide or in one of the more populous counties. Crash event dots are color-coded based on whether the crash involved serious cyclist injury, cyclist fatality, or neither.
|
249 |
+
|
250 |
+
Expand the menu below to adjust map options.
|
251 |
+
""")
|
252 |
+
|
253 |
+
### User input - settings for map plot ###
|
254 |
+
|
255 |
+
with map_settings_container:
|
256 |
+
# Expander containing plot option user input
|
257 |
+
with st.expander('Click here to expand or collapse map options menu'):
|
258 |
+
# Locale selectbox
|
259 |
+
geo = st.selectbox(
|
260 |
+
'Select either statewide or a particular county to plot:',
|
261 |
+
['Statewide']+[county+' County' for county in county_data],
|
262 |
+
key = 'map_geo_select',
|
263 |
+
)
|
264 |
+
# Animation status selectbox
|
265 |
+
animate = st.selectbox(
|
266 |
+
'Select how to animate the map:',
|
267 |
+
['do not animate','by year','by month'],
|
268 |
+
key = 'map_animate_select',
|
269 |
+
)
|
270 |
+
|
271 |
+
### Post-process user-selected setting data ###
|
272 |
+
|
273 |
+
if geo == 'Statewide':
|
274 |
+
county = None
|
275 |
+
else:
|
276 |
+
geo = geo.split(' ')[0]
|
277 |
+
county = (county_data[geo],geo)
|
278 |
+
color_dots=True
|
279 |
+
if animate == 'do not animate':
|
280 |
+
animate = False
|
281 |
+
animate_by=None
|
282 |
+
else:
|
283 |
+
animate_by = animate.split(' ')[1]
|
284 |
+
animate = True
|
285 |
+
# If county is not None and animating, check whether first frame has all
|
286 |
+
# injury/fatality status categories. If not, then we will not color dots
|
287 |
+
# by injury/fatality status.
|
288 |
+
# This is to account for bug/feature in plotly 'animation_frame' and 'color' functionality
|
289 |
+
# which yields unexpected results when all color categories not present in first frame
|
290 |
+
# see e.g. https://github.com/plotly/plotly.py/issues/2259
|
291 |
+
|
292 |
+
if county is not None:
|
293 |
+
if animate_by == 'year':
|
294 |
+
color_dots = len(crashes.query('COUNTY==@county[0] and CRASH_YEAR==2002')\
|
295 |
+
.BICYCLE_DEATH_COUNT.unique())+\
|
296 |
+
len(crashes.query('COUNTY==@county[0] and CRASH_YEAR==2002')\
|
297 |
+
.BICYCLE_SUSP_SERIOUS_INJ_COUNT.unique()) > 3
|
298 |
+
else:
|
299 |
+
color_dots = len(crashes.query('COUNTY==@county[0] and CRASH_YEAR==2002 and CRASH_MONTH==1')\
|
300 |
+
.BICYCLE_DEATH_COUNT.unique())+\
|
301 |
+
len(crashes.query('COUNTY==@county[0] and CRASH_YEAR==2002 and CRASH_MONTH==1')\
|
302 |
+
.BICYCLE_SUSP_SERIOUS_INJ_COUNT.unique()) > 3
|
303 |
+
if color_dots==False:
|
304 |
+
st.markdown("""
|
305 |
+
**Warning:** color-coding by injury/death status is disabled; this feature gives unexpected results
|
306 |
+
when not all classes appear in the first animation frame due to bug/feature in Plotly animate functionality.
|
307 |
+
Injury/death status is still visible in hover-text box.
|
308 |
+
""")
|
309 |
+
|
310 |
+
### Build and display map plot ###
|
311 |
+
|
312 |
+
from lib.vis_data import plot_map
|
313 |
+
|
314 |
+
with map_plot_container:
|
315 |
+
fig = plot_map(
|
316 |
+
df=crashes,county=county,animate=animate,
|
317 |
+
color_dots=color_dots,animate_by=animate_by,
|
318 |
+
show_fig=False,return_fig=True,
|
319 |
+
)
|
320 |
+
st.plotly_chart(fig,use_container_width=True)
|
321 |
+
|
322 |
+
##########################################
|
323 |
+
### Populate feature distributions tab ###
|
324 |
+
##########################################
|
325 |
+
|
326 |
+
### Intro text ###
|
327 |
+
|
328 |
+
with feature_intro_container:
|
329 |
+
st.subheader('Visualizing crash feature distributions')
|
330 |
+
|
331 |
+
st.markdown("""
|
332 |
+
The tools on this page will demonstrate how distributions of values of various crash and cyclist features vary between two groups:
|
333 |
+
- all cyclists involved in crashes, and
|
334 |
+
- those cyclists who suffered serious injury or fatality
|
335 |
+
|
336 |
+
Expand the following menu to choose a feature, and the graph will show its distribution of its values (via percentages) over the two groups. Again you may restrict to Philadelpha county only, or the PA counties in the greater Philadelphia area (Bucks, Chester, Delaware, Montgomery, and Philadelphia).
|
337 |
+
|
338 |
+
Pay particular attention to feature values which become more or less prevalent among cyclists suffering serious injury or death - for instance, 6.2% of all cyclists statewide were involved in a head-on collision, whereas 11.8% of those with serious injury or fatality were in a head-on collision.
|
339 |
+
""")
|
340 |
+
|
341 |
+
### User input - settings for plot ###
|
342 |
+
|
343 |
+
with feature_settings_container:
|
344 |
+
# Expander containing plot option user input
|
345 |
+
with st.expander('Click here to expand or collapse feature selection menu'):
|
346 |
+
# Geographic restriction selectbox
|
347 |
+
geo = st.selectbox(
|
348 |
+
'Geographic scope:',
|
349 |
+
list(geo_data.keys()),index=0,
|
350 |
+
format_func = lambda x:geo_data[x][0],
|
351 |
+
key = 'feature_geo_select',
|
352 |
+
)
|
353 |
+
# Feature selectbox
|
354 |
+
feature = st.selectbox('Show distributions of this feature:',
|
355 |
+
features,format_func = lambda x:feature_names[x],
|
356 |
+
key = 'feature_select',
|
357 |
+
)
|
358 |
+
|
359 |
+
### Post-process user-selected settings data ###
|
360 |
+
|
361 |
+
from lib.vis_data import feat_perc,feat_perc_bar
|
362 |
+
# Geographic restriction
|
363 |
+
if geo != 'statewide':
|
364 |
+
cyclists = cyclists[cyclists.COUNTY.isin(geo_data[geo][1])]
|
365 |
+
|
366 |
+
# Recast binary and day of week data
|
367 |
+
if feature not in ord_features:
|
368 |
+
cyclists[feature]=cyclists[feature].replace({1:'yes',0:'no'})
|
369 |
+
if feature == 'DAY_OF_WEEK':
|
370 |
+
cyclists[feature]=cyclists[feature].astype(str)
|
371 |
+
|
372 |
+
### Build and display plot ###
|
373 |
+
|
374 |
+
with feature_plot_container:
|
375 |
+
|
376 |
+
# Generate plot
|
377 |
+
sort = False if feature in ord_features else True
|
378 |
+
fig = feat_perc_bar(
|
379 |
+
feature,cyclists, feat_name=feature_names[feature],
|
380 |
+
return_fig=True,show_fig=False,sort=sort
|
381 |
+
)
|
382 |
+
|
383 |
+
# Adjust some colorscale and display settings
|
384 |
+
if feature == 'SPEED_LIMIT':
|
385 |
+
fig.update_coloraxes(colorscale='YlOrRd',cmid=35)
|
386 |
+
if feature == 'HOUR_OF_DAY':
|
387 |
+
fig.update_coloraxes(colorscale='balance')
|
388 |
+
if feature == 'DAY_OF_WEEK':
|
389 |
+
days = ['Sun']+list(cal.day_abbr)[:-1]
|
390 |
+
for idx, day in enumerate(days):
|
391 |
+
fig.data[idx].name = day
|
392 |
+
fig.data[idx].hovertemplate = day
|
393 |
+
|
394 |
+
# Display plot
|
395 |
+
st.plotly_chart(fig,use_container_width=True)
|
396 |
+
|
397 |
+
st.markdown('See [this Jupyter notebook](https://e-tweedy.github.io/2_BikeSaferPA_vis.html) for an in-depth data exploration and visualization process.')
|
398 |
+
|
399 |
+
######################################
|
400 |
+
### Populate model predictions tab ###
|
401 |
+
######################################
|
402 |
+
|
403 |
+
from lib.study_classif import ClassifierStudy
|
404 |
+
|
405 |
+
### Intro text ###
|
406 |
+
|
407 |
+
with model_intro_container:
|
408 |
+
st.subheader('Predicting cyclist outcome with BikeSaferPA')
|
409 |
+
|
410 |
+
st.markdown("""
|
411 |
+
An instance of the BikeSaferPA predictive model has been trained in advance on all cyclist samples in the PENNDOT dataset. This model is a gradient-boosted decision tree classifier model, and the model selection and evaluation process is covered in detail in [this Jupyter notebook](https://e-tweedy.github.io/3_BikeSaferPA_models.html).
|
412 |
+
|
413 |
+
The purpose of this tool is to allow the user to simulate a model prediction on a hypothetical sample, and then explain the model's prediction using SHAP values.
|
414 |
+
|
415 |
+
Expand the following sections to adjust the factors in a hypothetical cyclist crash, and the model will provide a predicted probability that the cyclist involved suffers serious injury or fatality. You'll find that some factors influence the prediction significantly, and others very little.
|
416 |
+
""")
|
417 |
+
|
418 |
+
### User inputs for model prediction ###
|
419 |
+
|
420 |
+
# Load the trained classifier study object
|
421 |
+
study = load_study()
|
422 |
+
|
423 |
+
# Initialize input sample. User inputs will update values.
|
424 |
+
sample = pd.DataFrame(columns = study.pipe['col'].feature_names_in_)
|
425 |
+
|
426 |
+
with model_settings_container:
|
427 |
+
# Expander for numerical inputs
|
428 |
+
with st.expander('Click here to expand or collapse numerical features'):
|
429 |
+
cols = st.columns(3)
|
430 |
+
with cols[0]:
|
431 |
+
sample.loc[0,'AGE'] = st.number_input('Cyclist age (yrs):',
|
432 |
+
min_value=0,step=1,value=30)
|
433 |
+
sample.loc[0,'SPEED_LIMIT'] = st.number_input('Posted speed limit (mph):',
|
434 |
+
min_value=0,max_value=100,step=5,value=25)
|
435 |
+
sample.loc[0,'CRASH_YEAR'] = st.number_input('Year crash took place:',
|
436 |
+
min_value=2002,max_value=2023,step=1)
|
437 |
+
with cols[1]:
|
438 |
+
for k in [0,1,2]:
|
439 |
+
sample.loc[0,f'{veh_data[k][1]}_COUNT']=st.number_input(
|
440 |
+
f'# {veh_data[k][0]}s involved:',
|
441 |
+
min_value=0,step=1,max_value=3
|
442 |
+
)
|
443 |
+
with cols[2]:
|
444 |
+
for k in [3,4]:
|
445 |
+
sample.loc[0,f'{veh_data[k][1]}_COUNT']=st.number_input(
|
446 |
+
f'# {veh_data[k][0]}s involved:',
|
447 |
+
min_value=0,step=1,max_value=3
|
448 |
+
)
|
449 |
+
# Expander for categorical inputs
|
450 |
+
with st.expander('Click here to expand or collapse categorical features'):
|
451 |
+
cols = st.columns(3)
|
452 |
+
with cols[0]:
|
453 |
+
sample.loc[0,'ILLUMINATION'] = st.selectbox(
|
454 |
+
'Illumination status:',
|
455 |
+
model_cat_data['ILLUMINATION'],
|
456 |
+
format_func= lambda x:x.replace('_',' and '),
|
457 |
+
)
|
458 |
+
sample.loc[0,'URBAN_RURAL'] = st.selectbox(
|
459 |
+
'Collision setting:',
|
460 |
+
model_cat_data['URBAN_RURAL'],
|
461 |
+
)
|
462 |
+
sample.loc[0,'TCD_TYPE'] = st.selectbox(
|
463 |
+
'Traffic control device:',
|
464 |
+
model_cat_data['TCD_TYPE'],
|
465 |
+
format_func= lambda x:x.replace('_',' '),
|
466 |
+
)
|
467 |
+
with cols[1]:
|
468 |
+
sample.loc[0,'VEH_ROLE'] = st.selectbox(
|
469 |
+
'Bicycle role in collision:',
|
470 |
+
model_cat_data['VEH_ROLE'],
|
471 |
+
format_func= lambda x:x.replace('_',' and '),
|
472 |
+
)
|
473 |
+
sample.loc[0,'IMPACT_SIDE'] = st.selectbox(
|
474 |
+
'Bicycle impact side:',
|
475 |
+
model_cat_data['IMPACT_SIDE'],
|
476 |
+
format_func= lambda x:x.replace('_',' '),
|
477 |
+
)
|
478 |
+
sample.loc[0,'GRADE'] = st.selectbox(
|
479 |
+
'Roadway grade:',
|
480 |
+
model_cat_data['GRADE'],
|
481 |
+
format_func= lambda x:x.replace('_',' '),
|
482 |
+
)
|
483 |
+
with cols[2]:
|
484 |
+
sample.loc[0,'RESTRAINT_HELMET'] = st.selectbox(
|
485 |
+
'Cyclist helmet status:',
|
486 |
+
model_cat_data['RESTRAINT_HELMET'],
|
487 |
+
format_func= lambda x:x.replace('_',' ')\
|
488 |
+
.replace('restraint','helmet'),
|
489 |
+
)
|
490 |
+
sample.loc[0,'COLLISION_TYPE'] = st.selectbox(
|
491 |
+
'Collision type:',
|
492 |
+
model_cat_data['COLLISION_TYPE'],
|
493 |
+
format_func= lambda x:x.replace('_',' ')\
|
494 |
+
.replace('dir','direction'),
|
495 |
+
)
|
496 |
+
sample.loc[0,'FEMALE'] = st.selectbox(
|
497 |
+
'Cyclist sex:*',[1,0],
|
498 |
+
format_func = lambda x:'F' if x==1 else 'M',
|
499 |
+
)
|
500 |
+
st.markdown('*Note: the PENNDOT dataset only has a binary sex feature.')
|
501 |
+
|
502 |
+
# Expander for binary inputs
|
503 |
+
with st.expander('Click here to expand or collapse binary features'):
|
504 |
+
cols = st.columns(len(model_bin_data))
|
505 |
+
for k,col in enumerate(cols):
|
506 |
+
with col:
|
507 |
+
for feat in model_bin_data[k]:
|
508 |
+
sample.loc[0,model_bin_data[k][feat][1]]=int(st.checkbox(model_bin_data[k][feat][0],
|
509 |
+
key=f'model_{feat}'))
|
510 |
+
|
511 |
+
### Model prediction and reporting result ###
|
512 |
+
|
513 |
+
with model_result_container:
|
514 |
+
# Fill these columns arbitrarily - they won't affect inference
|
515 |
+
# COUNTY, MUNICIPALITY, HOUR_OF_DAY, CRASH_MONTH used in pipeline for NaN imputation
|
516 |
+
# This version of model doesn't use temporal features as we set cyc_method=None
|
517 |
+
for feat in ['HOUR_OF_DAY','DAY_OF_WEEK','CRASH_MONTH','COUNTY','MUNICIPALITY']:
|
518 |
+
sample.loc[0,feat]=1
|
519 |
+
|
520 |
+
# Predict and report result
|
521 |
+
study.predict_proba_pipeline(X_test=sample)
|
522 |
+
|
523 |
+
st.write(f'**BikeSaferPA predicts a :red[{100*float(study.y_predict_proba):.2f}%] probability that a cyclist suffers serious injury or fatality under these conditions.**')
|
524 |
+
|
525 |
+
### SHAP values ####
|
526 |
+
|
527 |
+
with model_shap_container:
|
528 |
+
st.subheader('SHAP analysis for this hypothetical prediction')
|
529 |
+
|
530 |
+
st.markdown("""
|
531 |
+
SHAP (SHapley Additive exPlainer) values provide an excellent method for assessing how various input features influence a model's predictions. One significant advantage is that SHAP values are 'model agnostic' - they effectively explain the predictions made by many different types of machine learning classifiers.
|
532 |
+
|
533 |
+
The following 'force plot' shows the influence of each feature's SHAP value on the model's predicted probability that the cyclist suffers serious injury or fatality. A feature with a positive (resp. negative) SHAP value indicates that the feature's value pushes the predicted probability higher (resp. lower), which in the force plot corresponds to a push to the right (resp. left).
|
534 |
+
|
535 |
+
The force plot will update as you adjust input features in the menu above.
|
536 |
+
""")
|
537 |
+
|
538 |
+
# SHAP will just explain classifier, so need transformed X_train and X_test
|
539 |
+
pipe = study.pipe_fitted
|
540 |
+
sample_trans = pipe[:-1].transform(sample)
|
541 |
+
|
542 |
+
# # Need masker for linear model
|
543 |
+
# masker = shap.maskers.Independent(data=X_train_trans)
|
544 |
+
|
545 |
+
# Initialize explainer and compute and store SHAP values as an explainer object
|
546 |
+
explainer = shap.TreeExplainer(pipe[-1], feature_names = pipe['col'].get_feature_names_out())
|
547 |
+
shap_values = explainer(sample_trans)
|
548 |
+
sample_trans = pd.DataFrame(sample_trans,columns=pipe['col'].get_feature_names_out())
|
549 |
+
|
550 |
+
# def st_shap(plot, height=None):
|
551 |
+
# shap_html = f"<head>{shap.getjs()}</head><body>{plot.html()}</body>"
|
552 |
+
# components.html(shap_html, height=height)
|
553 |
+
fig=shap.plots.force(explainer.expected_value[1],shap_values.values[0][:,1],sample_trans,
|
554 |
+
figsize=(20,3),show=False,matplotlib=True)
|
555 |
+
st.pyplot(fig)
|
app_data.pickle
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c8882e6c8ec43e8a4e96724b21f6f1c11347cc9e18317d1a0dbbd5621bd93812
|
3 |
+
size 4990
|
crashes.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
cyclists.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
lib/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
lib/.ipynb_checkpoints/__init__-checkpoint.py
ADDED
File without changes
|
lib/.ipynb_checkpoints/get_data-checkpoint.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
|
3 |
+
def extract_data(year):
|
4 |
+
"""
|
5 |
+
A function for loading data corresponding to an individual
|
6 |
+
year from a CSV file. Data is then preprocessed and the
|
7 |
+
following dataframes are returned:
|
8 |
+
- 'bicycles': samples are bicycle vehicles which were
|
9 |
+
involved in crashes.
|
10 |
+
- 'persons': samples are all individuals involved in
|
11 |
+
crashes involving bicycles.
|
12 |
+
- 'crashes': samples are crash events involving bicycles.
|
13 |
+
- 'roadway': additional features for
|
14 |
+
crash events, related to roadway attributes and conditions.
|
15 |
+
"""
|
16 |
+
|
17 |
+
# Retrieve vehicle samples corresponding to bicycles.
|
18 |
+
# Note that in some samples VEH_TYPE is string, others float
|
19 |
+
vehicles = pd.read_csv(f'data/raw_csv/VEHICLE_{year}_Statewide.csv',encoding='latin')
|
20 |
+
bicycle_filter = vehicles.VEH_TYPE.isin([20,21,'20','21'])
|
21 |
+
cols = ['CRN', 'GRADE', 'IMPACT_POINT',
|
22 |
+
'RDWY_ALIGNMENT','UNIT_NUM',
|
23 |
+
'VEH_MOVEMENT', 'VEH_POSITION','VEH_ROLE', 'VEH_TYPE']
|
24 |
+
bicycles = vehicles[bicycle_filter][cols]
|
25 |
+
del vehicles
|
26 |
+
|
27 |
+
# Merge onto bicycles dataframe some additional features from cycle
|
28 |
+
cycles = pd.read_csv(f'data/raw_csv/CYCLE_{year}_Statewide.csv',encoding='latin')
|
29 |
+
cols = ['CRN','UNIT_NUM','PC_HDLGHT_IND', 'PC_HLMT_IND','PC_REAR_RFLTR_IND']
|
30 |
+
bicycles = bicycles.merge(cycles[cols],how='left',on=['CRN','UNIT_NUM'])
|
31 |
+
del cycles
|
32 |
+
|
33 |
+
# Retrieve information about persons involved in crashes involving bikes
|
34 |
+
# (not just the persons riding the bikes)
|
35 |
+
persons = pd.read_csv(f'data/raw_csv/PERSON_{year}_Statewide.csv',encoding='latin')
|
36 |
+
cols = ['AGE','CRN','INJ_SEVERITY','PERSON_TYPE',
|
37 |
+
'RESTRAINT_HELMET','SEX', 'TRANSPORTED', 'UNIT_NUM']
|
38 |
+
persons = persons[persons.CRN.isin(bicycles.CRN)][cols]
|
39 |
+
|
40 |
+
# Retrieve crash samples involving bikes
|
41 |
+
crashes = pd.read_csv(f'data/raw_csv/CRASH_{year}_Statewide.csv',encoding='latin')
|
42 |
+
cols = ['CRN','ARRIVAL_TM','DISPATCH_TM','COUNTY','MUNICIPALITY','DEC_LAT','DEC_LONG',
|
43 |
+
'BICYCLE_DEATH_COUNT','BICYCLE_SUSP_SERIOUS_INJ_COUNT',
|
44 |
+
'BUS_COUNT','COMM_VEH_COUNT','HEAVY_TRUCK_COUNT','SMALL_TRUCK_COUNT','SUV_COUNT','VAN_COUNT',
|
45 |
+
'CRASH_MONTH', 'CRASH_YEAR','DAY_OF_WEEK','HOUR_OF_DAY',
|
46 |
+
'COLLISION_TYPE','ILLUMINATION','INTERSECT_TYPE',
|
47 |
+
'LOCATION_TYPE','RELATION_TO_ROAD','TIME_OF_DAY',
|
48 |
+
'ROAD_CONDITION','TCD_TYPE','TCD_FUNC_CD','URBAN_RURAL',
|
49 |
+
'WEATHER1','WEATHER2']
|
50 |
+
crashes = crashes[crashes.CRN.isin(bicycles.CRN)][cols]
|
51 |
+
|
52 |
+
# Retrieve roadway data involving bikes
|
53 |
+
roadway = pd.read_csv(f'data/raw_csv/ROADWAY_{year}_Statewide.csv',encoding='latin')
|
54 |
+
cols = ['CRN','SPEED_LIMIT','RDWY_COUNTY']
|
55 |
+
roadway = roadway[roadway.CRN.isin(bicycles.CRN)][cols]
|
56 |
+
|
57 |
+
# Merge onto out bicycle_crashes and ped_crashes dataframe
|
58 |
+
# some additional flag features.
|
59 |
+
# Include flag features corresponding to driver impairment,
|
60 |
+
# driver inattention, other driver attributes,relevant road conditions, etc.
|
61 |
+
flags = pd.read_csv(f'data/raw_csv/FLAG_{year}_Statewide.csv',encoding='latin')
|
62 |
+
cols = ['AGGRESSIVE_DRIVING','ALCOHOL_RELATED','ANGLE_CRASH','CELL_PHONE','COMM_VEHICLE',
|
63 |
+
'CRN','CROSS_MEDIAN','CURVED_ROAD','CURVE_DVR_ERROR','DISTRACTED','DRINKING_DRIVER',
|
64 |
+
'DRUGGED_DRIVER','DRUG_RELATED','FATIGUE_ASLEEP','HO_OPPDIR_SDSWP','ICY_ROAD',
|
65 |
+
'ILLUMINATION_DARK','IMPAIRED_DRIVER','INTERSECTION','LANE_DEPARTURE',
|
66 |
+
'NHTSA_AGG_DRIVING','NO_CLEARANCE',
|
67 |
+
'NON_INTERSECTION','REAR_END','RUNNING_RED_LT','RUNNING_STOP_SIGN',
|
68 |
+
'RURAL','SNOW_SLUSH_ROAD','SPEEDING','SPEEDING_RELATED',
|
69 |
+
'SUDDEN_DEER','TAILGATING','URBAN','WET_ROAD','WORK_ZONE',
|
70 |
+
'MATURE_DRIVER','YOUNG_DRIVER']
|
71 |
+
crashes = crashes.merge(flags[cols],how='left',on='CRN')
|
72 |
+
del flags
|
73 |
+
|
74 |
+
return bicycles, persons, crashes, roadway
|
lib/.ipynb_checkpoints/study_classif-checkpoint.py
ADDED
@@ -0,0 +1,787 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import pandas as pd
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
import seaborn as sns
|
5 |
+
import shap
|
6 |
+
from sklearn.feature_selection import chi2, SelectKBest, mutual_info_classif, f_classif
|
7 |
+
from sklearn.metrics import accuracy_score, log_loss, confusion_matrix, f1_score, fbeta_score, roc_auc_score
|
8 |
+
from sklearn.metrics import ConfusionMatrixDisplay, RocCurveDisplay, classification_report, precision_recall_curve
|
9 |
+
from sklearn.model_selection import train_test_split, RepeatedStratifiedKFold, cross_val_score, RandomizedSearchCV, StratifiedKFold
|
10 |
+
from sklearn.preprocessing import OneHotEncoder, OrdinalEncoder, StandardScaler, FunctionTransformer, SplineTransformer, PolynomialFeatures
|
11 |
+
from sklearn.decomposition import PCA
|
12 |
+
from sklearn.linear_model import LogisticRegression
|
13 |
+
from sklearn.ensemble import HistGradientBoostingClassifier, GradientBoostingClassifier
|
14 |
+
# from lightgbm import LGBMClassifier
|
15 |
+
from sklearn.base import BaseEstimator, TransformerMixin, clone
|
16 |
+
from sklearn.utils.validation import check_is_fitted
|
17 |
+
from sklearn.impute import SimpleImputer
|
18 |
+
from sklearn.pipeline import Pipeline, make_pipeline
|
19 |
+
from sklearn.compose import ColumnTransformer, make_column_transformer
|
20 |
+
from lib.transform_data import *
|
21 |
+
|
22 |
+
class ClassifierStudy():
|
23 |
+
"""
|
24 |
+
A class that contains tools for studying a classifier pipeline
|
25 |
+
|
26 |
+
Parameters:
|
27 |
+
-----------
|
28 |
+
classifier : a scikit-learn compatible binary classifier
|
29 |
+
X : pd.DataFrame
|
30 |
+
dataframe of features
|
31 |
+
y : pd.Series
|
32 |
+
series of binary target values corresponding to X
|
33 |
+
classifier_name : str or None
|
34 |
+
if provided, will use as classifier name in pipeline
|
35 |
+
if not, will use 'clf' as name
|
36 |
+
features : dict
|
37 |
+
a dictionary whose keys are the feature types
|
38 |
+
'cyc','cat','ord','num','bin' and whose values
|
39 |
+
are lists of features of each type.
|
40 |
+
|
41 |
+
Methods:
|
42 |
+
-------
|
43 |
+
set_data, set_features, set_state
|
44 |
+
sets or resets attributes of self
|
45 |
+
build_pipeline
|
46 |
+
builds out pipeline based on supplied specs
|
47 |
+
cv_score
|
48 |
+
runs k-fold cross validation and reports scores
|
49 |
+
randomized_search
|
50 |
+
runs randomized search with cross validation
|
51 |
+
and reports results
|
52 |
+
fit_pipeline
|
53 |
+
fits the model pipeline and stores as
|
54 |
+
self.pipe_fitted
|
55 |
+
predict_proba_pipeline
|
56 |
+
uses a fitted pipeline to compute predicted
|
57 |
+
probabilities for test or validation set
|
58 |
+
score_pipeline
|
59 |
+
scores predicted probabilities
|
60 |
+
|
61 |
+
"""
|
62 |
+
def __init__(self, classifier=None, X = None, y = None,
|
63 |
+
features = None,classifier_name = None,
|
64 |
+
random_state=42):
|
65 |
+
self.classifier = classifier
|
66 |
+
if X is not None:
|
67 |
+
self.X = X.copy()
|
68 |
+
if y is not None:
|
69 |
+
self.y = y.copy()
|
70 |
+
if features is not None:
|
71 |
+
self.features = features.copy()
|
72 |
+
self.random_state=random_state
|
73 |
+
self.pipe, self.pipe_fitted = None, None
|
74 |
+
self.classifier_name = classifier_name
|
75 |
+
self.X_val, self.y_val = None, None
|
76 |
+
self.y_predict_proba = None
|
77 |
+
self.best_params, self.best_n_components = None, None
|
78 |
+
self.shap_vals = None
|
79 |
+
|
80 |
+
def set_data(self,X=None,y=None):
|
81 |
+
"""Method to set or reset feature and/or target data"""
|
82 |
+
if X is not None:
|
83 |
+
self.X = X.copy()
|
84 |
+
if y is not None:
|
85 |
+
self.y = y.copy()
|
86 |
+
|
87 |
+
def set_features(self,features):
|
88 |
+
"""Method to set or reset the feature dictionary"""
|
89 |
+
if features is not None:
|
90 |
+
self.features = features.copy()
|
91 |
+
|
92 |
+
def set_state(self,random_state):
|
93 |
+
"""Method to set or reset the random_state"""
|
94 |
+
self.random_state = random_state
|
95 |
+
|
96 |
+
def build_pipeline(self, cat_method = 'onehot',cyc_method = 'spline',num_ss=True,
|
97 |
+
over_sample = False, pca=False,n_components=None,
|
98 |
+
select_features = False,score_func=None,k='all',
|
99 |
+
poly_features = False, degree=2, interaction_only=False):
|
100 |
+
"""
|
101 |
+
Method to build the model pipeline
|
102 |
+
Parameters:
|
103 |
+
-----------
|
104 |
+
cat_method : str
|
105 |
+
specifies whether to encode categorical
|
106 |
+
variables as one-hot vectors or ordinals
|
107 |
+
must be either 'onehot' or 'ord'
|
108 |
+
cyc_method : str
|
109 |
+
specifies whether to encode cyclical features
|
110 |
+
with sine/cosine encoding or periodic splines
|
111 |
+
must be one of 'trig', 'spline', 'interact-trig',
|
112 |
+
'interact-spline','onehot', 'ord', or None
|
113 |
+
- If 'trig' or 'spline', will set up periodic encoder
|
114 |
+
with desired method
|
115 |
+
- If 'onehot' or 'ord', will set up appropriate
|
116 |
+
categorical encoder
|
117 |
+
- If 'interact-{method}', will use <method> encoding for HOUR_OF_DAY,
|
118 |
+
encode DAY_OF_WEEK as a binary feature expressing whether
|
119 |
+
the day is a weekend day, and then include interaction
|
120 |
+
features among this set via PolynomialFeatures.
|
121 |
+
- If None, will leave out cyclical features altogether
|
122 |
+
num_ss : bool
|
123 |
+
Whether or not to apply StandardScaler on the numerical features
|
124 |
+
over_sample : bool
|
125 |
+
set to True to include imblearn.over_sampling.RandomOverSampler step
|
126 |
+
pca : bool
|
127 |
+
set to True to include sklearn.decomposition.PCA step
|
128 |
+
n_components : int or None
|
129 |
+
number of components for sklearn.decomposition.PCA
|
130 |
+
select_features : bool
|
131 |
+
set to True to include sklearn.feature_selection.SelectKBest step
|
132 |
+
score_func : callable
|
133 |
+
score function to use for sklearn.feature_selection.SelectKBest
|
134 |
+
recommended: chi2, f_classif, or mutual_info_classif
|
135 |
+
k : int or 'all'
|
136 |
+
number of features for sklearn.feature_selection.SelectKBest
|
137 |
+
poly_features : bool
|
138 |
+
set to True to include sklearn.preprocessing.PolynomialFeatures step
|
139 |
+
degree : int
|
140 |
+
max degree for sklearn.preprocessing.PolynomialFeatures
|
141 |
+
interaction_only : bool
|
142 |
+
whether or not sklearn.preprocessing.PolynomialFeatures will be limited
|
143 |
+
to interaction terms only
|
144 |
+
"""
|
145 |
+
|
146 |
+
# Define transformer for categorical features
|
147 |
+
if cat_method == 'onehot':
|
148 |
+
cat_encoder = ('ohe',OneHotEncoder(handle_unknown='infrequent_if_exist'))
|
149 |
+
|
150 |
+
elif cat_method == 'ord':
|
151 |
+
cat_encoder = ('oe',OrdinalEncoder(handle_unknown='use_encoded_value',unknown_value=np.nan))
|
152 |
+
else:
|
153 |
+
raise ValueError("cat_method must be either 'onehot' or 'ord'")
|
154 |
+
|
155 |
+
cat_transform = Pipeline([('si',SimpleImputer(strategy='most_frequent')),cat_encoder])
|
156 |
+
|
157 |
+
# Define transformer for cyclic features
|
158 |
+
cyc_dict = {'HOUR_OF_DAY':24,'DAY_OF_WEEK':7}
|
159 |
+
if cyc_method == 'trig':
|
160 |
+
cyc_transform = [(f'{feat}_cos',cos_transformer(cyc_dict[feat]),[feat]) for feat in self.features['cyc']]+\
|
161 |
+
[(f'{feat}_sin',sin_transformer(cyc_dict[feat]),[feat]) for feat in self.features['cyc']]
|
162 |
+
elif cyc_method =='spline':
|
163 |
+
cyc_transform = [(f'{feat}_cyclic',
|
164 |
+
periodic_spline_transformer(cyc_dict[feat],n_splines=cyc_dict[feat]//2),
|
165 |
+
[feat]) for feat in self.features['cyc']]
|
166 |
+
elif cyc_method == 'onehot':
|
167 |
+
cyc_encoder = ('ohe_cyc',OneHotEncoder(handle_unknown='infrequent_if_exist'))
|
168 |
+
cyc_transform = [('cyc',Pipeline([cyc_encoder]),self.features['cyc'])]
|
169 |
+
elif cyc_method == 'ord':
|
170 |
+
cyc_encoder = ('oe_cyc',OrdinalEncoder(handle_unknown='use_encoded_value',unknown_value=np.nan))
|
171 |
+
cyc_transform = [('cyc',Pipeline([cyc_encoder]),self.features['cyc'])]
|
172 |
+
elif cyc_method == 'interact-spline':
|
173 |
+
hour_transform = (f'hour_cyc',periodic_spline_transformer(cyc_dict['HOUR_OF_DAY'],n_splines=12),['HOUR_OF_DAY'])
|
174 |
+
wkend_transform = ('wkend',FunctionTransformer(lambda x: (x.isin([1,7])).astype(int)),['DAY_OF_WEEK'])
|
175 |
+
cyc_transform = [('cyc',Pipeline([('cyc_col',ColumnTransformer([hour_transform, wkend_transform],
|
176 |
+
remainder='drop',verbose_feature_names_out=False)),
|
177 |
+
('cyc_poly',PolynomialFeatures(degree=2,interaction_only=True,
|
178 |
+
include_bias=False))]),
|
179 |
+
self.features['cyc'])]
|
180 |
+
elif cyc_method == 'interact-trig':
|
181 |
+
hour_transform = [(f'HOUR_cos',cos_transformer(cyc_dict['HOUR_OF_DAY']),['HOUR_OF_DAY']),
|
182 |
+
(f'HOUR_sin',sin_transformer(cyc_dict['HOUR_OF_DAY']),['HOUR_OF_DAY'])]
|
183 |
+
wkend_transform = ('wkend',FunctionTransformer(lambda x: (x.isin([1,7])).astype(int)),['DAY_OF_WEEK'])
|
184 |
+
cyc_transform = [('cyc',Pipeline([('cyc_col',ColumnTransformer(hour_transform+[wkend_transform],
|
185 |
+
remainder='drop',verbose_feature_names_out=False)),
|
186 |
+
('cyc_poly',PolynomialFeatures(degree=2,interaction_only=True,
|
187 |
+
include_bias=False))]),
|
188 |
+
self.features['cyc'])]
|
189 |
+
elif cyc_method is None:
|
190 |
+
cyc_transform = [('cyc','passthrough',[])]
|
191 |
+
else:
|
192 |
+
raise ValueError("cyc_method must be one of 'trig','spline','interact','onehot','ord',or None")
|
193 |
+
|
194 |
+
# Define numerical transform
|
195 |
+
num_transform = ('num',StandardScaler(),self.features['num']) if num_ss else\
|
196 |
+
('num','passthrough',self.features['num'])
|
197 |
+
|
198 |
+
# Define column transformer
|
199 |
+
col_transform = ColumnTransformer([('cat',cat_transform,self.features['cat']),
|
200 |
+
('ord','passthrough',self.features['ord']),
|
201 |
+
num_transform,
|
202 |
+
('bin',SimpleImputer(strategy='most_frequent'),
|
203 |
+
self.features['bin'])]+\
|
204 |
+
cyc_transform,
|
205 |
+
remainder='drop',verbose_feature_names_out=False)
|
206 |
+
|
207 |
+
steps = [('col',col_transform)]
|
208 |
+
|
209 |
+
if 'AGE' in self.features['num']:
|
210 |
+
steps.insert(0,('gi_age',GroupImputer(target = 'AGE', group_cols=['COUNTY'],strategy='median')))
|
211 |
+
if 'HOUR_OF_DAY' in self.features['cyc']:
|
212 |
+
steps.insert(0,('gi_hour',GroupImputer(target = 'HOUR_OF_DAY', group_cols=['ILLUMINATION','CRASH_MONTH'],strategy='mode')))
|
213 |
+
# Insert optional steps as needed
|
214 |
+
if over_sample:
|
215 |
+
steps.insert(0,('os',RandomOverSampler(random_state=self.random_state)))
|
216 |
+
if poly_features:
|
217 |
+
steps.append(('pf',PolynomialFeatures(degree=degree,interaction_only=interaction_only)))
|
218 |
+
if select_features:
|
219 |
+
steps.append(('fs',SelectKBest(score_func = score_func, k = k)))
|
220 |
+
if pca:
|
221 |
+
steps.append(('pca',PCA(n_components=n_components,random_state=self.random_state)))
|
222 |
+
# Append classifier if provided
|
223 |
+
if self.classifier is not None:
|
224 |
+
if self.classifier_name is not None:
|
225 |
+
steps.append((f'{self.classifier_name}_clf',self.classifier))
|
226 |
+
else:
|
227 |
+
steps.append(('clf',self.classifier))
|
228 |
+
|
229 |
+
# Initialize pipeline
|
230 |
+
self.pipe = Pipeline(steps)
|
231 |
+
|
232 |
+
def cv_score(self, scoring = 'roc_auc', n_splits = 5, n_repeats=3, thresh = 0.5, beta = 1,
|
233 |
+
return_mean_score=False,print_mean_score=True,print_scores=False, n_jobs=-1,
|
234 |
+
eval_size=0.1,eval_metric='auc'):
|
235 |
+
"""
|
236 |
+
Method for performing cross validation via RepeatedStratifiedKFold
|
237 |
+
|
238 |
+
Parameters:
|
239 |
+
-----------
|
240 |
+
scoring : str
|
241 |
+
scoring function to use. must be one of
|
242 |
+
'roc_auc','acc','f1','','f1w'
|
243 |
+
thresh : float
|
244 |
+
the classification threshold for computing y_pred
|
245 |
+
from y_pred_proba
|
246 |
+
beta : float
|
247 |
+
the beta-value to use in the f_beta score, if chosen
|
248 |
+
n_splits, n_repeats : int, int
|
249 |
+
number of splits and number of repeat iterations
|
250 |
+
for sklearn.model_selection.RepeatedStratifiedKFold
|
251 |
+
return_mean_score : bool
|
252 |
+
whether or not to return the mean score
|
253 |
+
print_mean_score : bool
|
254 |
+
whether to print out a report of the mean score
|
255 |
+
print_scores : bool
|
256 |
+
whether to print out a report of CV scores for all folds
|
257 |
+
n_jobs : int or None
|
258 |
+
number of CPU cores to use for parallel processing
|
259 |
+
-1 uses all available cores, and None defaults to 1
|
260 |
+
eval_size : float
|
261 |
+
Fraction of the training set to use for early stopping eval set
|
262 |
+
eval_metric : str
|
263 |
+
eval metric to use in early stopping
|
264 |
+
Returns: None or mean_score, depending on return_mean_score setting
|
265 |
+
--------
|
266 |
+
"""
|
267 |
+
assert self.pipe is not None, 'No pipeline exists; first build a pipeline using build_pipeline.'
|
268 |
+
assert (self.X is not None)&(self.y is not None), 'X and/or y does not exist. First supply X and y using set_data.'
|
269 |
+
assert 'clf' in self.pipe.steps[-1][0], 'The pipeline has no classifier. Build a pipeline with a classifier first.'
|
270 |
+
assert scoring in ['roc_auc','acc','f1','fb','f1w'],"scoring must be one of 'roc_auc','acc','f1','fb','f1w'"
|
271 |
+
|
272 |
+
# Initialize CV iterator
|
273 |
+
kf = RepeatedStratifiedKFold(n_splits = n_splits, n_repeats=n_repeats,
|
274 |
+
random_state=self.random_state)
|
275 |
+
# Restrict to features supplied in self.features
|
276 |
+
X = self.X[[feat for feat_type in self.features for feat in self.features[feat_type]]]
|
277 |
+
|
278 |
+
lgb_es=False
|
279 |
+
# if isinstance(self.pipe[-1],LGBMClassifier):
|
280 |
+
# if 'early_stopping_round' in self.pipe[-1].get_params():
|
281 |
+
# if self.pipe[-1].get_params()['early_stopping_rounds'] is not None:
|
282 |
+
# lgb_es=True
|
283 |
+
|
284 |
+
scores = []
|
285 |
+
# Iterate over folds and train, predict, score
|
286 |
+
for i,(train_idx,test_idx) in enumerate(kf.split(X,self.y)):
|
287 |
+
fold_X_train = X.iloc[train_idx,:]
|
288 |
+
fold_X_test = X.iloc[test_idx,:]
|
289 |
+
fold_y_train = self.y.iloc[train_idx]
|
290 |
+
fold_y_test = self.y.iloc[test_idx]
|
291 |
+
|
292 |
+
pipe=clone(self.pipe)
|
293 |
+
if lgb_es:
|
294 |
+
fold_X_train,fold_X_es,fold_y_train,fold_y_es = train_test_split(fold_X_train,fold_y_train,
|
295 |
+
stratify=fold_y_train,test_size=eval_size,
|
296 |
+
random_state=self.random_state)
|
297 |
+
trans_pipe = pipe[:-1]
|
298 |
+
trans_pipe.fit_transform(fold_X_train)
|
299 |
+
fold_X_es = trans_pipe.transform(fold_X_es)
|
300 |
+
clf_name = pipe.steps[-1][0]
|
301 |
+
fit_params = {f'{clf_name}__eval_set':[(fold_X_es,fold_y_es)],
|
302 |
+
f'{clf_name}__eval_metric':eval_metric,
|
303 |
+
f'{clf_name}__verbose':0}
|
304 |
+
else:
|
305 |
+
fit_params = {}
|
306 |
+
|
307 |
+
pipe.fit(fold_X_train,fold_y_train,**fit_params)
|
308 |
+
fold_y_pred_proba = pipe.predict_proba(fold_X_test)[:,1]
|
309 |
+
|
310 |
+
if scoring == 'roc_auc':
|
311 |
+
fold_score = roc_auc_score(fold_y_test, fold_y_pred_proba)
|
312 |
+
else:
|
313 |
+
fold_y_pred = (fold_y_pred_proba >= thresh).astype('int')
|
314 |
+
if scoring == 'acc':
|
315 |
+
fold_score = accuracy_score(fold_y_test,fold_y_pred)
|
316 |
+
elif scoring == 'f1':
|
317 |
+
fold_score = f1_score(fold_y_test,fold_y_pred)
|
318 |
+
elif scoring == 'f1w':
|
319 |
+
fold_score = f1_score(fold_y_test,fold_y_pred,average='weighted')
|
320 |
+
else:
|
321 |
+
fold_score = fbeta_score(fold_y_test,fold_y_pred,beta=beta)
|
322 |
+
scores.append(fold_score)
|
323 |
+
|
324 |
+
# Average and report
|
325 |
+
mean_score = np.mean(scores)
|
326 |
+
if print_scores:
|
327 |
+
print(f'CV scores using {scoring} score: {scores} \nMean score: {mean_score}')
|
328 |
+
if print_mean_score:
|
329 |
+
print(f'Mean CV {scoring} score: {mean_score}')
|
330 |
+
if return_mean_score:
|
331 |
+
return mean_score
|
332 |
+
|
333 |
+
def randomized_search(self, params, n_components = None, n_iter=10,
|
334 |
+
scoring='roc_auc',cv=5,refit=False,top_n=10, n_jobs=-1):
|
335 |
+
"""
|
336 |
+
Method for performing randomized search with cross validation on a given dictionary of parameter distributions
|
337 |
+
Also displays a table of results the best top_n iterations
|
338 |
+
|
339 |
+
Parameters:
|
340 |
+
----------
|
341 |
+
params : dict
|
342 |
+
parameter distributions to use for RandomizedSearchCV
|
343 |
+
n_components : int, or list, or None
|
344 |
+
number of components for sklearn.decomposition.PCA
|
345 |
+
- if int, will reset the PCA layer in self.pipe with provided value
|
346 |
+
- if list, must be list of ints, which will be included in
|
347 |
+
RandomizedSearchCV parameter distribution
|
348 |
+
scoring : str
|
349 |
+
scoring function for sklearn.model_selection.cross_val_score
|
350 |
+
n_iter : int
|
351 |
+
number of iterations to use in RandomizedSearchCV
|
352 |
+
refit : bool
|
353 |
+
whether to refit a final classifier with best parameters
|
354 |
+
- if False, will only set self.best_params and self.best_score
|
355 |
+
- if True, will set self.best_estimator in addition
|
356 |
+
top_n : int or None
|
357 |
+
if int, will display results from top_n best iterations only
|
358 |
+
if None, will display all results
|
359 |
+
n_jobs : int or None
|
360 |
+
number of CPU cores to use for parallel processing
|
361 |
+
-1 uses all available cores, and None defaults to 1
|
362 |
+
"""
|
363 |
+
assert self.pipe is not None, 'No pipeline exists; first build a pipeline using build_pipeline.'
|
364 |
+
assert (self.X is not None)&(self.y is not None), 'X and/or y does not exist. First supply X and y using set_data.'
|
365 |
+
assert 'clf' in self.pipe.steps[-1][0], 'The pipeline has no classifier. Build a pipeline with a classifier first.'
|
366 |
+
assert (n_components is None)|('pca' in self.pipe.named_steps), 'Your pipeline has no PCA step. Build a pipeline with PCA first.'
|
367 |
+
assert (len(params)>0)|(type(n_components)==list), 'Either pass a parameter distribution or a list of n_components values.'
|
368 |
+
|
369 |
+
# Add estimator name prefix to hyperparams
|
370 |
+
params = {self.pipe.steps[-1][0]+'__'+key:params[key] for key in params}
|
371 |
+
|
372 |
+
# Process supplied n_components
|
373 |
+
if type(n_components)==list:
|
374 |
+
params['pca__n_components']=n_components
|
375 |
+
elif type(n_components)==int:
|
376 |
+
self.pipe['pca'].set_params(n_components=n_components)
|
377 |
+
|
378 |
+
# Restrict to features supplied in self.features
|
379 |
+
X = self.X[[feat for feat_type in self.features for feat in self.features[feat_type]]]
|
380 |
+
|
381 |
+
# Initialize rs and fit
|
382 |
+
rs = RandomizedSearchCV(self.pipe, param_distributions = params,
|
383 |
+
n_iter=n_iter, scoring = scoring, cv = cv,refit=refit,
|
384 |
+
random_state=self.random_state, n_jobs=n_jobs)
|
385 |
+
|
386 |
+
rs.fit(X,self.y)
|
387 |
+
|
388 |
+
# Display top n scores
|
389 |
+
results = rs.cv_results_
|
390 |
+
results_df = pd.DataFrame(results['params'])
|
391 |
+
param_names = list(results_df.columns)
|
392 |
+
results_df[f'mean cv score ({scoring})']=pd.Series(results['mean_test_score'])
|
393 |
+
results_df = results_df.set_index(param_names).sort_values(by=f'mean cv score ({scoring})',ascending=False)
|
394 |
+
if top_n is not None:
|
395 |
+
display(results_df.head(top_n).style\
|
396 |
+
.highlight_max(axis=0, props='color:white; font-weight:bold; background-color:seagreen;'))
|
397 |
+
else:
|
398 |
+
display(results_df.style\
|
399 |
+
.highlight_max(axis=0, props='color:white; font-weight:bold; background-color:seagreen;'))
|
400 |
+
if refit:
|
401 |
+
self.best_estimator = rs.best_estimator_
|
402 |
+
best_params = rs.best_params_
|
403 |
+
self.best_params = {key.split('__')[-1]:best_params[key] for key in best_params if key.split('__')[0]!='pca'}
|
404 |
+
self.best_n_components = next((best_params[key] for key in best_params if key.split('__')[0]=='pca'), None)
|
405 |
+
self.best_score = rs.best_score_
|
406 |
+
|
407 |
+
def fit_pipeline(self,split_first=False, eval_size=0.1,eval_metric='auc'):
|
408 |
+
"""
|
409 |
+
Method for fitting self.pipeline on self.X,self.y
|
410 |
+
Parameters:
|
411 |
+
-----------
|
412 |
+
split_first : bool
|
413 |
+
if True, a train_test_split will be performed first
|
414 |
+
and the validation set will be stored
|
415 |
+
early_stopping : bool
|
416 |
+
Indicates whether we will use early_stopping for lightgbm.
|
417 |
+
If true, will split off an eval set prior to k-fold split
|
418 |
+
eval_size : float
|
419 |
+
Fraction of the training set to use for early stopping eval set
|
420 |
+
eval_metric : str
|
421 |
+
eval metric to use in early stopping
|
422 |
+
"""
|
423 |
+
# Need pipe and X to fit
|
424 |
+
assert self.pipe is not None, 'No pipeline exists; first build a pipeline using build_pipeline.'
|
425 |
+
assert self.X is not None, 'X does not exist. First set X.'
|
426 |
+
|
427 |
+
# If no y provided, then no pipeline steps should require y
|
428 |
+
step_list = [step[0] for step in self.pipe.steps]
|
429 |
+
assert (('clf' not in step_list[-1])&('kf' not in step_list))|(self.y is not None), 'You must provide targets y if pipeline has a classifier step or feature selection step.'
|
430 |
+
|
431 |
+
# Don't need to do a train-test split without a classifier
|
432 |
+
assert (split_first==False)|('clf' in step_list[-1]), 'Only need train-test split if you have a classifier.'
|
433 |
+
|
434 |
+
if split_first:
|
435 |
+
X_train,X_val,y_train,y_val = train_test_split(self.X,self.y,stratify=self.y,
|
436 |
+
test_size=0.2,random_state=self.random_state)
|
437 |
+
self.X_val = X_val
|
438 |
+
self.y_val = y_val
|
439 |
+
else:
|
440 |
+
X_train = self.X.copy()
|
441 |
+
if self.y is not None:
|
442 |
+
y_train = self.y.copy()
|
443 |
+
|
444 |
+
# Restrict to features supplied in self.features
|
445 |
+
X_train = X_train[[feat for feat_type in self.features for feat in self.features[feat_type]]]
|
446 |
+
|
447 |
+
# If LGBM early stopping, then need to split off eval_set and define fit_params
|
448 |
+
# if isinstance(self.pipe[-1],LGBMClassifier):
|
449 |
+
# if self.pipe[-1].get_params()['early_stopping_rounds'] is not None:
|
450 |
+
# X_train,X_es,y_train,y_es = train_test_split(X_train,y_train,
|
451 |
+
# test_size=eval_size,
|
452 |
+
# stratify=y_train,
|
453 |
+
# random_state=self.random_state)
|
454 |
+
# trans_pipe = self.pipe[:-1]
|
455 |
+
# trans_pipe.fit_transform(X_train)
|
456 |
+
# X_es = trans_pipe.transform(X_es)
|
457 |
+
# clf_name = self.pipe.steps[-1][0]
|
458 |
+
# fit_params = {f'{clf_name}__eval_set':[(X_es,y_es)],
|
459 |
+
# f'{clf_name}__eval_metric':eval_metric,
|
460 |
+
# f'{clf_name}__verbose':0}
|
461 |
+
# else:
|
462 |
+
# fit_params = {}
|
463 |
+
# else:
|
464 |
+
# fit_params = {}
|
465 |
+
fit_params = {}
|
466 |
+
# Fit and store fitted pipeline. If no classifier, fit_transform X_train and store transformed version
|
467 |
+
pipe = self.pipe
|
468 |
+
if 'clf' in step_list[-1]:
|
469 |
+
pipe.fit(X_train,y_train,**fit_params)
|
470 |
+
else:
|
471 |
+
X_transformed = pipe.fit_transform(X_train)
|
472 |
+
# X_transformed = pd.DataFrame(X_transformed,columns=pipe[-1].get_column_names_out())
|
473 |
+
self.X_transformed = X_transformed
|
474 |
+
self.pipe_fitted = pipe
|
475 |
+
|
476 |
+
def predict_proba_pipeline(self, X_test = None):
|
477 |
+
"""
|
478 |
+
Method for using a fitted pipeline to compute predicted
|
479 |
+
probabilities for X_test (if supplied) or self.X_val
|
480 |
+
Parameters:
|
481 |
+
-----------
|
482 |
+
X_test : pd.DataFrame or None
|
483 |
+
test data input features (if None, will use self.X_val)
|
484 |
+
"""
|
485 |
+
assert self.pipe is not None, 'No pipeline exists; first build a pipeline using build_pipeline.'
|
486 |
+
assert 'clf' in self.pipe.steps[-1][0], 'The pipeline has no classifier. Build a pipeline with a classifier first.'
|
487 |
+
assert self.pipe_fitted is not None, 'Pipeline is not fitted. First fit pipeline using fit_pipeline.'
|
488 |
+
assert (X_test is not None)|(self.X_val is not None), 'Must either provide X_test and y_test or fit the pipeline with split_first=True.'
|
489 |
+
|
490 |
+
if X_test is None:
|
491 |
+
X_test = self.X_val
|
492 |
+
|
493 |
+
# Restrict to features supplied in self.features
|
494 |
+
X_test = X_test[[feat for feat_type in self.features for feat in self.features[feat_type]]]
|
495 |
+
|
496 |
+
# Save prediction
|
497 |
+
self.y_predict_proba = self.pipe_fitted.predict_proba(X_test)[:,1]
|
498 |
+
|
499 |
+
def score_pipeline(self,y_test=None,scoring='roc_auc',thresh=0.5, beta = 1,
|
500 |
+
normalize = None, print_score = True):
|
501 |
+
"""
|
502 |
+
Method for scoring self.pipe_fitted on supplied test data and reporting score
|
503 |
+
Parameters:
|
504 |
+
-----------
|
505 |
+
y_test : pd.Series or None
|
506 |
+
true binary targets (if None, will use self.y_val)
|
507 |
+
scoring : str
|
508 |
+
specifies the metric to use for scoring
|
509 |
+
must be one of
|
510 |
+
'roc_auc', 'roc_plot', 'acc', 'f1', 'f1w', 'fb','mcc','kappa','conf','classif_report'
|
511 |
+
thresh : float
|
512 |
+
threshhold value for computing y_pred
|
513 |
+
from y_predict_proba
|
514 |
+
beta : float
|
515 |
+
the beta parameter in the fb score
|
516 |
+
normalize : str or None
|
517 |
+
the normalize parameter for the
|
518 |
+
confusion_matrix. must be one of
|
519 |
+
'true','pred','all',None
|
520 |
+
print_score : bool
|
521 |
+
if True, will print a message reporting the score
|
522 |
+
if False, will return the score as a float
|
523 |
+
"""
|
524 |
+
assert (y_test is not None)|(self.y_val is not None), 'Must either provide X_test and y_test or fit the pipeline with split_first=True.'
|
525 |
+
assert self.y_predict_proba is not None, 'Predicted probabilities do not exist. Run predict_proba_pipeline first.'
|
526 |
+
|
527 |
+
if y_test is None:
|
528 |
+
y_test = self.y_val
|
529 |
+
|
530 |
+
# Score and report
|
531 |
+
if scoring == 'roc_plot':
|
532 |
+
fig = plt.figure(figsize=(4,4))
|
533 |
+
ax = fig.add_subplot(111)
|
534 |
+
RocCurveDisplay.from_predictions(y_test,self.y_predict_proba,ax=ax)
|
535 |
+
plt.show()
|
536 |
+
elif scoring == 'roc_auc':
|
537 |
+
score = roc_auc_score(y_test, self.y_predict_proba)
|
538 |
+
else:
|
539 |
+
y_pred = (self.y_predict_proba >= thresh).astype('int')
|
540 |
+
if scoring == 'acc':
|
541 |
+
score = accuracy_score(y_test,y_pred)
|
542 |
+
elif scoring == 'f1':
|
543 |
+
score = f1_score(y_test,y_pred)
|
544 |
+
elif scoring == 'f1w':
|
545 |
+
score = f1_score(y_test,y_pred,average='weighted')
|
546 |
+
elif scoring == 'fb':
|
547 |
+
score = fbeta_score(y_test,y_pred,beta=beta)
|
548 |
+
elif scoring == 'mcc':
|
549 |
+
score = matthews_coffcoeff(y_test,y_pred)
|
550 |
+
elif scoring == 'kappa':
|
551 |
+
score = cohen_kappa_score(y_test,y_pred)
|
552 |
+
elif scoring == 'conf':
|
553 |
+
fig = plt.figure(figsize=(3,3))
|
554 |
+
ax = fig.add_subplot(111)
|
555 |
+
ConfusionMatrixDisplay.from_predictions(y_test,y_pred,ax=ax,colorbar=False)
|
556 |
+
plt.show()
|
557 |
+
elif scoring == 'classif_report':
|
558 |
+
target_names=['neither seriously injured nor killed','seriously injured or killed']
|
559 |
+
print(classification_report(y_test, y_pred,target_names=target_names))
|
560 |
+
else:
|
561 |
+
raise ValueError("scoring must be one of 'roc_auc', 'roc_plot','acc', 'f1', 'f1w', 'fb','mcc','kappa','conf','classif_report'")
|
562 |
+
if scoring not in ['conf','roc_plot','classif_report']:
|
563 |
+
if print_score:
|
564 |
+
print(f'The {scoring} score is: {score}')
|
565 |
+
else:
|
566 |
+
return score
|
567 |
+
|
568 |
+
def shap_values(self, X_test = None, eval_size=0.1,eval_metric='auc'):
|
569 |
+
"""
|
570 |
+
Method for computing and SHAP values for features
|
571 |
+
stratifiedtrain/test split
|
572 |
+
A copy of self.pipe is fitted on the training set
|
573 |
+
and then SHAP values are computed on test set samples
|
574 |
+
Parameters:
|
575 |
+
-----------
|
576 |
+
X_test : pd.DataFrame
|
577 |
+
The test set; if provided, will not perform
|
578 |
+
a train/test split before fitting
|
579 |
+
eval_size : float
|
580 |
+
Fraction of the training set to use for early stopping eval set
|
581 |
+
eval_metric : str
|
582 |
+
eval metric to use in early stopping
|
583 |
+
Returns: None (stores results in self.shap_vals)
|
584 |
+
--------
|
585 |
+
"""
|
586 |
+
assert self.pipe is not None, 'No pipeline exists; first build a pipeline using build_pipeline.'
|
587 |
+
assert (self.X is not None)&(self.y is not None), 'X and/or y does not exist. First supply X and y using set_data.'
|
588 |
+
assert 'clf' in self.pipe.steps[-1][0], 'The pipeline has no classifier. Build a pipeline with a classifier first.'
|
589 |
+
|
590 |
+
|
591 |
+
# Clone pipeline, do train/test split if X_test not provided
|
592 |
+
pipe = clone(self.pipe)
|
593 |
+
X_train = self.X.copy()
|
594 |
+
y_train = self.y.copy()
|
595 |
+
if X_test is None:
|
596 |
+
X_train,X_test,y_train,y_test = train_test_split(X_train,y_train,stratify=y_train,
|
597 |
+
test_size=0.2,random_state=self.random_state)
|
598 |
+
# Restrict to features provided in self.features, and fit
|
599 |
+
X_train = X_train[[feat for feat_type in self.features for feat in self.features[feat_type]]]
|
600 |
+
X_test = X_test[[feat for feat_type in self.features for feat in self.features[feat_type]]]
|
601 |
+
|
602 |
+
# If LGBM early stopping, then need to split off eval_set and define fit_params
|
603 |
+
# if isinstance(self.pipe[-1],LGBMClassifier):
|
604 |
+
# if 'early_stopping_round' in self.pipe[-1].get_params():
|
605 |
+
# if self.pipe[-1].get_params()['early_stopping_rounds'] is not None:
|
606 |
+
# X_train,X_es,y_train,y_es = train_test_split(X_train,y_train,
|
607 |
+
# test_size=eval_size,
|
608 |
+
# stratify=y_train,
|
609 |
+
# random_state=self.random_state)
|
610 |
+
# trans_pipe = self.pipe[:-1]
|
611 |
+
# trans_pipe.fit_transform(X_train)
|
612 |
+
# X_es = trans_pipe.transform(X_es)
|
613 |
+
# clf_name = self.pipe.steps[-1][0]
|
614 |
+
# fit_params = {f'{clf_name}__eval_set':[(X_es,y_es)],
|
615 |
+
# f'{clf_name}__eval_metric':eval_metric,
|
616 |
+
# f'{clf_name}__verbose':0}
|
617 |
+
# else:
|
618 |
+
# fit_params = {}
|
619 |
+
# else:
|
620 |
+
# fit_params = {}
|
621 |
+
fit_params = {}
|
622 |
+
pipe.fit(X_train,y_train,**fit_params)
|
623 |
+
|
624 |
+
# SHAP will just explain classifier, so need transformed X_train and X_test
|
625 |
+
X_train_trans, X_test_trans = pipe[:-1].transform(X_train), pipe[:-1].transform(X_test)
|
626 |
+
|
627 |
+
# Need masker for linear model
|
628 |
+
masker = shap.maskers.Independent(data=X_train_trans)
|
629 |
+
|
630 |
+
# Initialize explainer and compute and store SHAP values as an explainer object
|
631 |
+
explainer = shap.Explainer(pipe[-1], masker = masker, feature_names = pipe['col'].get_feature_names_out())
|
632 |
+
self.shap_vals = explainer(X_test_trans)
|
633 |
+
self.X_shap = X_train_trans
|
634 |
+
self.y_shap = y_train
|
635 |
+
|
636 |
+
def shap_plot(self,max_display='all'):
|
637 |
+
"""
|
638 |
+
Method for generating plots of SHAP value results
|
639 |
+
SHAP values should be already computed previously
|
640 |
+
Generates two plots side by side:
|
641 |
+
- a beeswarm plot of SHAP values of all samples
|
642 |
+
- a barplot of mean absolute SHAP values
|
643 |
+
Parameters:
|
644 |
+
-----------
|
645 |
+
max_display : int or 'all'
|
646 |
+
The number of features to show in the plot, in descending
|
647 |
+
order by mean absolute SHAP value. If 'all', then
|
648 |
+
all features will be included.
|
649 |
+
|
650 |
+
Returns: None (plots displayed)
|
651 |
+
--------
|
652 |
+
"""
|
653 |
+
assert self.shap_vals is not None, 'No shap values exist. First compute shap values.'
|
654 |
+
assert (isinstance(max_display,int))|(max_display=='all'), "'max_display' must be 'all' or an integer"
|
655 |
+
|
656 |
+
if max_display=='all':
|
657 |
+
title_add = ', all features'
|
658 |
+
max_display = self.shap_vals.shape[1]
|
659 |
+
else:
|
660 |
+
title_add = f', top {max_display} features'
|
661 |
+
|
662 |
+
# Plot
|
663 |
+
fig=plt.figure()
|
664 |
+
ax1=fig.add_subplot(121)
|
665 |
+
shap.summary_plot(self.shap_vals,plot_type='bar',max_display=max_display,
|
666 |
+
show=False,plot_size=0.2)
|
667 |
+
ax2=fig.add_subplot(122)
|
668 |
+
shap.summary_plot(self.shap_vals,plot_type='violin',max_display=max_display,
|
669 |
+
show=False,plot_size=0.2)
|
670 |
+
fig.set_size_inches(12,max_display/3)
|
671 |
+
|
672 |
+
ax1.set_title(f'Mean absolute SHAP values'+title_add,fontsize='small')
|
673 |
+
ax1.set_xlabel('mean(|SHAP value|)',fontsize='x-small')
|
674 |
+
ax2.set_title(f'SHAP values'+title_add,fontsize='small')
|
675 |
+
ax2.set_xlabel('SHAP value', fontsize='x-small')
|
676 |
+
for ax in [ax1,ax2]:
|
677 |
+
ax.set_ylabel('feature name',fontsize='x-small')
|
678 |
+
ax.tick_params(axis='y', labelsize='xx-small')
|
679 |
+
plt.tight_layout()
|
680 |
+
plt.show()
|
681 |
+
|
682 |
+
def find_best_threshold(self,beta=1,conf=True,report=True, print_result=True):
|
683 |
+
"""
|
684 |
+
Computes the classification threshold which gives the
|
685 |
+
best F_beta score from classifier predictions,
|
686 |
+
prints the best threshold and the corresponding F_beta score,
|
687 |
+
and displays a confusion matrix and classification report
|
688 |
+
corresponding to that threshold
|
689 |
+
|
690 |
+
Parameters:
|
691 |
+
-----------
|
692 |
+
beta : float
|
693 |
+
the desired beta value in the F_beta score
|
694 |
+
conf : bool
|
695 |
+
whether to display confusion matrix
|
696 |
+
report : bool
|
697 |
+
whether to display classification report
|
698 |
+
print_result : bool
|
699 |
+
whether to print a line reporting the best threshold
|
700 |
+
and resulting F_beta score
|
701 |
+
|
702 |
+
Returns: None (prints results and stores self.best_thresh)
|
703 |
+
--------
|
704 |
+
"""
|
705 |
+
prec,rec,threshs = precision_recall_curve(self.y_val,
|
706 |
+
self.y_predict_proba)
|
707 |
+
F_betas = (1+beta**2)*(prec*rec)/((beta**2*prec)+rec)
|
708 |
+
# Above formula is valid when TP!=0. When TP==0
|
709 |
+
# it gives np.nan whereas F_beta should be 0
|
710 |
+
F_betas = np.nan_to_num(F_betas)
|
711 |
+
idx = np.argmax(F_betas)
|
712 |
+
best_thresh = threshs[idx]
|
713 |
+
if print_result:
|
714 |
+
print(f'Threshold optimizing F_{beta} score: {best_thresh}\nBest F_{beta} score: {F_betas[idx]}')
|
715 |
+
if conf:
|
716 |
+
self.score_pipeline(scoring='conf',thresh=best_thresh,beta=beta)
|
717 |
+
if report:
|
718 |
+
self.score_pipeline(scoring='classif_report',thresh=best_thresh,beta=beta)
|
719 |
+
self.best_thresh = best_thresh
|
720 |
+
|
721 |
+
class LRStudy(ClassifierStudy):
|
722 |
+
"""
|
723 |
+
A child class of ClassifierStudy which has an additional method specific to logistic regression
|
724 |
+
"""
|
725 |
+
def __init__(self, classifier=None, X = None, y = None,
|
726 |
+
features=None,classifier_name = 'LR',
|
727 |
+
random_state=42):
|
728 |
+
super().__init__(classifier, X, y,features,classifier_name,random_state)
|
729 |
+
|
730 |
+
def plot_coeff(self, print_score = True, print_zero = False, title_add=None):
|
731 |
+
"""
|
732 |
+
Method for doing a train/validation split, fitting the classifier,
|
733 |
+
predicting and scoring on the validation set, and plotting
|
734 |
+
a bar chart of the logistic regression coefficients corresponding
|
735 |
+
to various model features.
|
736 |
+
Features with coefficient zero and periodic spline features
|
737 |
+
will be excluded from the chart.
|
738 |
+
Parameters:
|
739 |
+
-----------
|
740 |
+
print_score : bool
|
741 |
+
if True, the validation score are printed
|
742 |
+
print_zero : bool
|
743 |
+
if True, the list of features with zero coefficients are printed
|
744 |
+
title_add : str or None
|
745 |
+
an addendum that is added to the end of the plot title
|
746 |
+
"""
|
747 |
+
assert self.pipe is not None, 'No pipeline exists; first build a pipeline using build_pipeline.'
|
748 |
+
assert isinstance(self.classifier,LogisticRegression),'Your classifier is not an instance of Logistic Regression.'
|
749 |
+
|
750 |
+
# fit and score
|
751 |
+
self.fit_pipeline(split_first = True)
|
752 |
+
self.predict_proba_pipeline()
|
753 |
+
score = roc_auc_score(self.y_val, self.y_predict_proba)
|
754 |
+
|
755 |
+
# Retrieve coeff values from fitted pipeline
|
756 |
+
coeff = pd.DataFrame({'feature name':self.pipe_fitted['col'].get_feature_names_out(),
|
757 |
+
'coeff value':self.pipe_fitted[-1].coef_.reshape(-1)})\
|
758 |
+
.sort_values(by='coeff value')
|
759 |
+
coeff = coeff[~coeff['feature name']\
|
760 |
+
.isin([f'HOUR_OF_DAY_sp_{n}' for n in range(12)]\
|
761 |
+
+[f'DAY_OF_WEEK_sp_{n}' for n in range(3)])]\
|
762 |
+
.set_index('feature name')
|
763 |
+
coeff_zero_features = coeff[coeff['coeff value']==0].index
|
764 |
+
coeff = coeff[coeff['coeff value']!=0]
|
765 |
+
|
766 |
+
# Plot feature coefficients
|
767 |
+
fig = plt.figure(figsize=(30,4))
|
768 |
+
ax = fig.add_subplot(111)
|
769 |
+
coeff['coeff value'].plot(kind='bar',ylabel='coeff value',ax=ax)
|
770 |
+
ax.axhline(y=0, color= 'red', linewidth=2,)
|
771 |
+
plot_title = 'PA bicycle collisions, 2002-2021\nLogistic regression model log-odds coefficients'
|
772 |
+
if title_add is not None:
|
773 |
+
plot_title += f': {title_add}'
|
774 |
+
ax.set_title(plot_title)
|
775 |
+
ax.tick_params(axis='x', labelsize='x-small')
|
776 |
+
plt.show()
|
777 |
+
|
778 |
+
if print_score:
|
779 |
+
print(f'Score on validation set: {score}')
|
780 |
+
if print_zero:
|
781 |
+
print(f'Features with zero coefficients in trained model: {list(coeff_zero)}')
|
782 |
+
|
783 |
+
self.score = score
|
784 |
+
self.coeff = coeff
|
785 |
+
self.coeff_zero_features = coeff_zero_features
|
786 |
+
|
787 |
+
|
lib/.ipynb_checkpoints/transform_data-checkpoint.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import numpy as np
|
3 |
+
from sklearn.preprocessing import FunctionTransformer, SplineTransformer
|
4 |
+
from sklearn.base import BaseEstimator, TransformerMixin
|
5 |
+
from sklearn.utils.validation import check_is_fitted
|
6 |
+
|
7 |
+
class GroupImputer(BaseEstimator, TransformerMixin):
|
8 |
+
"""
|
9 |
+
Class used for imputing missing values in a pd.DataFrame
|
10 |
+
using mean, median, or mode by groupwise aggregation,
|
11 |
+
or a constant.
|
12 |
+
|
13 |
+
Parameters:
|
14 |
+
-----------
|
15 |
+
target : str
|
16 |
+
- The name of the column to be imputed
|
17 |
+
group_cols : list
|
18 |
+
- List of name(s) of columns on which to groupby
|
19 |
+
strategy : str
|
20 |
+
- The method for replacement; can be any of
|
21 |
+
['mean', 'median', 'mode']
|
22 |
+
|
23 |
+
Returns:
|
24 |
+
--------
|
25 |
+
X : pd.DataFrame
|
26 |
+
- The dataframe with imputed values in the target column
|
27 |
+
|
28 |
+
"""
|
29 |
+
def __init__(self,target,group_cols=None,strategy='median'):
|
30 |
+
assert strategy in ['mean','median','mode'], "strategy must be in ['mean', 'median', 'mode']'"
|
31 |
+
assert type(group_cols)==list, 'group_cols must be a list of column names'
|
32 |
+
assert type(target) == str, 'target must be a string'
|
33 |
+
|
34 |
+
self.group_cols = group_cols
|
35 |
+
self.strategy=strategy
|
36 |
+
self.target = target
|
37 |
+
|
38 |
+
def fit(self,X,y=None):
|
39 |
+
|
40 |
+
if self.strategy=='mode':
|
41 |
+
impute_map = X.groupby(self.group_cols)[self.target]\
|
42 |
+
.agg(lambda x: pd.Series.mode(x,dropna=False)[0])\
|
43 |
+
.reset_index(drop=False)
|
44 |
+
else:
|
45 |
+
impute_map = X.groupby(self.group_cols)[self.target]\
|
46 |
+
.agg(self.strategy).reset_index(drop=False)
|
47 |
+
self.impute_map_ = impute_map
|
48 |
+
|
49 |
+
return self
|
50 |
+
|
51 |
+
def transform(self,X,y=None):
|
52 |
+
|
53 |
+
check_is_fitted(self,'impute_map_')
|
54 |
+
|
55 |
+
X=X.copy()
|
56 |
+
|
57 |
+
for index,row in self.impute_map_.iterrows():
|
58 |
+
ind = (X[self.group_cols] == row[self.group_cols]).all(axis=1)
|
59 |
+
X.loc[ind,self.target] = X.loc[ind,self.target].fillna(row[self.target])
|
60 |
+
return X
|
61 |
+
|
62 |
+
# Sine and consine transformations
|
63 |
+
def sin_feature_names(transformer, feature_names):
|
64 |
+
return [f'SIN_{col}' for col in feature_names]
|
65 |
+
def cos_feature_names(transformer, feature_names):
|
66 |
+
return [f'COS_{col}' for col in feature_names]
|
67 |
+
def sin_transformer(period):
|
68 |
+
return FunctionTransformer(lambda x: np.sin(2*np.pi*x/period),feature_names_out = sin_feature_names)
|
69 |
+
def cos_transformer(period):
|
70 |
+
return FunctionTransformer(lambda x: np.cos(2*np.pi*x/period),feature_names_out = cos_feature_names)
|
71 |
+
|
72 |
+
# Periodic spline transformation
|
73 |
+
def periodic_spline_transformer(period, n_splines=None, degree=3):
|
74 |
+
if n_splines is None:
|
75 |
+
n_splines = period
|
76 |
+
n_knots = n_splines + 1 # periodic and include_bias is True
|
77 |
+
return SplineTransformer(
|
78 |
+
degree=degree,
|
79 |
+
n_knots=n_knots,
|
80 |
+
knots=np.linspace(0, period, n_knots).reshape(n_knots, 1),
|
81 |
+
extrapolation="periodic",
|
82 |
+
include_bias=True,
|
83 |
+
)
|
lib/.ipynb_checkpoints/vis_data-checkpoint.py
ADDED
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import numpy as np
|
3 |
+
import plotly.express as px
|
4 |
+
from scipy import stats
|
5 |
+
|
6 |
+
def plot_map(df,city=None,county=None,animate=True,color_dots=True,animate_by='year',show_fig=True,return_fig=False):
|
7 |
+
"""
|
8 |
+
Displays a plotly.express.scatter_mapbox interactive map
|
9 |
+
of crashes in a municipality if specified, or otherwise
|
10 |
+
statewide. Can be animated over time or static.
|
11 |
+
|
12 |
+
Parameters:
|
13 |
+
-----------
|
14 |
+
df : pd.DataFrame
|
15 |
+
dataframe of crash samples
|
16 |
+
city or county : tuple or None
|
17 |
+
if provided, must be a tuple (code,name)
|
18 |
+
- code : str
|
19 |
+
the code corresponding to the desired municipality/county
|
20 |
+
(see the data dictionary)
|
21 |
+
- name : str
|
22 |
+
the name you want to use for the municipality/county
|
23 |
+
in plot title
|
24 |
+
* At most one of these can be not None!
|
25 |
+
animate : bool
|
26 |
+
if animate==True, then the map will animate using
|
27 |
+
the frequency provided in animate_by
|
28 |
+
color_dots : bool
|
29 |
+
if color_dots==True, then dots will be color-coded by
|
30 |
+
'serious injury or death' status.
|
31 |
+
WARNING: if color_dots and animate, then all frames
|
32 |
+
will be missing samples in 'serious injury or death'
|
33 |
+
classes which aren't present in first frame - due to
|
34 |
+
bug in plotly animation_frame implementation.
|
35 |
+
Recommend only using both when geographic
|
36 |
+
area is statewide or at least has all values of
|
37 |
+
'serious injury or death' in first frame
|
38 |
+
animate_by : str
|
39 |
+
the desired animation frequency, must be
|
40 |
+
either 'year' or 'month'
|
41 |
+
show_fig : bool
|
42 |
+
whether to display figure using fig.show()
|
43 |
+
return_fig : bool
|
44 |
+
whether to return the figure object
|
45 |
+
|
46 |
+
Returns: Either figure or None
|
47 |
+
--------
|
48 |
+
"""
|
49 |
+
assert (city is None)|(county is None), 'A city and county cannot both be provided.'
|
50 |
+
# Copy df and create new column for color coding event type
|
51 |
+
df = df.copy()
|
52 |
+
df.loc[df.BICYCLE_SUSP_SERIOUS_INJ_COUNT>0,'Serious cyclist injury or death']='serious injury'
|
53 |
+
df.loc[df.BICYCLE_DEATH_COUNT>0,'Serious cyclist injury or death']='death'
|
54 |
+
df['Serious cyclist injury or death']=df['Serious cyclist injury or death'].fillna('neither')
|
55 |
+
|
56 |
+
# Set animation parameters
|
57 |
+
if animate:
|
58 |
+
if animate_by == 'year':
|
59 |
+
animation_frame = 'CRASH_YEAR'
|
60 |
+
title_animate = ' by year'
|
61 |
+
elif animate_by == 'month':
|
62 |
+
df['DATE'] = pd.to_datetime((df['CRASH_MONTH'].astype('str')\
|
63 |
+
+'-'+df['CRASH_YEAR'].astype('str')),
|
64 |
+
format = "%m-%Y")
|
65 |
+
df=df.sort_values(by='DATE')
|
66 |
+
df['DATE']=df['DATE'].astype('str').apply(lambda x: x.rsplit('-',1)[0])
|
67 |
+
animation_frame = 'DATE'
|
68 |
+
title_animate = ' by month'
|
69 |
+
else:
|
70 |
+
raise ValueError("animate_by must be 'year' or 'month'")
|
71 |
+
else:
|
72 |
+
animation_frame = None
|
73 |
+
title_animate = ''
|
74 |
+
|
75 |
+
if color_dots:
|
76 |
+
color='Serious cyclist injury or death'
|
77 |
+
else:
|
78 |
+
color=None
|
79 |
+
|
80 |
+
# Adjustments for when city or county are provided
|
81 |
+
if city is not None:
|
82 |
+
df = df[df.MUNICIPALITY==city[0]]
|
83 |
+
# Ignore extreme outlier samples - lat,lon may be incorrect
|
84 |
+
df = df[np.abs(stats.zscore(df.DEC_LAT))<=4]
|
85 |
+
df = df[np.abs(stats.zscore(df.DEC_LONG))<=4]
|
86 |
+
title_place = city[1]+', PA'
|
87 |
+
elif county is not None:
|
88 |
+
df = df[df.COUNTY==county[0]]
|
89 |
+
# Ignore extreme outlier samples - lat,lon may be incorrect
|
90 |
+
df = df[np.abs(stats.zscore(df.DEC_LAT))<=4]
|
91 |
+
df = df[np.abs(stats.zscore(df.DEC_LONG))<=4]
|
92 |
+
title_place = county[1]+' county, PA'
|
93 |
+
else:
|
94 |
+
title_place = 'PA'
|
95 |
+
|
96 |
+
# Compute default zoom level based on lat,lon ranges.
|
97 |
+
# open-street-map uses
|
98 |
+
max_lat, min_lat = df.DEC_LAT.max(), df.DEC_LAT.min()
|
99 |
+
max_lon, min_lon = df.DEC_LONG.max(), df.DEC_LONG.min()
|
100 |
+
|
101 |
+
# 2^(zoom) = 360/(longitude width of 1 tile)
|
102 |
+
zoom = np.log2(360/max(max_lon-min_lon,max_lat-min_lat))
|
103 |
+
|
104 |
+
lat_center = (max_lat+min_lat)/2
|
105 |
+
lon_center = (max_lon+min_lon)/2
|
106 |
+
|
107 |
+
# Adjust width so that aspect ratio matches shape of state
|
108 |
+
width_mult = (max_lon-min_lon)/(max_lat-min_lat)
|
109 |
+
cols = ['CRN','DEC_LAT','DEC_LONG','Serious cyclist injury or death','CRASH_YEAR','CRASH_MONTH']
|
110 |
+
if animate_by=='month':
|
111 |
+
cols.append('DATE')
|
112 |
+
# Plot mapbox
|
113 |
+
fig = px.scatter_mapbox(df, lat='DEC_LAT',lon='DEC_LONG',
|
114 |
+
color=color,
|
115 |
+
color_discrete_map={'neither':'royalblue','serious injury':'orange','death':'crimson'},
|
116 |
+
mapbox_style='open-street-map',
|
117 |
+
animation_frame = animation_frame,
|
118 |
+
animation_group='CRN',
|
119 |
+
hover_data = {'DEC_LAT':False,'DEC_LONG':False,
|
120 |
+
'CRASH_YEAR':True,'CRASH_MONTH':True,
|
121 |
+
'Serious cyclist injury or death':True},
|
122 |
+
width = width_mult*500,height=700,zoom=zoom,
|
123 |
+
center={'lat':lat_center,'lon':lon_center},
|
124 |
+
title=f'Crashes involving bicycles{title_animate}<br> in {title_place}, 2002-2021')
|
125 |
+
fig.update_layout(legend=dict(orientation='h',xanchor='right',yanchor='bottom',x=1,y=-0.12),
|
126 |
+
legend_title_side='top')
|
127 |
+
if show_fig:
|
128 |
+
fig.show()
|
129 |
+
if return_fig:
|
130 |
+
return fig
|
131 |
+
|
132 |
+
def feat_perc(feat, df, col_name = 'percentage', feat_name = None):
|
133 |
+
"""
|
134 |
+
Constructs a single-column dataframe 'perc'
|
135 |
+
containing the value counts in the series
|
136 |
+
df[feat] as percentages of the whole.
|
137 |
+
- 'df' is the input dataframe.
|
138 |
+
- 'feat' is the desired column of df.
|
139 |
+
- 'col_name' is the name of the
|
140 |
+
column of the output dataframe
|
141 |
+
- 'feat_name' is the index name
|
142 |
+
of the output dataframe if provided, otherwise
|
143 |
+
will use 'feat' as index name.
|
144 |
+
"""
|
145 |
+
perc = pd.DataFrame({col_name:df[feat].value_counts(normalize=True).sort_index()})
|
146 |
+
if feat_name:
|
147 |
+
perc.index.name=feat_name
|
148 |
+
else:
|
149 |
+
perc.index.name=feat
|
150 |
+
return perc
|
151 |
+
|
152 |
+
def feat_perc_bar(feat,df,feat_name=None,cohort_name=None,show_fig=True,return_fig=False,sort=False):
|
153 |
+
"""
|
154 |
+
Makes barplot of two series:
|
155 |
+
- distribution of feature among all cyclists
|
156 |
+
- distribution of feature among cyclists with serious injury or fatality
|
157 |
+
|
158 |
+
Parameters:
|
159 |
+
-----------
|
160 |
+
feat : str
|
161 |
+
The column name of the desired feature
|
162 |
+
df : pd.DataFrame
|
163 |
+
The input dataframe
|
164 |
+
feat_name : str or None
|
165 |
+
The feature name to use in the
|
166 |
+
x-axis label. If None, will use feat
|
167 |
+
cohort_name : str or None
|
168 |
+
qualifier to use in front of 'cyclists'
|
169 |
+
in titles, if provided, e.g. 'rural cyclists'
|
170 |
+
show_fig : bool
|
171 |
+
whether to finish with fig.show()
|
172 |
+
return_fig : bool
|
173 |
+
whether to return the fig object
|
174 |
+
sort : bool
|
175 |
+
whether to sort bars. If False, will use default sorting
|
176 |
+
by category name or feature value. If True, will resort
|
177 |
+
in descending order by percentage
|
178 |
+
|
179 |
+
Returns: figure or None
|
180 |
+
--------
|
181 |
+
"""
|
182 |
+
if feat_name is None:
|
183 |
+
feat_name=feat
|
184 |
+
df_inj = df.query('SERIOUS_OR_FATALITY==1')
|
185 |
+
table = feat_perc(feat,df)
|
186 |
+
table.loc[:,'cohort']='all'
|
187 |
+
ordering = list(table['percentage'].sort_values(ascending=False).index) if sort else None
|
188 |
+
table_inj = feat_perc(feat,df_inj)
|
189 |
+
table_inj.loc[:,'cohort']='seriously injured or killed'
|
190 |
+
table = pd.concat([table,table_inj],axis=0).reset_index()
|
191 |
+
category_orders = {'cohort':['all','seriously injured or killed']}
|
192 |
+
if sort:
|
193 |
+
category_orders[feat]=ordering
|
194 |
+
fig = px.bar(table,y='cohort',x='percentage',color=feat,
|
195 |
+
barmode='stack',text_auto='.1%',
|
196 |
+
category_orders=category_orders,
|
197 |
+
title=f'Distributions of {feat} values within cyclist cohorts')
|
198 |
+
fig.update_yaxes(tickangle=-90)
|
199 |
+
fig.update_xaxes(tickformat=".0%")
|
200 |
+
if show_fig:
|
201 |
+
fig.show()
|
202 |
+
if return_fig:
|
203 |
+
return fig
|
204 |
+
|
205 |
+
# def feat_perc_comp(feat,df,feat_name=None,cohort_name = None,merge_inj_death=True):
|
206 |
+
# """
|
207 |
+
# Returns a styled dataframe (Styler object)
|
208 |
+
# whose underlying dataframe has three columns
|
209 |
+
# containing value counts of 'feat' among:
|
210 |
+
# - all cyclists involved in crashes
|
211 |
+
# - cyclists suffering serious injury or fatality
|
212 |
+
# each formatted as percentages of the series sum.
|
213 |
+
# Styled with bars comparing percentages
|
214 |
+
|
215 |
+
# Parameters:
|
216 |
+
# -----------
|
217 |
+
# feat : str
|
218 |
+
# The column name of the desired feature
|
219 |
+
# df : pd.DataFrame
|
220 |
+
# The input dataframe
|
221 |
+
# feat_name : str or None
|
222 |
+
# The feature name to use in the output dataframe
|
223 |
+
# index name. If None, will use feat
|
224 |
+
# cohort_name : str or None
|
225 |
+
# qualifier to use in front of 'cyclists'
|
226 |
+
# in titles, if provided, e.g. 'rural cyclists'
|
227 |
+
# merge_inj_death : bool
|
228 |
+
# whether to merge seriously injured and killed cohorts
|
229 |
+
# Returns:
|
230 |
+
# --------
|
231 |
+
# perc_comp : pd.Styler object
|
232 |
+
# """
|
233 |
+
# # Need qualifier for titles if restricting cyclist cohort
|
234 |
+
# qualifier = cohort_name if cohort_name is not None else ''
|
235 |
+
|
236 |
+
# # Two columns or three, depending on merge_inj_death
|
237 |
+
# if merge_inj_death:
|
238 |
+
# perc_comp = feat_perc(feat,df=df,feat_name=feat_name,
|
239 |
+
# col_name='all cyclists',)\
|
240 |
+
# .merge(feat_perc(feat,feat_name=feat_name,
|
241 |
+
# df=df.query('SERIOUS_OR_FATALITY==1'),
|
242 |
+
# col_name=qualifier+'cyclists with serious injury or fatality'),
|
243 |
+
# on=feat,how='left')
|
244 |
+
# perc_comp = perc_comp[perc_comp.max(axis=1)>=0.005]
|
245 |
+
# else:
|
246 |
+
# perc_comp = feat_perc(feat,df=df,feat_name=feat_name,
|
247 |
+
# col_name='all cyclists')\
|
248 |
+
# .merge(feat_perc(feat,feat_name=feat_name,
|
249 |
+
# df=df.query('INJ_SEVERITY=="susp_serious_injury"'),
|
250 |
+
# col_name=qualifier+'cyclists with serious injury'),
|
251 |
+
# on=feat,how='left')\
|
252 |
+
# .merge(feat_perc(feat,feat_name=feat_name,
|
253 |
+
# df=df.query('INJ_SEVERITY=="killed"'),
|
254 |
+
# col_name=qualifier+'cyclists with fatality'),
|
255 |
+
# on=feat,how='left')
|
256 |
+
|
257 |
+
# # If feature is not ordinal, sort rows descending by crash counts
|
258 |
+
# if feat not in ['AGE_BINS','SPEED_LIMIT','DAY_OF_WEEK','HOUR_OF_DAY']:
|
259 |
+
# perc_comp=perc_comp.sort_values(by='all cyclists',ascending=False)
|
260 |
+
|
261 |
+
# # Relabel day numbers with strings
|
262 |
+
# if feat == 'DAY_OF_WEEK':
|
263 |
+
# perc_comp.index=['Sun','Mon','Tues','Wed','Thurs','Fri','Sat']
|
264 |
+
# perc_comp.index.name='DAY_OF_WEEK'
|
265 |
+
# perc_comp=perc_comp.fillna(0)
|
266 |
+
# table_columns = list(perc_comp.columns)
|
267 |
+
|
268 |
+
# # Define format for displaying floats
|
269 |
+
# format_dict={col:'{:.2%}' for col in perc_comp.columns}
|
270 |
+
|
271 |
+
|
272 |
+
# # Define table styles
|
273 |
+
# styles = [dict(selector="caption",
|
274 |
+
# props=[("text-align", "center"),
|
275 |
+
# ("font-size", "100%"),
|
276 |
+
# ("color", 'black'),
|
277 |
+
# ("text-decoration","underline"),
|
278 |
+
# ("font-weight","bold")])]
|
279 |
+
|
280 |
+
# # Return formatted dataframe
|
281 |
+
# if feat_name is None:
|
282 |
+
# feat_name=feat
|
283 |
+
# caption = f'Breakdown of {feat_name} among cyclist groups'
|
284 |
+
# return perc_comp.reset_index().style.set_table_attributes("style='display:inline'")\
|
285 |
+
# .format(format_dict).bar(color='powderblue',
|
286 |
+
# subset=table_columns).hide().set_caption(caption)\
|
287 |
+
# .set_table_styles(styles)
|
lib/__init__.py
ADDED
File without changes
|
lib/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (222 Bytes). View file
|
|
lib/__pycache__/get_data.cpython-310.pyc
ADDED
Binary file (2.81 kB). View file
|
|
lib/__pycache__/study_class.cpython-310.pyc
ADDED
Binary file (28.3 kB). View file
|
|
lib/__pycache__/study_classif.cpython-310.pyc
ADDED
Binary file (28.9 kB). View file
|
|
lib/__pycache__/transform_data.cpython-310.pyc
ADDED
Binary file (3.88 kB). View file
|
|
lib/__pycache__/vis_data.cpython-310.pyc
ADDED
Binary file (6.64 kB). View file
|
|
lib/study_classif.py
ADDED
@@ -0,0 +1,787 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import pandas as pd
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
import seaborn as sns
|
5 |
+
import shap
|
6 |
+
from sklearn.feature_selection import chi2, SelectKBest, mutual_info_classif, f_classif
|
7 |
+
from sklearn.metrics import accuracy_score, log_loss, confusion_matrix, f1_score, fbeta_score, roc_auc_score
|
8 |
+
from sklearn.metrics import ConfusionMatrixDisplay, RocCurveDisplay, classification_report, precision_recall_curve
|
9 |
+
from sklearn.model_selection import train_test_split, RepeatedStratifiedKFold, cross_val_score, RandomizedSearchCV, StratifiedKFold
|
10 |
+
from sklearn.preprocessing import OneHotEncoder, OrdinalEncoder, StandardScaler, FunctionTransformer, SplineTransformer, PolynomialFeatures
|
11 |
+
from sklearn.decomposition import PCA
|
12 |
+
from sklearn.linear_model import LogisticRegression
|
13 |
+
from sklearn.ensemble import HistGradientBoostingClassifier, GradientBoostingClassifier
|
14 |
+
# from lightgbm import LGBMClassifier
|
15 |
+
from sklearn.base import BaseEstimator, TransformerMixin, clone
|
16 |
+
from sklearn.utils.validation import check_is_fitted
|
17 |
+
from sklearn.impute import SimpleImputer
|
18 |
+
from sklearn.pipeline import Pipeline, make_pipeline
|
19 |
+
from sklearn.compose import ColumnTransformer, make_column_transformer
|
20 |
+
from lib.transform_data import *
|
21 |
+
|
22 |
+
class ClassifierStudy():
|
23 |
+
"""
|
24 |
+
A class that contains tools for studying a classifier pipeline
|
25 |
+
|
26 |
+
Parameters:
|
27 |
+
-----------
|
28 |
+
classifier : a scikit-learn compatible binary classifier
|
29 |
+
X : pd.DataFrame
|
30 |
+
dataframe of features
|
31 |
+
y : pd.Series
|
32 |
+
series of binary target values corresponding to X
|
33 |
+
classifier_name : str or None
|
34 |
+
if provided, will use as classifier name in pipeline
|
35 |
+
if not, will use 'clf' as name
|
36 |
+
features : dict
|
37 |
+
a dictionary whose keys are the feature types
|
38 |
+
'cyc','cat','ord','num','bin' and whose values
|
39 |
+
are lists of features of each type.
|
40 |
+
|
41 |
+
Methods:
|
42 |
+
-------
|
43 |
+
set_data, set_features, set_state
|
44 |
+
sets or resets attributes of self
|
45 |
+
build_pipeline
|
46 |
+
builds out pipeline based on supplied specs
|
47 |
+
cv_score
|
48 |
+
runs k-fold cross validation and reports scores
|
49 |
+
randomized_search
|
50 |
+
runs randomized search with cross validation
|
51 |
+
and reports results
|
52 |
+
fit_pipeline
|
53 |
+
fits the model pipeline and stores as
|
54 |
+
self.pipe_fitted
|
55 |
+
predict_proba_pipeline
|
56 |
+
uses a fitted pipeline to compute predicted
|
57 |
+
probabilities for test or validation set
|
58 |
+
score_pipeline
|
59 |
+
scores predicted probabilities
|
60 |
+
|
61 |
+
"""
|
62 |
+
def __init__(self, classifier=None, X = None, y = None,
|
63 |
+
features = None,classifier_name = None,
|
64 |
+
random_state=42):
|
65 |
+
self.classifier = classifier
|
66 |
+
if X is not None:
|
67 |
+
self.X = X.copy()
|
68 |
+
if y is not None:
|
69 |
+
self.y = y.copy()
|
70 |
+
if features is not None:
|
71 |
+
self.features = features.copy()
|
72 |
+
self.random_state=random_state
|
73 |
+
self.pipe, self.pipe_fitted = None, None
|
74 |
+
self.classifier_name = classifier_name
|
75 |
+
self.X_val, self.y_val = None, None
|
76 |
+
self.y_predict_proba = None
|
77 |
+
self.best_params, self.best_n_components = None, None
|
78 |
+
self.shap_vals = None
|
79 |
+
|
80 |
+
def set_data(self,X=None,y=None):
|
81 |
+
"""Method to set or reset feature and/or target data"""
|
82 |
+
if X is not None:
|
83 |
+
self.X = X.copy()
|
84 |
+
if y is not None:
|
85 |
+
self.y = y.copy()
|
86 |
+
|
87 |
+
def set_features(self,features):
|
88 |
+
"""Method to set or reset the feature dictionary"""
|
89 |
+
if features is not None:
|
90 |
+
self.features = features.copy()
|
91 |
+
|
92 |
+
def set_state(self,random_state):
|
93 |
+
"""Method to set or reset the random_state"""
|
94 |
+
self.random_state = random_state
|
95 |
+
|
96 |
+
def build_pipeline(self, cat_method = 'onehot',cyc_method = 'spline',num_ss=True,
|
97 |
+
over_sample = False, pca=False,n_components=None,
|
98 |
+
select_features = False,score_func=None,k='all',
|
99 |
+
poly_features = False, degree=2, interaction_only=False):
|
100 |
+
"""
|
101 |
+
Method to build the model pipeline
|
102 |
+
Parameters:
|
103 |
+
-----------
|
104 |
+
cat_method : str
|
105 |
+
specifies whether to encode categorical
|
106 |
+
variables as one-hot vectors or ordinals
|
107 |
+
must be either 'onehot' or 'ord'
|
108 |
+
cyc_method : str
|
109 |
+
specifies whether to encode cyclical features
|
110 |
+
with sine/cosine encoding or periodic splines
|
111 |
+
must be one of 'trig', 'spline', 'interact-trig',
|
112 |
+
'interact-spline','onehot', 'ord', or None
|
113 |
+
- If 'trig' or 'spline', will set up periodic encoder
|
114 |
+
with desired method
|
115 |
+
- If 'onehot' or 'ord', will set up appropriate
|
116 |
+
categorical encoder
|
117 |
+
- If 'interact-{method}', will use <method> encoding for HOUR_OF_DAY,
|
118 |
+
encode DAY_OF_WEEK as a binary feature expressing whether
|
119 |
+
the day is a weekend day, and then include interaction
|
120 |
+
features among this set via PolynomialFeatures.
|
121 |
+
- If None, will leave out cyclical features altogether
|
122 |
+
num_ss : bool
|
123 |
+
Whether or not to apply StandardScaler on the numerical features
|
124 |
+
over_sample : bool
|
125 |
+
set to True to include imblearn.over_sampling.RandomOverSampler step
|
126 |
+
pca : bool
|
127 |
+
set to True to include sklearn.decomposition.PCA step
|
128 |
+
n_components : int or None
|
129 |
+
number of components for sklearn.decomposition.PCA
|
130 |
+
select_features : bool
|
131 |
+
set to True to include sklearn.feature_selection.SelectKBest step
|
132 |
+
score_func : callable
|
133 |
+
score function to use for sklearn.feature_selection.SelectKBest
|
134 |
+
recommended: chi2, f_classif, or mutual_info_classif
|
135 |
+
k : int or 'all'
|
136 |
+
number of features for sklearn.feature_selection.SelectKBest
|
137 |
+
poly_features : bool
|
138 |
+
set to True to include sklearn.preprocessing.PolynomialFeatures step
|
139 |
+
degree : int
|
140 |
+
max degree for sklearn.preprocessing.PolynomialFeatures
|
141 |
+
interaction_only : bool
|
142 |
+
whether or not sklearn.preprocessing.PolynomialFeatures will be limited
|
143 |
+
to interaction terms only
|
144 |
+
"""
|
145 |
+
|
146 |
+
# Define transformer for categorical features
|
147 |
+
if cat_method == 'onehot':
|
148 |
+
cat_encoder = ('ohe',OneHotEncoder(handle_unknown='infrequent_if_exist'))
|
149 |
+
|
150 |
+
elif cat_method == 'ord':
|
151 |
+
cat_encoder = ('oe',OrdinalEncoder(handle_unknown='use_encoded_value',unknown_value=np.nan))
|
152 |
+
else:
|
153 |
+
raise ValueError("cat_method must be either 'onehot' or 'ord'")
|
154 |
+
|
155 |
+
cat_transform = Pipeline([('si',SimpleImputer(strategy='most_frequent')),cat_encoder])
|
156 |
+
|
157 |
+
# Define transformer for cyclic features
|
158 |
+
cyc_dict = {'HOUR_OF_DAY':24,'DAY_OF_WEEK':7}
|
159 |
+
if cyc_method == 'trig':
|
160 |
+
cyc_transform = [(f'{feat}_cos',cos_transformer(cyc_dict[feat]),[feat]) for feat in self.features['cyc']]+\
|
161 |
+
[(f'{feat}_sin',sin_transformer(cyc_dict[feat]),[feat]) for feat in self.features['cyc']]
|
162 |
+
elif cyc_method =='spline':
|
163 |
+
cyc_transform = [(f'{feat}_cyclic',
|
164 |
+
periodic_spline_transformer(cyc_dict[feat],n_splines=cyc_dict[feat]//2),
|
165 |
+
[feat]) for feat in self.features['cyc']]
|
166 |
+
elif cyc_method == 'onehot':
|
167 |
+
cyc_encoder = ('ohe_cyc',OneHotEncoder(handle_unknown='infrequent_if_exist'))
|
168 |
+
cyc_transform = [('cyc',Pipeline([cyc_encoder]),self.features['cyc'])]
|
169 |
+
elif cyc_method == 'ord':
|
170 |
+
cyc_encoder = ('oe_cyc',OrdinalEncoder(handle_unknown='use_encoded_value',unknown_value=np.nan))
|
171 |
+
cyc_transform = [('cyc',Pipeline([cyc_encoder]),self.features['cyc'])]
|
172 |
+
elif cyc_method == 'interact-spline':
|
173 |
+
hour_transform = (f'hour_cyc',periodic_spline_transformer(cyc_dict['HOUR_OF_DAY'],n_splines=12),['HOUR_OF_DAY'])
|
174 |
+
wkend_transform = ('wkend',FunctionTransformer(lambda x: (x.isin([1,7])).astype(int)),['DAY_OF_WEEK'])
|
175 |
+
cyc_transform = [('cyc',Pipeline([('cyc_col',ColumnTransformer([hour_transform, wkend_transform],
|
176 |
+
remainder='drop',verbose_feature_names_out=False)),
|
177 |
+
('cyc_poly',PolynomialFeatures(degree=2,interaction_only=True,
|
178 |
+
include_bias=False))]),
|
179 |
+
self.features['cyc'])]
|
180 |
+
elif cyc_method == 'interact-trig':
|
181 |
+
hour_transform = [(f'HOUR_cos',cos_transformer(cyc_dict['HOUR_OF_DAY']),['HOUR_OF_DAY']),
|
182 |
+
(f'HOUR_sin',sin_transformer(cyc_dict['HOUR_OF_DAY']),['HOUR_OF_DAY'])]
|
183 |
+
wkend_transform = ('wkend',FunctionTransformer(lambda x: (x.isin([1,7])).astype(int)),['DAY_OF_WEEK'])
|
184 |
+
cyc_transform = [('cyc',Pipeline([('cyc_col',ColumnTransformer(hour_transform+[wkend_transform],
|
185 |
+
remainder='drop',verbose_feature_names_out=False)),
|
186 |
+
('cyc_poly',PolynomialFeatures(degree=2,interaction_only=True,
|
187 |
+
include_bias=False))]),
|
188 |
+
self.features['cyc'])]
|
189 |
+
elif cyc_method is None:
|
190 |
+
cyc_transform = [('cyc','passthrough',[])]
|
191 |
+
else:
|
192 |
+
raise ValueError("cyc_method must be one of 'trig','spline','interact','onehot','ord',or None")
|
193 |
+
|
194 |
+
# Define numerical transform
|
195 |
+
num_transform = ('num',StandardScaler(),self.features['num']) if num_ss else\
|
196 |
+
('num','passthrough',self.features['num'])
|
197 |
+
|
198 |
+
# Define column transformer
|
199 |
+
col_transform = ColumnTransformer([('cat',cat_transform,self.features['cat']),
|
200 |
+
('ord','passthrough',self.features['ord']),
|
201 |
+
num_transform,
|
202 |
+
('bin',SimpleImputer(strategy='most_frequent'),
|
203 |
+
self.features['bin'])]+\
|
204 |
+
cyc_transform,
|
205 |
+
remainder='drop',verbose_feature_names_out=False)
|
206 |
+
|
207 |
+
steps = [('col',col_transform)]
|
208 |
+
|
209 |
+
if 'AGE' in self.features['num']:
|
210 |
+
steps.insert(0,('gi_age',GroupImputer(target = 'AGE', group_cols=['COUNTY'],strategy='median')))
|
211 |
+
if 'HOUR_OF_DAY' in self.features['cyc']:
|
212 |
+
steps.insert(0,('gi_hour',GroupImputer(target = 'HOUR_OF_DAY', group_cols=['ILLUMINATION','CRASH_MONTH'],strategy='mode')))
|
213 |
+
# Insert optional steps as needed
|
214 |
+
if over_sample:
|
215 |
+
steps.insert(0,('os',RandomOverSampler(random_state=self.random_state)))
|
216 |
+
if poly_features:
|
217 |
+
steps.append(('pf',PolynomialFeatures(degree=degree,interaction_only=interaction_only)))
|
218 |
+
if select_features:
|
219 |
+
steps.append(('fs',SelectKBest(score_func = score_func, k = k)))
|
220 |
+
if pca:
|
221 |
+
steps.append(('pca',PCA(n_components=n_components,random_state=self.random_state)))
|
222 |
+
# Append classifier if provided
|
223 |
+
if self.classifier is not None:
|
224 |
+
if self.classifier_name is not None:
|
225 |
+
steps.append((f'{self.classifier_name}_clf',self.classifier))
|
226 |
+
else:
|
227 |
+
steps.append(('clf',self.classifier))
|
228 |
+
|
229 |
+
# Initialize pipeline
|
230 |
+
self.pipe = Pipeline(steps)
|
231 |
+
|
232 |
+
def cv_score(self, scoring = 'roc_auc', n_splits = 5, n_repeats=3, thresh = 0.5, beta = 1,
|
233 |
+
return_mean_score=False,print_mean_score=True,print_scores=False, n_jobs=-1,
|
234 |
+
eval_size=0.1,eval_metric='auc'):
|
235 |
+
"""
|
236 |
+
Method for performing cross validation via RepeatedStratifiedKFold
|
237 |
+
|
238 |
+
Parameters:
|
239 |
+
-----------
|
240 |
+
scoring : str
|
241 |
+
scoring function to use. must be one of
|
242 |
+
'roc_auc','acc','f1','','f1w'
|
243 |
+
thresh : float
|
244 |
+
the classification threshold for computing y_pred
|
245 |
+
from y_pred_proba
|
246 |
+
beta : float
|
247 |
+
the beta-value to use in the f_beta score, if chosen
|
248 |
+
n_splits, n_repeats : int, int
|
249 |
+
number of splits and number of repeat iterations
|
250 |
+
for sklearn.model_selection.RepeatedStratifiedKFold
|
251 |
+
return_mean_score : bool
|
252 |
+
whether or not to return the mean score
|
253 |
+
print_mean_score : bool
|
254 |
+
whether to print out a report of the mean score
|
255 |
+
print_scores : bool
|
256 |
+
whether to print out a report of CV scores for all folds
|
257 |
+
n_jobs : int or None
|
258 |
+
number of CPU cores to use for parallel processing
|
259 |
+
-1 uses all available cores, and None defaults to 1
|
260 |
+
eval_size : float
|
261 |
+
Fraction of the training set to use for early stopping eval set
|
262 |
+
eval_metric : str
|
263 |
+
eval metric to use in early stopping
|
264 |
+
Returns: None or mean_score, depending on return_mean_score setting
|
265 |
+
--------
|
266 |
+
"""
|
267 |
+
assert self.pipe is not None, 'No pipeline exists; first build a pipeline using build_pipeline.'
|
268 |
+
assert (self.X is not None)&(self.y is not None), 'X and/or y does not exist. First supply X and y using set_data.'
|
269 |
+
assert 'clf' in self.pipe.steps[-1][0], 'The pipeline has no classifier. Build a pipeline with a classifier first.'
|
270 |
+
assert scoring in ['roc_auc','acc','f1','fb','f1w'],"scoring must be one of 'roc_auc','acc','f1','fb','f1w'"
|
271 |
+
|
272 |
+
# Initialize CV iterator
|
273 |
+
kf = RepeatedStratifiedKFold(n_splits = n_splits, n_repeats=n_repeats,
|
274 |
+
random_state=self.random_state)
|
275 |
+
# Restrict to features supplied in self.features
|
276 |
+
X = self.X[[feat for feat_type in self.features for feat in self.features[feat_type]]]
|
277 |
+
|
278 |
+
lgb_es=False
|
279 |
+
# if isinstance(self.pipe[-1],LGBMClassifier):
|
280 |
+
# if 'early_stopping_round' in self.pipe[-1].get_params():
|
281 |
+
# if self.pipe[-1].get_params()['early_stopping_rounds'] is not None:
|
282 |
+
# lgb_es=True
|
283 |
+
|
284 |
+
scores = []
|
285 |
+
# Iterate over folds and train, predict, score
|
286 |
+
for i,(train_idx,test_idx) in enumerate(kf.split(X,self.y)):
|
287 |
+
fold_X_train = X.iloc[train_idx,:]
|
288 |
+
fold_X_test = X.iloc[test_idx,:]
|
289 |
+
fold_y_train = self.y.iloc[train_idx]
|
290 |
+
fold_y_test = self.y.iloc[test_idx]
|
291 |
+
|
292 |
+
pipe=clone(self.pipe)
|
293 |
+
if lgb_es:
|
294 |
+
fold_X_train,fold_X_es,fold_y_train,fold_y_es = train_test_split(fold_X_train,fold_y_train,
|
295 |
+
stratify=fold_y_train,test_size=eval_size,
|
296 |
+
random_state=self.random_state)
|
297 |
+
trans_pipe = pipe[:-1]
|
298 |
+
trans_pipe.fit_transform(fold_X_train)
|
299 |
+
fold_X_es = trans_pipe.transform(fold_X_es)
|
300 |
+
clf_name = pipe.steps[-1][0]
|
301 |
+
fit_params = {f'{clf_name}__eval_set':[(fold_X_es,fold_y_es)],
|
302 |
+
f'{clf_name}__eval_metric':eval_metric,
|
303 |
+
f'{clf_name}__verbose':0}
|
304 |
+
else:
|
305 |
+
fit_params = {}
|
306 |
+
|
307 |
+
pipe.fit(fold_X_train,fold_y_train,**fit_params)
|
308 |
+
fold_y_pred_proba = pipe.predict_proba(fold_X_test)[:,1]
|
309 |
+
|
310 |
+
if scoring == 'roc_auc':
|
311 |
+
fold_score = roc_auc_score(fold_y_test, fold_y_pred_proba)
|
312 |
+
else:
|
313 |
+
fold_y_pred = (fold_y_pred_proba >= thresh).astype('int')
|
314 |
+
if scoring == 'acc':
|
315 |
+
fold_score = accuracy_score(fold_y_test,fold_y_pred)
|
316 |
+
elif scoring == 'f1':
|
317 |
+
fold_score = f1_score(fold_y_test,fold_y_pred)
|
318 |
+
elif scoring == 'f1w':
|
319 |
+
fold_score = f1_score(fold_y_test,fold_y_pred,average='weighted')
|
320 |
+
else:
|
321 |
+
fold_score = fbeta_score(fold_y_test,fold_y_pred,beta=beta)
|
322 |
+
scores.append(fold_score)
|
323 |
+
|
324 |
+
# Average and report
|
325 |
+
mean_score = np.mean(scores)
|
326 |
+
if print_scores:
|
327 |
+
print(f'CV scores using {scoring} score: {scores} \nMean score: {mean_score}')
|
328 |
+
if print_mean_score:
|
329 |
+
print(f'Mean CV {scoring} score: {mean_score}')
|
330 |
+
if return_mean_score:
|
331 |
+
return mean_score
|
332 |
+
|
333 |
+
def randomized_search(self, params, n_components = None, n_iter=10,
|
334 |
+
scoring='roc_auc',cv=5,refit=False,top_n=10, n_jobs=-1):
|
335 |
+
"""
|
336 |
+
Method for performing randomized search with cross validation on a given dictionary of parameter distributions
|
337 |
+
Also displays a table of results the best top_n iterations
|
338 |
+
|
339 |
+
Parameters:
|
340 |
+
----------
|
341 |
+
params : dict
|
342 |
+
parameter distributions to use for RandomizedSearchCV
|
343 |
+
n_components : int, or list, or None
|
344 |
+
number of components for sklearn.decomposition.PCA
|
345 |
+
- if int, will reset the PCA layer in self.pipe with provided value
|
346 |
+
- if list, must be list of ints, which will be included in
|
347 |
+
RandomizedSearchCV parameter distribution
|
348 |
+
scoring : str
|
349 |
+
scoring function for sklearn.model_selection.cross_val_score
|
350 |
+
n_iter : int
|
351 |
+
number of iterations to use in RandomizedSearchCV
|
352 |
+
refit : bool
|
353 |
+
whether to refit a final classifier with best parameters
|
354 |
+
- if False, will only set self.best_params and self.best_score
|
355 |
+
- if True, will set self.best_estimator in addition
|
356 |
+
top_n : int or None
|
357 |
+
if int, will display results from top_n best iterations only
|
358 |
+
if None, will display all results
|
359 |
+
n_jobs : int or None
|
360 |
+
number of CPU cores to use for parallel processing
|
361 |
+
-1 uses all available cores, and None defaults to 1
|
362 |
+
"""
|
363 |
+
assert self.pipe is not None, 'No pipeline exists; first build a pipeline using build_pipeline.'
|
364 |
+
assert (self.X is not None)&(self.y is not None), 'X and/or y does not exist. First supply X and y using set_data.'
|
365 |
+
assert 'clf' in self.pipe.steps[-1][0], 'The pipeline has no classifier. Build a pipeline with a classifier first.'
|
366 |
+
assert (n_components is None)|('pca' in self.pipe.named_steps), 'Your pipeline has no PCA step. Build a pipeline with PCA first.'
|
367 |
+
assert (len(params)>0)|(type(n_components)==list), 'Either pass a parameter distribution or a list of n_components values.'
|
368 |
+
|
369 |
+
# Add estimator name prefix to hyperparams
|
370 |
+
params = {self.pipe.steps[-1][0]+'__'+key:params[key] for key in params}
|
371 |
+
|
372 |
+
# Process supplied n_components
|
373 |
+
if type(n_components)==list:
|
374 |
+
params['pca__n_components']=n_components
|
375 |
+
elif type(n_components)==int:
|
376 |
+
self.pipe['pca'].set_params(n_components=n_components)
|
377 |
+
|
378 |
+
# Restrict to features supplied in self.features
|
379 |
+
X = self.X[[feat for feat_type in self.features for feat in self.features[feat_type]]]
|
380 |
+
|
381 |
+
# Initialize rs and fit
|
382 |
+
rs = RandomizedSearchCV(self.pipe, param_distributions = params,
|
383 |
+
n_iter=n_iter, scoring = scoring, cv = cv,refit=refit,
|
384 |
+
random_state=self.random_state, n_jobs=n_jobs)
|
385 |
+
|
386 |
+
rs.fit(X,self.y)
|
387 |
+
|
388 |
+
# Display top n scores
|
389 |
+
results = rs.cv_results_
|
390 |
+
results_df = pd.DataFrame(results['params'])
|
391 |
+
param_names = list(results_df.columns)
|
392 |
+
results_df[f'mean cv score ({scoring})']=pd.Series(results['mean_test_score'])
|
393 |
+
results_df = results_df.set_index(param_names).sort_values(by=f'mean cv score ({scoring})',ascending=False)
|
394 |
+
if top_n is not None:
|
395 |
+
display(results_df.head(top_n).style\
|
396 |
+
.highlight_max(axis=0, props='color:white; font-weight:bold; background-color:seagreen;'))
|
397 |
+
else:
|
398 |
+
display(results_df.style\
|
399 |
+
.highlight_max(axis=0, props='color:white; font-weight:bold; background-color:seagreen;'))
|
400 |
+
if refit:
|
401 |
+
self.best_estimator = rs.best_estimator_
|
402 |
+
best_params = rs.best_params_
|
403 |
+
self.best_params = {key.split('__')[-1]:best_params[key] for key in best_params if key.split('__')[0]!='pca'}
|
404 |
+
self.best_n_components = next((best_params[key] for key in best_params if key.split('__')[0]=='pca'), None)
|
405 |
+
self.best_score = rs.best_score_
|
406 |
+
|
407 |
+
def fit_pipeline(self,split_first=False, eval_size=0.1,eval_metric='auc'):
|
408 |
+
"""
|
409 |
+
Method for fitting self.pipeline on self.X,self.y
|
410 |
+
Parameters:
|
411 |
+
-----------
|
412 |
+
split_first : bool
|
413 |
+
if True, a train_test_split will be performed first
|
414 |
+
and the validation set will be stored
|
415 |
+
early_stopping : bool
|
416 |
+
Indicates whether we will use early_stopping for lightgbm.
|
417 |
+
If true, will split off an eval set prior to k-fold split
|
418 |
+
eval_size : float
|
419 |
+
Fraction of the training set to use for early stopping eval set
|
420 |
+
eval_metric : str
|
421 |
+
eval metric to use in early stopping
|
422 |
+
"""
|
423 |
+
# Need pipe and X to fit
|
424 |
+
assert self.pipe is not None, 'No pipeline exists; first build a pipeline using build_pipeline.'
|
425 |
+
assert self.X is not None, 'X does not exist. First set X.'
|
426 |
+
|
427 |
+
# If no y provided, then no pipeline steps should require y
|
428 |
+
step_list = [step[0] for step in self.pipe.steps]
|
429 |
+
assert (('clf' not in step_list[-1])&('kf' not in step_list))|(self.y is not None), 'You must provide targets y if pipeline has a classifier step or feature selection step.'
|
430 |
+
|
431 |
+
# Don't need to do a train-test split without a classifier
|
432 |
+
assert (split_first==False)|('clf' in step_list[-1]), 'Only need train-test split if you have a classifier.'
|
433 |
+
|
434 |
+
if split_first:
|
435 |
+
X_train,X_val,y_train,y_val = train_test_split(self.X,self.y,stratify=self.y,
|
436 |
+
test_size=0.2,random_state=self.random_state)
|
437 |
+
self.X_val = X_val
|
438 |
+
self.y_val = y_val
|
439 |
+
else:
|
440 |
+
X_train = self.X.copy()
|
441 |
+
if self.y is not None:
|
442 |
+
y_train = self.y.copy()
|
443 |
+
|
444 |
+
# Restrict to features supplied in self.features
|
445 |
+
X_train = X_train[[feat for feat_type in self.features for feat in self.features[feat_type]]]
|
446 |
+
|
447 |
+
# If LGBM early stopping, then need to split off eval_set and define fit_params
|
448 |
+
# if isinstance(self.pipe[-1],LGBMClassifier):
|
449 |
+
# if self.pipe[-1].get_params()['early_stopping_rounds'] is not None:
|
450 |
+
# X_train,X_es,y_train,y_es = train_test_split(X_train,y_train,
|
451 |
+
# test_size=eval_size,
|
452 |
+
# stratify=y_train,
|
453 |
+
# random_state=self.random_state)
|
454 |
+
# trans_pipe = self.pipe[:-1]
|
455 |
+
# trans_pipe.fit_transform(X_train)
|
456 |
+
# X_es = trans_pipe.transform(X_es)
|
457 |
+
# clf_name = self.pipe.steps[-1][0]
|
458 |
+
# fit_params = {f'{clf_name}__eval_set':[(X_es,y_es)],
|
459 |
+
# f'{clf_name}__eval_metric':eval_metric,
|
460 |
+
# f'{clf_name}__verbose':0}
|
461 |
+
# else:
|
462 |
+
# fit_params = {}
|
463 |
+
# else:
|
464 |
+
# fit_params = {}
|
465 |
+
fit_params = {}
|
466 |
+
# Fit and store fitted pipeline. If no classifier, fit_transform X_train and store transformed version
|
467 |
+
pipe = self.pipe
|
468 |
+
if 'clf' in step_list[-1]:
|
469 |
+
pipe.fit(X_train,y_train,**fit_params)
|
470 |
+
else:
|
471 |
+
X_transformed = pipe.fit_transform(X_train)
|
472 |
+
# X_transformed = pd.DataFrame(X_transformed,columns=pipe[-1].get_column_names_out())
|
473 |
+
self.X_transformed = X_transformed
|
474 |
+
self.pipe_fitted = pipe
|
475 |
+
|
476 |
+
def predict_proba_pipeline(self, X_test = None):
|
477 |
+
"""
|
478 |
+
Method for using a fitted pipeline to compute predicted
|
479 |
+
probabilities for X_test (if supplied) or self.X_val
|
480 |
+
Parameters:
|
481 |
+
-----------
|
482 |
+
X_test : pd.DataFrame or None
|
483 |
+
test data input features (if None, will use self.X_val)
|
484 |
+
"""
|
485 |
+
assert self.pipe is not None, 'No pipeline exists; first build a pipeline using build_pipeline.'
|
486 |
+
assert 'clf' in self.pipe.steps[-1][0], 'The pipeline has no classifier. Build a pipeline with a classifier first.'
|
487 |
+
assert self.pipe_fitted is not None, 'Pipeline is not fitted. First fit pipeline using fit_pipeline.'
|
488 |
+
assert (X_test is not None)|(self.X_val is not None), 'Must either provide X_test and y_test or fit the pipeline with split_first=True.'
|
489 |
+
|
490 |
+
if X_test is None:
|
491 |
+
X_test = self.X_val
|
492 |
+
|
493 |
+
# Restrict to features supplied in self.features
|
494 |
+
X_test = X_test[[feat for feat_type in self.features for feat in self.features[feat_type]]]
|
495 |
+
|
496 |
+
# Save prediction
|
497 |
+
self.y_predict_proba = self.pipe_fitted.predict_proba(X_test)[:,1]
|
498 |
+
|
499 |
+
def score_pipeline(self,y_test=None,scoring='roc_auc',thresh=0.5, beta = 1,
|
500 |
+
normalize = None, print_score = True):
|
501 |
+
"""
|
502 |
+
Method for scoring self.pipe_fitted on supplied test data and reporting score
|
503 |
+
Parameters:
|
504 |
+
-----------
|
505 |
+
y_test : pd.Series or None
|
506 |
+
true binary targets (if None, will use self.y_val)
|
507 |
+
scoring : str
|
508 |
+
specifies the metric to use for scoring
|
509 |
+
must be one of
|
510 |
+
'roc_auc', 'roc_plot', 'acc', 'f1', 'f1w', 'fb','mcc','kappa','conf','classif_report'
|
511 |
+
thresh : float
|
512 |
+
threshhold value for computing y_pred
|
513 |
+
from y_predict_proba
|
514 |
+
beta : float
|
515 |
+
the beta parameter in the fb score
|
516 |
+
normalize : str or None
|
517 |
+
the normalize parameter for the
|
518 |
+
confusion_matrix. must be one of
|
519 |
+
'true','pred','all',None
|
520 |
+
print_score : bool
|
521 |
+
if True, will print a message reporting the score
|
522 |
+
if False, will return the score as a float
|
523 |
+
"""
|
524 |
+
assert (y_test is not None)|(self.y_val is not None), 'Must either provide X_test and y_test or fit the pipeline with split_first=True.'
|
525 |
+
assert self.y_predict_proba is not None, 'Predicted probabilities do not exist. Run predict_proba_pipeline first.'
|
526 |
+
|
527 |
+
if y_test is None:
|
528 |
+
y_test = self.y_val
|
529 |
+
|
530 |
+
# Score and report
|
531 |
+
if scoring == 'roc_plot':
|
532 |
+
fig = plt.figure(figsize=(4,4))
|
533 |
+
ax = fig.add_subplot(111)
|
534 |
+
RocCurveDisplay.from_predictions(y_test,self.y_predict_proba,ax=ax)
|
535 |
+
plt.show()
|
536 |
+
elif scoring == 'roc_auc':
|
537 |
+
score = roc_auc_score(y_test, self.y_predict_proba)
|
538 |
+
else:
|
539 |
+
y_pred = (self.y_predict_proba >= thresh).astype('int')
|
540 |
+
if scoring == 'acc':
|
541 |
+
score = accuracy_score(y_test,y_pred)
|
542 |
+
elif scoring == 'f1':
|
543 |
+
score = f1_score(y_test,y_pred)
|
544 |
+
elif scoring == 'f1w':
|
545 |
+
score = f1_score(y_test,y_pred,average='weighted')
|
546 |
+
elif scoring == 'fb':
|
547 |
+
score = fbeta_score(y_test,y_pred,beta=beta)
|
548 |
+
elif scoring == 'mcc':
|
549 |
+
score = matthews_coffcoeff(y_test,y_pred)
|
550 |
+
elif scoring == 'kappa':
|
551 |
+
score = cohen_kappa_score(y_test,y_pred)
|
552 |
+
elif scoring == 'conf':
|
553 |
+
fig = plt.figure(figsize=(3,3))
|
554 |
+
ax = fig.add_subplot(111)
|
555 |
+
ConfusionMatrixDisplay.from_predictions(y_test,y_pred,ax=ax,colorbar=False)
|
556 |
+
plt.show()
|
557 |
+
elif scoring == 'classif_report':
|
558 |
+
target_names=['neither seriously injured nor killed','seriously injured or killed']
|
559 |
+
print(classification_report(y_test, y_pred,target_names=target_names))
|
560 |
+
else:
|
561 |
+
raise ValueError("scoring must be one of 'roc_auc', 'roc_plot','acc', 'f1', 'f1w', 'fb','mcc','kappa','conf','classif_report'")
|
562 |
+
if scoring not in ['conf','roc_plot','classif_report']:
|
563 |
+
if print_score:
|
564 |
+
print(f'The {scoring} score is: {score}')
|
565 |
+
else:
|
566 |
+
return score
|
567 |
+
|
568 |
+
def shap_values(self, X_test = None, eval_size=0.1,eval_metric='auc'):
|
569 |
+
"""
|
570 |
+
Method for computing and SHAP values for features
|
571 |
+
stratifiedtrain/test split
|
572 |
+
A copy of self.pipe is fitted on the training set
|
573 |
+
and then SHAP values are computed on test set samples
|
574 |
+
Parameters:
|
575 |
+
-----------
|
576 |
+
X_test : pd.DataFrame
|
577 |
+
The test set; if provided, will not perform
|
578 |
+
a train/test split before fitting
|
579 |
+
eval_size : float
|
580 |
+
Fraction of the training set to use for early stopping eval set
|
581 |
+
eval_metric : str
|
582 |
+
eval metric to use in early stopping
|
583 |
+
Returns: None (stores results in self.shap_vals)
|
584 |
+
--------
|
585 |
+
"""
|
586 |
+
assert self.pipe is not None, 'No pipeline exists; first build a pipeline using build_pipeline.'
|
587 |
+
assert (self.X is not None)&(self.y is not None), 'X and/or y does not exist. First supply X and y using set_data.'
|
588 |
+
assert 'clf' in self.pipe.steps[-1][0], 'The pipeline has no classifier. Build a pipeline with a classifier first.'
|
589 |
+
|
590 |
+
|
591 |
+
# Clone pipeline, do train/test split if X_test not provided
|
592 |
+
pipe = clone(self.pipe)
|
593 |
+
X_train = self.X.copy()
|
594 |
+
y_train = self.y.copy()
|
595 |
+
if X_test is None:
|
596 |
+
X_train,X_test,y_train,y_test = train_test_split(X_train,y_train,stratify=y_train,
|
597 |
+
test_size=0.2,random_state=self.random_state)
|
598 |
+
# Restrict to features provided in self.features, and fit
|
599 |
+
X_train = X_train[[feat for feat_type in self.features for feat in self.features[feat_type]]]
|
600 |
+
X_test = X_test[[feat for feat_type in self.features for feat in self.features[feat_type]]]
|
601 |
+
|
602 |
+
# If LGBM early stopping, then need to split off eval_set and define fit_params
|
603 |
+
# if isinstance(self.pipe[-1],LGBMClassifier):
|
604 |
+
# if 'early_stopping_round' in self.pipe[-1].get_params():
|
605 |
+
# if self.pipe[-1].get_params()['early_stopping_rounds'] is not None:
|
606 |
+
# X_train,X_es,y_train,y_es = train_test_split(X_train,y_train,
|
607 |
+
# test_size=eval_size,
|
608 |
+
# stratify=y_train,
|
609 |
+
# random_state=self.random_state)
|
610 |
+
# trans_pipe = self.pipe[:-1]
|
611 |
+
# trans_pipe.fit_transform(X_train)
|
612 |
+
# X_es = trans_pipe.transform(X_es)
|
613 |
+
# clf_name = self.pipe.steps[-1][0]
|
614 |
+
# fit_params = {f'{clf_name}__eval_set':[(X_es,y_es)],
|
615 |
+
# f'{clf_name}__eval_metric':eval_metric,
|
616 |
+
# f'{clf_name}__verbose':0}
|
617 |
+
# else:
|
618 |
+
# fit_params = {}
|
619 |
+
# else:
|
620 |
+
# fit_params = {}
|
621 |
+
fit_params = {}
|
622 |
+
pipe.fit(X_train,y_train,**fit_params)
|
623 |
+
|
624 |
+
# SHAP will just explain classifier, so need transformed X_train and X_test
|
625 |
+
X_train_trans, X_test_trans = pipe[:-1].transform(X_train), pipe[:-1].transform(X_test)
|
626 |
+
|
627 |
+
# Need masker for linear model
|
628 |
+
masker = shap.maskers.Independent(data=X_train_trans)
|
629 |
+
|
630 |
+
# Initialize explainer and compute and store SHAP values as an explainer object
|
631 |
+
explainer = shap.Explainer(pipe[-1], masker = masker, feature_names = pipe['col'].get_feature_names_out())
|
632 |
+
self.shap_vals = explainer(X_test_trans)
|
633 |
+
self.X_shap = X_train_trans
|
634 |
+
self.y_shap = y_train
|
635 |
+
|
636 |
+
def shap_plot(self,max_display='all'):
|
637 |
+
"""
|
638 |
+
Method for generating plots of SHAP value results
|
639 |
+
SHAP values should be already computed previously
|
640 |
+
Generates two plots side by side:
|
641 |
+
- a beeswarm plot of SHAP values of all samples
|
642 |
+
- a barplot of mean absolute SHAP values
|
643 |
+
Parameters:
|
644 |
+
-----------
|
645 |
+
max_display : int or 'all'
|
646 |
+
The number of features to show in the plot, in descending
|
647 |
+
order by mean absolute SHAP value. If 'all', then
|
648 |
+
all features will be included.
|
649 |
+
|
650 |
+
Returns: None (plots displayed)
|
651 |
+
--------
|
652 |
+
"""
|
653 |
+
assert self.shap_vals is not None, 'No shap values exist. First compute shap values.'
|
654 |
+
assert (isinstance(max_display,int))|(max_display=='all'), "'max_display' must be 'all' or an integer"
|
655 |
+
|
656 |
+
if max_display=='all':
|
657 |
+
title_add = ', all features'
|
658 |
+
max_display = self.shap_vals.shape[1]
|
659 |
+
else:
|
660 |
+
title_add = f', top {max_display} features'
|
661 |
+
|
662 |
+
# Plot
|
663 |
+
fig=plt.figure()
|
664 |
+
ax1=fig.add_subplot(121)
|
665 |
+
shap.summary_plot(self.shap_vals,plot_type='bar',max_display=max_display,
|
666 |
+
show=False,plot_size=0.2)
|
667 |
+
ax2=fig.add_subplot(122)
|
668 |
+
shap.summary_plot(self.shap_vals,plot_type='violin',max_display=max_display,
|
669 |
+
show=False,plot_size=0.2)
|
670 |
+
fig.set_size_inches(12,max_display/3)
|
671 |
+
|
672 |
+
ax1.set_title(f'Mean absolute SHAP values'+title_add,fontsize='small')
|
673 |
+
ax1.set_xlabel('mean(|SHAP value|)',fontsize='x-small')
|
674 |
+
ax2.set_title(f'SHAP values'+title_add,fontsize='small')
|
675 |
+
ax2.set_xlabel('SHAP value', fontsize='x-small')
|
676 |
+
for ax in [ax1,ax2]:
|
677 |
+
ax.set_ylabel('feature name',fontsize='x-small')
|
678 |
+
ax.tick_params(axis='y', labelsize='xx-small')
|
679 |
+
plt.tight_layout()
|
680 |
+
plt.show()
|
681 |
+
|
682 |
+
def find_best_threshold(self,beta=1,conf=True,report=True, print_result=True):
|
683 |
+
"""
|
684 |
+
Computes the classification threshold which gives the
|
685 |
+
best F_beta score from classifier predictions,
|
686 |
+
prints the best threshold and the corresponding F_beta score,
|
687 |
+
and displays a confusion matrix and classification report
|
688 |
+
corresponding to that threshold
|
689 |
+
|
690 |
+
Parameters:
|
691 |
+
-----------
|
692 |
+
beta : float
|
693 |
+
the desired beta value in the F_beta score
|
694 |
+
conf : bool
|
695 |
+
whether to display confusion matrix
|
696 |
+
report : bool
|
697 |
+
whether to display classification report
|
698 |
+
print_result : bool
|
699 |
+
whether to print a line reporting the best threshold
|
700 |
+
and resulting F_beta score
|
701 |
+
|
702 |
+
Returns: None (prints results and stores self.best_thresh)
|
703 |
+
--------
|
704 |
+
"""
|
705 |
+
prec,rec,threshs = precision_recall_curve(self.y_val,
|
706 |
+
self.y_predict_proba)
|
707 |
+
F_betas = (1+beta**2)*(prec*rec)/((beta**2*prec)+rec)
|
708 |
+
# Above formula is valid when TP!=0. When TP==0
|
709 |
+
# it gives np.nan whereas F_beta should be 0
|
710 |
+
F_betas = np.nan_to_num(F_betas)
|
711 |
+
idx = np.argmax(F_betas)
|
712 |
+
best_thresh = threshs[idx]
|
713 |
+
if print_result:
|
714 |
+
print(f'Threshold optimizing F_{beta} score: {best_thresh}\nBest F_{beta} score: {F_betas[idx]}')
|
715 |
+
if conf:
|
716 |
+
self.score_pipeline(scoring='conf',thresh=best_thresh,beta=beta)
|
717 |
+
if report:
|
718 |
+
self.score_pipeline(scoring='classif_report',thresh=best_thresh,beta=beta)
|
719 |
+
self.best_thresh = best_thresh
|
720 |
+
|
721 |
+
class LRStudy(ClassifierStudy):
|
722 |
+
"""
|
723 |
+
A child class of ClassifierStudy which has an additional method specific to logistic regression
|
724 |
+
"""
|
725 |
+
def __init__(self, classifier=None, X = None, y = None,
|
726 |
+
features=None,classifier_name = 'LR',
|
727 |
+
random_state=42):
|
728 |
+
super().__init__(classifier, X, y,features,classifier_name,random_state)
|
729 |
+
|
730 |
+
def plot_coeff(self, print_score = True, print_zero = False, title_add=None):
|
731 |
+
"""
|
732 |
+
Method for doing a train/validation split, fitting the classifier,
|
733 |
+
predicting and scoring on the validation set, and plotting
|
734 |
+
a bar chart of the logistic regression coefficients corresponding
|
735 |
+
to various model features.
|
736 |
+
Features with coefficient zero and periodic spline features
|
737 |
+
will be excluded from the chart.
|
738 |
+
Parameters:
|
739 |
+
-----------
|
740 |
+
print_score : bool
|
741 |
+
if True, the validation score are printed
|
742 |
+
print_zero : bool
|
743 |
+
if True, the list of features with zero coefficients are printed
|
744 |
+
title_add : str or None
|
745 |
+
an addendum that is added to the end of the plot title
|
746 |
+
"""
|
747 |
+
assert self.pipe is not None, 'No pipeline exists; first build a pipeline using build_pipeline.'
|
748 |
+
assert isinstance(self.classifier,LogisticRegression),'Your classifier is not an instance of Logistic Regression.'
|
749 |
+
|
750 |
+
# fit and score
|
751 |
+
self.fit_pipeline(split_first = True)
|
752 |
+
self.predict_proba_pipeline()
|
753 |
+
score = roc_auc_score(self.y_val, self.y_predict_proba)
|
754 |
+
|
755 |
+
# Retrieve coeff values from fitted pipeline
|
756 |
+
coeff = pd.DataFrame({'feature name':self.pipe_fitted['col'].get_feature_names_out(),
|
757 |
+
'coeff value':self.pipe_fitted[-1].coef_.reshape(-1)})\
|
758 |
+
.sort_values(by='coeff value')
|
759 |
+
coeff = coeff[~coeff['feature name']\
|
760 |
+
.isin([f'HOUR_OF_DAY_sp_{n}' for n in range(12)]\
|
761 |
+
+[f'DAY_OF_WEEK_sp_{n}' for n in range(3)])]\
|
762 |
+
.set_index('feature name')
|
763 |
+
coeff_zero_features = coeff[coeff['coeff value']==0].index
|
764 |
+
coeff = coeff[coeff['coeff value']!=0]
|
765 |
+
|
766 |
+
# Plot feature coefficients
|
767 |
+
fig = plt.figure(figsize=(30,4))
|
768 |
+
ax = fig.add_subplot(111)
|
769 |
+
coeff['coeff value'].plot(kind='bar',ylabel='coeff value',ax=ax)
|
770 |
+
ax.axhline(y=0, color= 'red', linewidth=2,)
|
771 |
+
plot_title = 'PA bicycle collisions, 2002-2021\nLogistic regression model log-odds coefficients'
|
772 |
+
if title_add is not None:
|
773 |
+
plot_title += f': {title_add}'
|
774 |
+
ax.set_title(plot_title)
|
775 |
+
ax.tick_params(axis='x', labelsize='x-small')
|
776 |
+
plt.show()
|
777 |
+
|
778 |
+
if print_score:
|
779 |
+
print(f'Score on validation set: {score}')
|
780 |
+
if print_zero:
|
781 |
+
print(f'Features with zero coefficients in trained model: {list(coeff_zero)}')
|
782 |
+
|
783 |
+
self.score = score
|
784 |
+
self.coeff = coeff
|
785 |
+
self.coeff_zero_features = coeff_zero_features
|
786 |
+
|
787 |
+
|
lib/transform_data.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import numpy as np
|
3 |
+
from sklearn.preprocessing import FunctionTransformer, SplineTransformer
|
4 |
+
from sklearn.base import BaseEstimator, TransformerMixin
|
5 |
+
from sklearn.utils.validation import check_is_fitted
|
6 |
+
|
7 |
+
class GroupImputer(BaseEstimator, TransformerMixin):
|
8 |
+
"""
|
9 |
+
Class used for imputing missing values in a pd.DataFrame
|
10 |
+
using mean, median, or mode by groupwise aggregation,
|
11 |
+
or a constant.
|
12 |
+
|
13 |
+
Parameters:
|
14 |
+
-----------
|
15 |
+
target : str
|
16 |
+
- The name of the column to be imputed
|
17 |
+
group_cols : list
|
18 |
+
- List of name(s) of columns on which to groupby
|
19 |
+
strategy : str
|
20 |
+
- The method for replacement; can be any of
|
21 |
+
['mean', 'median', 'mode']
|
22 |
+
|
23 |
+
Returns:
|
24 |
+
--------
|
25 |
+
X : pd.DataFrame
|
26 |
+
- The dataframe with imputed values in the target column
|
27 |
+
|
28 |
+
"""
|
29 |
+
def __init__(self,target,group_cols=None,strategy='median'):
|
30 |
+
assert strategy in ['mean','median','mode'], "strategy must be in ['mean', 'median', 'mode']'"
|
31 |
+
assert type(group_cols)==list, 'group_cols must be a list of column names'
|
32 |
+
assert type(target) == str, 'target must be a string'
|
33 |
+
|
34 |
+
self.group_cols = group_cols
|
35 |
+
self.strategy=strategy
|
36 |
+
self.target = target
|
37 |
+
|
38 |
+
def fit(self,X,y=None):
|
39 |
+
|
40 |
+
if self.strategy=='mode':
|
41 |
+
impute_map = X.groupby(self.group_cols)[self.target]\
|
42 |
+
.agg(lambda x: pd.Series.mode(x,dropna=False)[0])\
|
43 |
+
.reset_index(drop=False)
|
44 |
+
else:
|
45 |
+
impute_map = X.groupby(self.group_cols)[self.target]\
|
46 |
+
.agg(self.strategy).reset_index(drop=False)
|
47 |
+
self.impute_map_ = impute_map
|
48 |
+
|
49 |
+
return self
|
50 |
+
|
51 |
+
def transform(self,X,y=None):
|
52 |
+
|
53 |
+
check_is_fitted(self,'impute_map_')
|
54 |
+
|
55 |
+
X=X.copy()
|
56 |
+
|
57 |
+
for index,row in self.impute_map_.iterrows():
|
58 |
+
ind = (X[self.group_cols] == row[self.group_cols]).all(axis=1)
|
59 |
+
X.loc[ind,self.target] = X.loc[ind,self.target].fillna(row[self.target])
|
60 |
+
return X
|
61 |
+
|
62 |
+
# Sine and consine transformations
|
63 |
+
def sin_feature_names(transformer, feature_names):
|
64 |
+
return [f'SIN_{col}' for col in feature_names]
|
65 |
+
def cos_feature_names(transformer, feature_names):
|
66 |
+
return [f'COS_{col}' for col in feature_names]
|
67 |
+
def sin_transformer(period):
|
68 |
+
return FunctionTransformer(lambda x: np.sin(2*np.pi*x/period),feature_names_out = sin_feature_names)
|
69 |
+
def cos_transformer(period):
|
70 |
+
return FunctionTransformer(lambda x: np.cos(2*np.pi*x/period),feature_names_out = cos_feature_names)
|
71 |
+
|
72 |
+
# Periodic spline transformation
|
73 |
+
def periodic_spline_transformer(period, n_splines=None, degree=3):
|
74 |
+
if n_splines is None:
|
75 |
+
n_splines = period
|
76 |
+
n_knots = n_splines + 1 # periodic and include_bias is True
|
77 |
+
return SplineTransformer(
|
78 |
+
degree=degree,
|
79 |
+
n_knots=n_knots,
|
80 |
+
knots=np.linspace(0, period, n_knots).reshape(n_knots, 1),
|
81 |
+
extrapolation="periodic",
|
82 |
+
include_bias=True,
|
83 |
+
)
|
lib/vis_data.py
ADDED
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import numpy as np
|
3 |
+
import plotly.express as px
|
4 |
+
from scipy import stats
|
5 |
+
|
6 |
+
def plot_map(df,city=None,county=None,animate=True,color_dots=True,animate_by='year',show_fig=True,return_fig=False):
|
7 |
+
"""
|
8 |
+
Displays a plotly.express.scatter_mapbox interactive map
|
9 |
+
of crashes in a municipality if specified, or otherwise
|
10 |
+
statewide. Can be animated over time or static.
|
11 |
+
|
12 |
+
Parameters:
|
13 |
+
-----------
|
14 |
+
df : pd.DataFrame
|
15 |
+
dataframe of crash samples
|
16 |
+
city or county : tuple or None
|
17 |
+
if provided, must be a tuple (code,name)
|
18 |
+
- code : str
|
19 |
+
the code corresponding to the desired municipality/county
|
20 |
+
(see the data dictionary)
|
21 |
+
- name : str
|
22 |
+
the name you want to use for the municipality/county
|
23 |
+
in plot title
|
24 |
+
* At most one of these can be not None!
|
25 |
+
animate : bool
|
26 |
+
if animate==True, then the map will animate using
|
27 |
+
the frequency provided in animate_by
|
28 |
+
color_dots : bool
|
29 |
+
if color_dots==True, then dots will be color-coded by
|
30 |
+
'serious injury or death' status.
|
31 |
+
WARNING: if color_dots and animate, then all frames
|
32 |
+
will be missing samples in 'serious injury or death'
|
33 |
+
classes which aren't present in first frame - due to
|
34 |
+
bug in plotly animation_frame implementation.
|
35 |
+
Recommend only using both when geographic
|
36 |
+
area is statewide or at least has all values of
|
37 |
+
'serious injury or death' in first frame
|
38 |
+
animate_by : str
|
39 |
+
the desired animation frequency, must be
|
40 |
+
either 'year' or 'month'
|
41 |
+
show_fig : bool
|
42 |
+
whether to display figure using fig.show()
|
43 |
+
return_fig : bool
|
44 |
+
whether to return the figure object
|
45 |
+
|
46 |
+
Returns: Either figure or None
|
47 |
+
--------
|
48 |
+
"""
|
49 |
+
assert (city is None)|(county is None), 'A city and county cannot both be provided.'
|
50 |
+
# Copy df and create new column for color coding event type
|
51 |
+
df = df.copy()
|
52 |
+
df.loc[df.BICYCLE_SUSP_SERIOUS_INJ_COUNT>0,'Serious cyclist injury or death']='serious injury'
|
53 |
+
df.loc[df.BICYCLE_DEATH_COUNT>0,'Serious cyclist injury or death']='death'
|
54 |
+
df['Serious cyclist injury or death']=df['Serious cyclist injury or death'].fillna('neither')
|
55 |
+
|
56 |
+
# Set animation parameters
|
57 |
+
if animate:
|
58 |
+
if animate_by == 'year':
|
59 |
+
animation_frame = 'CRASH_YEAR'
|
60 |
+
title_animate = ' by year'
|
61 |
+
elif animate_by == 'month':
|
62 |
+
df['DATE'] = pd.to_datetime((df['CRASH_MONTH'].astype('str')\
|
63 |
+
+'-'+df['CRASH_YEAR'].astype('str')),
|
64 |
+
format = "%m-%Y")
|
65 |
+
df=df.sort_values(by='DATE')
|
66 |
+
df['DATE']=df['DATE'].astype('str').apply(lambda x: x.rsplit('-',1)[0])
|
67 |
+
animation_frame = 'DATE'
|
68 |
+
title_animate = ' by month'
|
69 |
+
else:
|
70 |
+
raise ValueError("animate_by must be 'year' or 'month'")
|
71 |
+
else:
|
72 |
+
animation_frame = None
|
73 |
+
title_animate = ''
|
74 |
+
|
75 |
+
if color_dots:
|
76 |
+
color='Serious cyclist injury or death'
|
77 |
+
else:
|
78 |
+
color=None
|
79 |
+
|
80 |
+
# Adjustments for when city or county are provided
|
81 |
+
if city is not None:
|
82 |
+
df = df[df.MUNICIPALITY==city[0]]
|
83 |
+
# Ignore extreme outlier samples - lat,lon may be incorrect
|
84 |
+
df = df[np.abs(stats.zscore(df.DEC_LAT))<=4]
|
85 |
+
df = df[np.abs(stats.zscore(df.DEC_LONG))<=4]
|
86 |
+
title_place = city[1]+', PA'
|
87 |
+
elif county is not None:
|
88 |
+
df = df[df.COUNTY==county[0]]
|
89 |
+
# Ignore extreme outlier samples - lat,lon may be incorrect
|
90 |
+
df = df[np.abs(stats.zscore(df.DEC_LAT))<=4]
|
91 |
+
df = df[np.abs(stats.zscore(df.DEC_LONG))<=4]
|
92 |
+
title_place = county[1]+' county, PA'
|
93 |
+
else:
|
94 |
+
title_place = 'PA'
|
95 |
+
|
96 |
+
# Compute default zoom level based on lat,lon ranges.
|
97 |
+
# open-street-map uses
|
98 |
+
max_lat, min_lat = df.DEC_LAT.max(), df.DEC_LAT.min()
|
99 |
+
max_lon, min_lon = df.DEC_LONG.max(), df.DEC_LONG.min()
|
100 |
+
|
101 |
+
# 2^(zoom) = 360/(longitude width of 1 tile)
|
102 |
+
zoom = np.log2(360/max(max_lon-min_lon,max_lat-min_lat))
|
103 |
+
|
104 |
+
lat_center = (max_lat+min_lat)/2
|
105 |
+
lon_center = (max_lon+min_lon)/2
|
106 |
+
|
107 |
+
# Adjust width so that aspect ratio matches shape of state
|
108 |
+
width_mult = (max_lon-min_lon)/(max_lat-min_lat)
|
109 |
+
cols = ['CRN','DEC_LAT','DEC_LONG','Serious cyclist injury or death','CRASH_YEAR','CRASH_MONTH']
|
110 |
+
if animate_by=='month':
|
111 |
+
cols.append('DATE')
|
112 |
+
# Plot mapbox
|
113 |
+
fig = px.scatter_mapbox(df, lat='DEC_LAT',lon='DEC_LONG',
|
114 |
+
color=color,
|
115 |
+
color_discrete_map={'neither':'royalblue','serious injury':'orange','death':'crimson'},
|
116 |
+
mapbox_style='open-street-map',
|
117 |
+
animation_frame = animation_frame,
|
118 |
+
animation_group='CRN',
|
119 |
+
hover_data = {'DEC_LAT':False,'DEC_LONG':False,
|
120 |
+
'CRASH_YEAR':True,'CRASH_MONTH':True,
|
121 |
+
'Serious cyclist injury or death':True},
|
122 |
+
width = width_mult*500,height=700,zoom=zoom,
|
123 |
+
center={'lat':lat_center,'lon':lon_center},
|
124 |
+
title=f'Crashes involving bicycles{title_animate}<br> in {title_place}, 2002-2021')
|
125 |
+
fig.update_layout(legend=dict(orientation='h',xanchor='right',yanchor='bottom',x=1,y=-0.12),
|
126 |
+
legend_title_side='top')
|
127 |
+
if show_fig:
|
128 |
+
fig.show()
|
129 |
+
if return_fig:
|
130 |
+
return fig
|
131 |
+
|
132 |
+
def feat_perc(feat, df, col_name = 'percentage', feat_name = None):
|
133 |
+
"""
|
134 |
+
Constructs a single-column dataframe 'perc'
|
135 |
+
containing the value counts in the series
|
136 |
+
df[feat] as percentages of the whole.
|
137 |
+
- 'df' is the input dataframe.
|
138 |
+
- 'feat' is the desired column of df.
|
139 |
+
- 'col_name' is the name of the
|
140 |
+
column of the output dataframe
|
141 |
+
- 'feat_name' is the index name
|
142 |
+
of the output dataframe if provided, otherwise
|
143 |
+
will use 'feat' as index name.
|
144 |
+
"""
|
145 |
+
perc = pd.DataFrame({col_name:df[feat].value_counts(normalize=True).sort_index()})
|
146 |
+
if feat_name:
|
147 |
+
perc.index.name=feat_name
|
148 |
+
else:
|
149 |
+
perc.index.name=feat
|
150 |
+
return perc
|
151 |
+
|
152 |
+
def feat_perc_bar(feat,df,feat_name=None,cohort_name=None,show_fig=True,return_fig=False,sort=False):
|
153 |
+
"""
|
154 |
+
Makes barplot of two series:
|
155 |
+
- distribution of feature among all cyclists
|
156 |
+
- distribution of feature among cyclists with serious injury or fatality
|
157 |
+
|
158 |
+
Parameters:
|
159 |
+
-----------
|
160 |
+
feat : str
|
161 |
+
The column name of the desired feature
|
162 |
+
df : pd.DataFrame
|
163 |
+
The input dataframe
|
164 |
+
feat_name : str or None
|
165 |
+
The feature name to use in the
|
166 |
+
x-axis label. If None, will use feat
|
167 |
+
cohort_name : str or None
|
168 |
+
qualifier to use in front of 'cyclists'
|
169 |
+
in titles, if provided, e.g. 'rural cyclists'
|
170 |
+
show_fig : bool
|
171 |
+
whether to finish with fig.show()
|
172 |
+
return_fig : bool
|
173 |
+
whether to return the fig object
|
174 |
+
sort : bool
|
175 |
+
whether to sort bars. If False, will use default sorting
|
176 |
+
by category name or feature value. If True, will resort
|
177 |
+
in descending order by percentage
|
178 |
+
|
179 |
+
Returns: figure or None
|
180 |
+
--------
|
181 |
+
"""
|
182 |
+
if feat_name is None:
|
183 |
+
feat_name=feat
|
184 |
+
df_inj = df.query('SERIOUS_OR_FATALITY==1')
|
185 |
+
table = feat_perc(feat,df)
|
186 |
+
table.loc[:,'cohort']='all'
|
187 |
+
ordering = list(table['percentage'].sort_values(ascending=False).index) if sort else None
|
188 |
+
table_inj = feat_perc(feat,df_inj)
|
189 |
+
table_inj.loc[:,'cohort']='seriously injured or killed'
|
190 |
+
table = pd.concat([table,table_inj],axis=0).reset_index()
|
191 |
+
category_orders = {'cohort':['all','seriously injured or killed']}
|
192 |
+
if sort:
|
193 |
+
category_orders[feat]=ordering
|
194 |
+
fig = px.bar(table,y='cohort',x='percentage',color=feat,
|
195 |
+
barmode='stack',text_auto='.1%',
|
196 |
+
category_orders=category_orders,
|
197 |
+
title=f'Distributions of {feat} values within cyclist cohorts')
|
198 |
+
fig.update_yaxes(tickangle=-90)
|
199 |
+
fig.update_xaxes(tickformat=".0%")
|
200 |
+
if show_fig:
|
201 |
+
fig.show()
|
202 |
+
if return_fig:
|
203 |
+
return fig
|
204 |
+
|
205 |
+
# def feat_perc_comp(feat,df,feat_name=None,cohort_name = None,merge_inj_death=True):
|
206 |
+
# """
|
207 |
+
# Returns a styled dataframe (Styler object)
|
208 |
+
# whose underlying dataframe has three columns
|
209 |
+
# containing value counts of 'feat' among:
|
210 |
+
# - all cyclists involved in crashes
|
211 |
+
# - cyclists suffering serious injury or fatality
|
212 |
+
# each formatted as percentages of the series sum.
|
213 |
+
# Styled with bars comparing percentages
|
214 |
+
|
215 |
+
# Parameters:
|
216 |
+
# -----------
|
217 |
+
# feat : str
|
218 |
+
# The column name of the desired feature
|
219 |
+
# df : pd.DataFrame
|
220 |
+
# The input dataframe
|
221 |
+
# feat_name : str or None
|
222 |
+
# The feature name to use in the output dataframe
|
223 |
+
# index name. If None, will use feat
|
224 |
+
# cohort_name : str or None
|
225 |
+
# qualifier to use in front of 'cyclists'
|
226 |
+
# in titles, if provided, e.g. 'rural cyclists'
|
227 |
+
# merge_inj_death : bool
|
228 |
+
# whether to merge seriously injured and killed cohorts
|
229 |
+
# Returns:
|
230 |
+
# --------
|
231 |
+
# perc_comp : pd.Styler object
|
232 |
+
# """
|
233 |
+
# # Need qualifier for titles if restricting cyclist cohort
|
234 |
+
# qualifier = cohort_name if cohort_name is not None else ''
|
235 |
+
|
236 |
+
# # Two columns or three, depending on merge_inj_death
|
237 |
+
# if merge_inj_death:
|
238 |
+
# perc_comp = feat_perc(feat,df=df,feat_name=feat_name,
|
239 |
+
# col_name='all cyclists',)\
|
240 |
+
# .merge(feat_perc(feat,feat_name=feat_name,
|
241 |
+
# df=df.query('SERIOUS_OR_FATALITY==1'),
|
242 |
+
# col_name=qualifier+'cyclists with serious injury or fatality'),
|
243 |
+
# on=feat,how='left')
|
244 |
+
# perc_comp = perc_comp[perc_comp.max(axis=1)>=0.005]
|
245 |
+
# else:
|
246 |
+
# perc_comp = feat_perc(feat,df=df,feat_name=feat_name,
|
247 |
+
# col_name='all cyclists')\
|
248 |
+
# .merge(feat_perc(feat,feat_name=feat_name,
|
249 |
+
# df=df.query('INJ_SEVERITY=="susp_serious_injury"'),
|
250 |
+
# col_name=qualifier+'cyclists with serious injury'),
|
251 |
+
# on=feat,how='left')\
|
252 |
+
# .merge(feat_perc(feat,feat_name=feat_name,
|
253 |
+
# df=df.query('INJ_SEVERITY=="killed"'),
|
254 |
+
# col_name=qualifier+'cyclists with fatality'),
|
255 |
+
# on=feat,how='left')
|
256 |
+
|
257 |
+
# # If feature is not ordinal, sort rows descending by crash counts
|
258 |
+
# if feat not in ['AGE_BINS','SPEED_LIMIT','DAY_OF_WEEK','HOUR_OF_DAY']:
|
259 |
+
# perc_comp=perc_comp.sort_values(by='all cyclists',ascending=False)
|
260 |
+
|
261 |
+
# # Relabel day numbers with strings
|
262 |
+
# if feat == 'DAY_OF_WEEK':
|
263 |
+
# perc_comp.index=['Sun','Mon','Tues','Wed','Thurs','Fri','Sat']
|
264 |
+
# perc_comp.index.name='DAY_OF_WEEK'
|
265 |
+
# perc_comp=perc_comp.fillna(0)
|
266 |
+
# table_columns = list(perc_comp.columns)
|
267 |
+
|
268 |
+
# # Define format for displaying floats
|
269 |
+
# format_dict={col:'{:.2%}' for col in perc_comp.columns}
|
270 |
+
|
271 |
+
|
272 |
+
# # Define table styles
|
273 |
+
# styles = [dict(selector="caption",
|
274 |
+
# props=[("text-align", "center"),
|
275 |
+
# ("font-size", "100%"),
|
276 |
+
# ("color", 'black'),
|
277 |
+
# ("text-decoration","underline"),
|
278 |
+
# ("font-weight","bold")])]
|
279 |
+
|
280 |
+
# # Return formatted dataframe
|
281 |
+
# if feat_name is None:
|
282 |
+
# feat_name=feat
|
283 |
+
# caption = f'Breakdown of {feat_name} among cyclist groups'
|
284 |
+
# return perc_comp.reset_index().style.set_table_attributes("style='display:inline'")\
|
285 |
+
# .format(format_dict).bar(color='powderblue',
|
286 |
+
# subset=table_columns).hide().set_caption(caption)\
|
287 |
+
# .set_table_styles(styles)
|
requirements.txt
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
altair==5.0.1
|
2 |
+
appnope==0.1.3
|
3 |
+
asttokens==2.2.1
|
4 |
+
attrs==23.1.0
|
5 |
+
backcall==0.2.0
|
6 |
+
blinker==1.6.2
|
7 |
+
cachetools==5.3.1
|
8 |
+
certifi==2023.5.7
|
9 |
+
charset-normalizer==3.1.0
|
10 |
+
click==8.1.3
|
11 |
+
cloudpickle==2.2.1
|
12 |
+
contourpy==1.1.0
|
13 |
+
cycler==0.11.0
|
14 |
+
decorator==5.1.1
|
15 |
+
executing==1.2.0
|
16 |
+
fonttools==4.40.0
|
17 |
+
gitdb==4.0.10
|
18 |
+
GitPython==3.1.31
|
19 |
+
importlib-metadata==6.7.0
|
20 |
+
ipython==8.14.0
|
21 |
+
jedi==0.18.2
|
22 |
+
Jinja2==3.1.2
|
23 |
+
joblib==1.2.0
|
24 |
+
jsonschema==4.17.3
|
25 |
+
kiwisolver==1.4.4
|
26 |
+
lightgbm==4.0.0
|
27 |
+
llvmlite==0.40.1
|
28 |
+
markdown-it-py==3.0.0
|
29 |
+
MarkupSafe==2.1.3
|
30 |
+
matplotlib==3.7.1
|
31 |
+
matplotlib-inline==0.1.6
|
32 |
+
mdurl==0.1.2
|
33 |
+
numba==0.57.1
|
34 |
+
numpy==1.24.1
|
35 |
+
pandas==2.0.2
|
36 |
+
parso==0.8.3
|
37 |
+
pexpect==4.8.0
|
38 |
+
pickleshare==0.7.5
|
39 |
+
Pillow==9.5.0
|
40 |
+
plotly==5.15.0
|
41 |
+
prompt-toolkit==3.0.38
|
42 |
+
protobuf==4.23.3
|
43 |
+
ptyprocess==0.7.0
|
44 |
+
pure-eval==0.2.2
|
45 |
+
pyarrow==12.0.1
|
46 |
+
pydeck==0.8.1b0
|
47 |
+
Pygments==2.15.1
|
48 |
+
Pympler==1.0.1
|
49 |
+
pyparsing==3.1.0
|
50 |
+
pyrsistent==0.19.3
|
51 |
+
python-dateutil==2.8.2
|
52 |
+
pytz==2023.3
|
53 |
+
pytz-deprecation-shim==0.1.0.post0
|
54 |
+
rich==13.4.2
|
55 |
+
scikit-learn==1.2.2
|
56 |
+
scipy==1.10.1
|
57 |
+
seaborn==0.12.2
|
58 |
+
shap==0.42.1
|
59 |
+
six==1.16.0
|
60 |
+
slicer==0.0.7
|
61 |
+
smmap==5.0.0
|
62 |
+
stack-data==0.6.2
|
63 |
+
streamlit==1.23.1
|
64 |
+
tenacity==8.2.2
|
65 |
+
threadpoolctl==3.1.0
|
66 |
+
toml==0.10.2
|
67 |
+
toolz==0.12.0
|
68 |
+
tornado==6.3.2
|
69 |
+
tqdm==4.65.0
|
70 |
+
traitlets==5.9.0
|
71 |
+
typing_extensions==4.6.3
|
72 |
+
tzdata==2023.3
|
73 |
+
tzlocal==4.3.1
|
74 |
+
urllib3==2.0.3
|
75 |
+
validators==0.20.0
|
76 |
+
wcwidth==0.2.6
|
77 |
+
zipp==3.15.0
|
study.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b1f563b55025766dc82eb2f00b0e4ae468caa7368f87dd309be9d8047eace6d9
|
3 |
+
size 9935040
|