emilylearning
commited on
Commit
·
08879a1
1
Parent(s):
a43a76a
format file and remove share=True
Browse files
app.py
CHANGED
@@ -31,6 +31,7 @@ for bert_like in MODEL_NAMES:
|
|
31 |
|
32 |
# %%
|
33 |
|
|
|
34 |
def clean_tokens(tokens):
|
35 |
return [token.strip() for token in tokens]
|
36 |
|
@@ -61,8 +62,6 @@ def get_avg_prob_from_pipeline_outputs(mask_filled_text, gendered_token, num_pre
|
|
61 |
return round(sum(pronoun_preds) / (EPS + num_preds) * 100, DECIMAL_PLACES)
|
62 |
|
63 |
|
64 |
-
|
65 |
-
|
66 |
def get_figure(df, gender, n_fit=1):
|
67 |
df = df.set_index("x-axis")
|
68 |
cols = df.columns
|
@@ -75,16 +74,16 @@ def get_figure(df, gender, n_fit=1):
|
|
75 |
|
76 |
# find stackoverflow reference
|
77 |
p, C_p = np.polyfit(xs, ys, n_fit, cov=1)
|
78 |
-
t = np.linspace(min(xs)-1, max(xs)+1,
|
79 |
-
TT = np.vstack([t**(n_fit-i) for i in range(n_fit+1)]).T
|
80 |
|
81 |
# matrix multiplication calculates the polynomial values
|
82 |
yi = np.dot(TT, p)
|
83 |
C_yi = np.dot(TT, np.dot(C_p, TT.T)) # C_y = TT*C_z*TT.T
|
84 |
sig_yi = np.sqrt(np.diag(C_yi)) # Standard deviations are sqrt of diagonal
|
85 |
|
86 |
-
ax.fill_between(t, yi+sig_yi, yi-sig_yi, alpha
|
87 |
-
ax.plot(t, yi,
|
88 |
ax.plot(df, "ro")
|
89 |
ax.legend(list(df.columns))
|
90 |
|
@@ -97,7 +96,6 @@ def get_figure(df, gender, n_fit=1):
|
|
97 |
return fig
|
98 |
|
99 |
|
100 |
-
|
101 |
# %%
|
102 |
def predict_masked_tokens(
|
103 |
model_name,
|
@@ -185,34 +183,33 @@ def predict_masked_tokens(
|
|
185 |
|
186 |
truck_fn_example = [
|
187 |
MODEL_NAMES[2],
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
"True",
|
194 |
1,
|
195 |
]
|
|
|
|
|
196 |
def truck_1_fn():
|
197 |
-
return truck_fn_example + [
|
198 |
-
|
199 |
-
]
|
200 |
|
201 |
def truck_2_fn():
|
202 |
return truck_fn_example + [
|
203 |
-
|
204 |
]
|
205 |
|
206 |
|
207 |
# # %%
|
208 |
|
209 |
|
210 |
-
|
211 |
demo = gr.Blocks()
|
212 |
with demo:
|
213 |
gr.Markdown("# Spurious Correlation Evaluation for Pre-trained LLMs")
|
214 |
|
215 |
-
|
216 |
gr.Markdown("## Instructions for this Demo")
|
217 |
gr.Markdown(
|
218 |
"1) Click on one of the examples below to pre-populate the input fields."
|
@@ -224,8 +221,8 @@ with demo:
|
|
224 |
"3) Repeat steps (1) and (2) with more pre-populated inputs or with your own values in the input fields!"
|
225 |
)
|
226 |
|
227 |
-
|
228 |
-
|
229 |
We can see this spurious correlation largely disappears in the well-specified example text.
|
230 |
|
231 |
<p align="center">
|
@@ -236,18 +233,25 @@ with demo:
|
|
236 |
<p align="center">
|
237 |
<img src="file/well_spec.png" alt="results" width="300"/>
|
238 |
</p>
|
239 |
-
"""
|
|
|
240 |
|
241 |
gr.Markdown("## Example inputs")
|
242 |
gr.Markdown(
|
243 |
"Click a button below to pre-populate input fields with example values. Then scroll down to Hit Submit to generate predictions."
|
244 |
)
|
245 |
with gr.Row():
|
246 |
-
truck_1_gen = gr.Button(
|
247 |
-
|
|
|
|
|
|
|
|
|
248 |
|
249 |
truck_2_gen = gr.Button("Click for well-specified vehicle-type example inputs")
|
250 |
-
gr.Markdown(
|
|
|
|
|
251 |
|
252 |
gr.Markdown("## Input fields")
|
253 |
gr.Markdown(
|
@@ -343,11 +347,37 @@ with demo:
|
|
343 |
)
|
344 |
|
345 |
with gr.Row():
|
346 |
-
truck_1_gen.click(
|
347 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
348 |
|
349 |
-
truck_2_gen.click(
|
350 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
351 |
|
352 |
btn.click(
|
353 |
predict_masked_tokens,
|
@@ -365,8 +395,6 @@ with demo:
|
|
365 |
outputs=[sample_text, female_fig, male_fig, df],
|
366 |
)
|
367 |
|
368 |
-
demo.launch(debug=True
|
369 |
|
370 |
# %%
|
371 |
-
|
372 |
-
|
|
|
31 |
|
32 |
# %%
|
33 |
|
34 |
+
|
35 |
def clean_tokens(tokens):
|
36 |
return [token.strip() for token in tokens]
|
37 |
|
|
|
62 |
return round(sum(pronoun_preds) / (EPS + num_preds) * 100, DECIMAL_PLACES)
|
63 |
|
64 |
|
|
|
|
|
65 |
def get_figure(df, gender, n_fit=1):
|
66 |
df = df.set_index("x-axis")
|
67 |
cols = df.columns
|
|
|
74 |
|
75 |
# find stackoverflow reference
|
76 |
p, C_p = np.polyfit(xs, ys, n_fit, cov=1)
|
77 |
+
t = np.linspace(min(xs) - 1, max(xs) + 1, 10 * len(xs))
|
78 |
+
TT = np.vstack([t ** (n_fit - i) for i in range(n_fit + 1)]).T
|
79 |
|
80 |
# matrix multiplication calculates the polynomial values
|
81 |
yi = np.dot(TT, p)
|
82 |
C_yi = np.dot(TT, np.dot(C_p, TT.T)) # C_y = TT*C_z*TT.T
|
83 |
sig_yi = np.sqrt(np.diag(C_yi)) # Standard deviations are sqrt of diagonal
|
84 |
|
85 |
+
ax.fill_between(t, yi + sig_yi, yi - sig_yi, alpha=0.25)
|
86 |
+
ax.plot(t, yi, "-")
|
87 |
ax.plot(df, "ro")
|
88 |
ax.legend(list(df.columns))
|
89 |
|
|
|
96 |
return fig
|
97 |
|
98 |
|
|
|
99 |
# %%
|
100 |
def predict_masked_tokens(
|
101 |
model_name,
|
|
|
183 |
|
184 |
truck_fn_example = [
|
185 |
MODEL_NAMES[2],
|
186 |
+
"",
|
187 |
+
", ".join(["truck", "pickup"]),
|
188 |
+
", ".join(["car", "sedan"]),
|
189 |
+
", ".join(["city", "neighborhood", "farm"]),
|
190 |
+
"PLACE",
|
191 |
"True",
|
192 |
1,
|
193 |
]
|
194 |
+
|
195 |
+
|
196 |
def truck_1_fn():
|
197 |
+
return truck_fn_example + ["He loaded up his truck and drove to the PLACE."]
|
198 |
+
|
|
|
199 |
|
200 |
def truck_2_fn():
|
201 |
return truck_fn_example + [
|
202 |
+
"He loaded up the bed of his truck and drove to the PLACE."
|
203 |
]
|
204 |
|
205 |
|
206 |
# # %%
|
207 |
|
208 |
|
|
|
209 |
demo = gr.Blocks()
|
210 |
with demo:
|
211 |
gr.Markdown("# Spurious Correlation Evaluation for Pre-trained LLMs")
|
212 |
|
|
|
213 |
gr.Markdown("## Instructions for this Demo")
|
214 |
gr.Markdown(
|
215 |
"1) Click on one of the examples below to pre-populate the input fields."
|
|
|
221 |
"3) Repeat steps (1) and (2) with more pre-populated inputs or with your own values in the input fields!"
|
222 |
)
|
223 |
|
224 |
+
gr.Markdown(
|
225 |
+
"""The pre-populated inputs below are for a demo example of a location-vs-vehicle-type spurious correlation.
|
226 |
We can see this spurious correlation largely disappears in the well-specified example text.
|
227 |
|
228 |
<p align="center">
|
|
|
233 |
<p align="center">
|
234 |
<img src="file/well_spec.png" alt="results" width="300"/>
|
235 |
</p>
|
236 |
+
"""
|
237 |
+
)
|
238 |
|
239 |
gr.Markdown("## Example inputs")
|
240 |
gr.Markdown(
|
241 |
"Click a button below to pre-populate input fields with example values. Then scroll down to Hit Submit to generate predictions."
|
242 |
)
|
243 |
with gr.Row():
|
244 |
+
truck_1_gen = gr.Button(
|
245 |
+
"Click for non-well-specified(?) vehicle-type example inputs"
|
246 |
+
)
|
247 |
+
gr.Markdown(
|
248 |
+
"<-- Multiple solutions with low training error. LLM sensitive to spurious(?) correlations."
|
249 |
+
)
|
250 |
|
251 |
truck_2_gen = gr.Button("Click for well-specified vehicle-type example inputs")
|
252 |
+
gr.Markdown(
|
253 |
+
"<-- Fewer solutions with low training error. LLM less sensitive to spurious(?) correlations."
|
254 |
+
)
|
255 |
|
256 |
gr.Markdown("## Input fields")
|
257 |
gr.Markdown(
|
|
|
347 |
)
|
348 |
|
349 |
with gr.Row():
|
350 |
+
truck_1_gen.click(
|
351 |
+
truck_1_fn,
|
352 |
+
inputs=[],
|
353 |
+
outputs=[
|
354 |
+
model_name,
|
355 |
+
own_model_name,
|
356 |
+
group_a_tokens,
|
357 |
+
group_b_tokens,
|
358 |
+
x_axis,
|
359 |
+
place_holder,
|
360 |
+
to_normalize,
|
361 |
+
n_fit,
|
362 |
+
input_text,
|
363 |
+
],
|
364 |
+
)
|
365 |
|
366 |
+
truck_2_gen.click(
|
367 |
+
truck_2_fn,
|
368 |
+
inputs=[],
|
369 |
+
outputs=[
|
370 |
+
model_name,
|
371 |
+
own_model_name,
|
372 |
+
group_a_tokens,
|
373 |
+
group_b_tokens,
|
374 |
+
x_axis,
|
375 |
+
place_holder,
|
376 |
+
to_normalize,
|
377 |
+
n_fit,
|
378 |
+
input_text,
|
379 |
+
],
|
380 |
+
)
|
381 |
|
382 |
btn.click(
|
383 |
predict_masked_tokens,
|
|
|
395 |
outputs=[sample_text, female_fig, male_fig, df],
|
396 |
)
|
397 |
|
398 |
+
demo.launch(debug=True)
|
399 |
|
400 |
# %%
|
|
|
|