Martijn van Beers commited on
Commit
ab7830f
·
1 Parent(s): 733749d

Hack it up to do multiple explanations

Browse files

Adds in explanations from captum's LayerIntegratedGradients

Files changed (1) hide show
  1. app.py +86 -40
app.py CHANGED
@@ -5,10 +5,11 @@ sys.path.append("BERT_explainability")
5
 
6
  import torch
7
 
 
8
  from BERT_explainability.ExplanationGenerator import Generator
9
  from BERT_explainability.roberta2 import RobertaForSequenceClassification
10
  from transformers import AutoTokenizer
11
-
12
  from captum.attr import visualization
13
  import torch
14
 
@@ -39,6 +40,7 @@ model = RobertaForSequenceClassification.from_pretrained(
39
  "textattack/roberta-base-SST-2"
40
  ).to(device)
41
  model.eval()
 
42
  tokenizer = AutoTokenizer.from_pretrained("textattack/roberta-base-SST-2")
43
  # initialize the explanations generator
44
  explanations = Generator(model, "roberta")
@@ -151,7 +153,7 @@ def visualize_text(datarecords, legend=True):
151
  return html
152
 
153
 
154
- def show_explanation(model, input_ids, attention_mask, index=None, start_layer=0):
155
  # generate an explanation for the input
156
  output, expl = generate_relevance(
157
  model, input_ids, attention_mask, index=index, start_layer=start_layer
@@ -177,32 +179,87 @@ def show_explanation(model, input_ids, attention_mask, index=None, start_layer=0
177
  tokens = tokenizer.convert_ids_to_tokens(input_ids[record].flatten())[
178
  1 : 0 - ((attention_mask[record] == 0).sum().item() + 1)
179
  ]
180
- vis_data_records.append(list(zip(tokens, nrm.tolist())))
181
  #print([(tokens[i], nrm[i].item()) for i in range(len(tokens))])
182
- # vis_data_records.append(
183
- # visualization.VisualizationDataRecord(
184
- # nrm,
185
- # output[record][classification],
186
- # classification,
187
- # classification,
188
- # index,
189
- # 1,
190
- # tokens,
191
- # 1,
192
- # )
193
- # )
194
- # return visualize_text(vis_data_records)
195
- return vis_data_records
 
 
 
 
 
 
 
196
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
 
198
  def sentence_sentiment(input_text):
199
  text_batch = [input_text]
200
  encoding = tokenizer(text_batch, return_tensors="pt")
201
  input_ids = encoding["input_ids"].to(device)
202
  attention_mask = encoding["attention_mask"].to(device)
203
- output = model(input_ids=input_ids, attention_mask=attention_mask)[0]
204
- index = output.argmax(axis=-1).item()
205
- return classifications[index]
206
 
207
  def sentiment_explanation_hila(input_text):
208
  text_batch = [input_text]
@@ -216,27 +273,19 @@ def sentiment_explanation_hila(input_text):
216
  return show_explanation(model, input_ids, attention_mask)
217
 
218
  hila = gradio.Interface(
219
- fn=sentence_sentiment,
220
  inputs="text",
221
- outputs="label",
222
- title="RoBERTa Explanability",
223
- description="Quick demo of a version of [Hila Chefer's](https://github.com/hila-chefer) [Transformer-Explanability](https://github.com/hila-chefer/Transformer-Explainability/) but without the layerwise relevance propagation (as in [Transformer-MM_explainability](https://github.com/hila-chefer/Transformer-MM-Explainability/)) for a RoBERTa model.",
224
- examples=[
225
- [
226
- "This movie was the best movie I have ever seen! some scenes were ridiculous, but acting was great"
227
- ],
228
- [
229
- "I really didn't like this movie. Some of the actors were good, but overall the movie was boring"
230
- ],
231
- ],
232
- interpretation=sentiment_explanation_hila
233
  )
234
- shap = gradio.Interface(
235
  fn=sentence_sentiment,
236
  inputs="text",
237
- outputs="label",
238
- title="RoBERTa Explanability",
239
- description="gradio shap explanations",
 
 
 
240
  examples=[
241
  [
242
  "This movie was the best movie I have ever seen! some scenes were ridiculous, but acting was great"
@@ -245,8 +294,5 @@ shap = gradio.Interface(
245
  "I really didn't like this movie. Some of the actors were good, but overall the movie was boring"
246
  ],
247
  ],
248
- interpretation="shap"
249
  )
250
-
251
- iface = gradio.Parallel(hila, shap)
252
  iface.launch()
 
5
 
6
  import torch
7
 
8
+ from transformers import AutoModelForSequenceClassification
9
  from BERT_explainability.ExplanationGenerator import Generator
10
  from BERT_explainability.roberta2 import RobertaForSequenceClassification
11
  from transformers import AutoTokenizer
12
+ from captum.attr import LayerIntegratedGradients
13
  from captum.attr import visualization
14
  import torch
15
 
 
40
  "textattack/roberta-base-SST-2"
41
  ).to(device)
42
  model.eval()
43
+ model2 = AutoModelForSequenceClassification.from_pretrained("textattack/roberta-base-SST-2")
44
  tokenizer = AutoTokenizer.from_pretrained("textattack/roberta-base-SST-2")
45
  # initialize the explanations generator
46
  explanations = Generator(model, "roberta")
 
153
  return html
154
 
155
 
156
+ def show_explanation(model, input_ids, attention_mask, index=None, start_layer=8):
157
  # generate an explanation for the input
158
  output, expl = generate_relevance(
159
  model, input_ids, attention_mask, index=index, start_layer=start_layer
 
179
  tokens = tokenizer.convert_ids_to_tokens(input_ids[record].flatten())[
180
  1 : 0 - ((attention_mask[record] == 0).sum().item() + 1)
181
  ]
182
+ # vis_data_records.append(list(zip(tokens, nrm.tolist())))
183
  #print([(tokens[i], nrm[i].item()) for i in range(len(tokens))])
184
+ vis_data_records.append(
185
+ visualization.VisualizationDataRecord(
186
+ nrm,
187
+ output[record][classification],
188
+ classification,
189
+ classification,
190
+ index,
191
+ 1,
192
+ tokens,
193
+ 1,
194
+ )
195
+ )
196
+ return visualize_text(vis_data_records)
197
+ # return vis_data_records
198
+
199
+ def custom_forward(inputs, attention_mask=None, pos=0):
200
+ # print("inputs", inputs.shape)
201
+ result = model2(inputs, attention_mask=attention_mask, return_dict=True)
202
+ preds = result.logits
203
+ # print("preds", preds.shape)
204
+ return preds
205
 
206
+ def summarize_attributions(attributions):
207
+ attributions = attributions.sum(dim=-1).squeeze(0)
208
+ attributions = attributions / torch.norm(attributions)
209
+ return attributions
210
+
211
+
212
+ def run_attribution_model(input_ids, attention_mask, ref_token_id=tokenizer.unk_token_id, layer=None, steps=20):
213
+ try:
214
+ output = model2(input_ids=input_ids, attention_mask=attention_mask)[0]
215
+ index = output.argmax(axis=-1).detach().cpu().numpy()
216
+
217
+ ablator = LayerIntegratedGradients(custom_forward, layer)
218
+ input_tensor = input_ids
219
+ attention_mask = attention_mask
220
+ attributions = ablator.attribute(
221
+ inputs=input_ids,
222
+ baselines=ref_token_id,
223
+ additional_forward_args=(attention_mask),
224
+ target=1,
225
+ n_steps=steps,
226
+ )
227
+ attributions = summarize_attributions(attributions).unsqueeze_(0)
228
+ finally:
229
+ pass
230
+ vis_data_records = []
231
+ print("IN", input_ids.size())
232
+ print("ATTR", attributions.shape)
233
+ for record in range(input_ids.size(0)):
234
+ classification = output[record].argmax(dim=-1).item()
235
+ class_name = classifications[classification]
236
+ attr = attributions[record]
237
+ tokens = tokenizer.convert_ids_to_tokens(input_ids[record].flatten())[
238
+ 1 : 0 - ((attention_mask[record] == 0).sum().item() + 1)
239
+ ]
240
+ print("TOK", len(tokens), attr.shape)
241
+ vis_data_records.append(
242
+ visualization.VisualizationDataRecord(
243
+ attr,
244
+ output[record][classification],
245
+ classification,
246
+ classification,
247
+ index,
248
+ 1,
249
+ tokens,
250
+ 1,
251
+ )
252
+ )
253
+ return visualize_text(vis_data_records)
254
 
255
  def sentence_sentiment(input_text):
256
  text_batch = [input_text]
257
  encoding = tokenizer(text_batch, return_tensors="pt")
258
  input_ids = encoding["input_ids"].to(device)
259
  attention_mask = encoding["attention_mask"].to(device)
260
+ layer = getattr(model2.roberta.encoder.layer, "8")
261
+ output = run_attribution_model(input_ids, attention_mask, layer=layer)
262
+ return output
263
 
264
  def sentiment_explanation_hila(input_text):
265
  text_batch = [input_text]
 
273
  return show_explanation(model, input_ids, attention_mask)
274
 
275
  hila = gradio.Interface(
276
+ fn=sentiment_explanation_hila,
277
  inputs="text",
278
+ outputs="html",
 
 
 
 
 
 
 
 
 
 
 
279
  )
280
+ lig = gradio.Interface(
281
  fn=sentence_sentiment,
282
  inputs="text",
283
+ outputs="html",
284
+ )
285
+
286
+ iface = gradio.Parallel(hila, lig,
287
+ title="RoBERTa Explanability",
288
+ description="Quick comparison demo of explainability for sentiment prediction with RoBERTa. The outputs are from:\n\n* a version of [Hila Chefer's](https://github.com/hila-chefer) [Transformer-Explanability](https://github.com/hila-chefer/Transformer-Explainability/) but without the layerwise relevance propagation (as in [Transformer-MM_explainability](https://github.com/hila-chefer/Transformer-MM-Explainability/)) for a RoBERTa model.\n* [captum](https://captum.ai/)'s LayerIntegratedGradients",
289
  examples=[
290
  [
291
  "This movie was the best movie I have ever seen! some scenes were ridiculous, but acting was great"
 
294
  "I really didn't like this movie. Some of the actors were good, but overall the movie was boring"
295
  ],
296
  ],
 
297
  )
 
 
298
  iface.launch()