LethallyHealthy commited on
Commit
f1197fe
·
1 Parent(s): 50acdcd

Update predictor.py

Browse files
Files changed (1) hide show
  1. predictor.py +23 -4
predictor.py CHANGED
@@ -134,8 +134,8 @@ def make_a_prediction(X):
134
  predictions = make_a_prediction(X_test)
135
  print(predictions)
136
 
137
- #to be called when needed
138
- def create_shap_models(data):
139
  explainer = shap.TreeExplainer(model)
140
  shap_values = explainer.shap_values(data)
141
  shap.initjs()
@@ -152,6 +152,25 @@ def create_shap_models(data):
152
  interaction_values = explainer.shap_interaction_values(data)
153
  interaction_values[0].round(2)
154
  st.write(pd.DataFrame(interaction_values[0].round(2)).head(60))
155
-
156
 
157
- return obj1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  predictions = make_a_prediction(X_test)
135
  print(predictions)
136
 
137
+ #to be called when needed for optimized results
138
+ def create_opt_shap_models(data):
139
  explainer = shap.TreeExplainer(model)
140
  shap_values = explainer.shap_values(data)
141
  shap.initjs()
 
152
  interaction_values = explainer.shap_interaction_values(data)
153
  interaction_values[0].round(2)
154
  st.write(pd.DataFrame(interaction_values[0].round(2)).head(60))
 
155
 
156
+ return obj1
157
+ #to be called when needed for optimized results
158
+ def create_unopt_shap_models(data):
159
+ explainer = shap.TreeExplainer(lgbm)
160
+ shap_values = explainer.shap_values(data)
161
+ shap.initjs()
162
+ obj2 = shap.force_plot(explainer.expected_value, shap_values=shap_values, feature_names=data.columns)
163
+
164
+ shap.initjs()
165
+ shap.decision_plot(explainer.expected_value, shap_values, feature_names=np.array(data.columns))
166
+ st.pyplot(bbox_inches='tight')
167
+
168
+ shap.initjs()
169
+ shap.summary_plot(shap_values=shap_values, feature_names=data.columns)
170
+ st.pyplot(bbox_inches='tight')
171
+
172
+ interaction_values = explainer.shap_interaction_values(data)
173
+ interaction_values[0].round(2)
174
+ st.write(pd.DataFrame(interaction_values[0].round(2)).head(60))
175
+
176
+ return obj2