attention-rollout / lib /integrated_gradients.py
Martijn van Beers
Add 'classic' rollout
4f67e27
import torch
from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer
from captum.attr import LayerIntegratedGradients
from captum.attr import visualization
from roberta2 import RobertaForSequenceClassification
from ExplanationGenerator import Generator
from util import visualize_text
classifications = ["NEGATIVE", "POSITIVE"]
class IntegratedGradientsExplainer:
def __init__(self, model, tokenizer):
self.model = model
self.device = model.device
self.tokenizer = tokenizer
self.baseline_map = {
'Unknown': self.tokenizer.unk_token_id,
'Padding': self.tokenizer.pad_token_id,
}
def tokens_from_ids(self, ids):
return list(map(lambda s: s[1:] if s[0] == "Ġ" else s, self.tokenizer.convert_ids_to_tokens(ids)))
def custom_forward(self, inputs, attention_mask=None, pos=0):
result = self.model(inputs, attention_mask=attention_mask, return_dict=True)
preds = result.logits
return preds
@staticmethod
def summarize_attributions(attributions):
attributions = attributions.sum(dim=-1).squeeze(0)
attributions = attributions / torch.norm(attributions)
return attributions
def run_attribution_model(self, input_ids, attention_mask, baseline=None, index=None, layer=None, steps=20):
if baseline is None:
baseline = self.tokenizer.unk_token_id
else:
baseline = self.baseline_map[baseline]
try:
output = self.model(input_ids=input_ids, attention_mask=attention_mask)[0]
# if index is None:
# index = output.argmax(axis=-1).item()
ablator = LayerIntegratedGradients(self.custom_forward, layer)
input_tensor = input_ids
attention_mask = attention_mask
attributions = ablator.attribute(
inputs=input_ids,
baselines=baseline,
additional_forward_args=(attention_mask),
target=1,
n_steps=steps,
)
return self.summarize_attributions(attributions).unsqueeze_(0), output, index
finally:
pass
def build_visualization(self, input_ids, attention_mask, **kwargs):
vis_data_records = []
attributions, output, index = self.run_attribution_model(input_ids, attention_mask, **kwargs)
for record in range(input_ids.size(0)):
classification = output[record].argmax(dim=-1).item()
class_name = classifications[classification]
attr = attributions[record]
tokens = self.tokens_from_ids(input_ids[record].flatten())[
1 : 0 - ((attention_mask[record] == 0).sum().item() + 1)
]
vis_data_records.append(
visualization.VisualizationDataRecord(
attr,
output[record][classification],
classification,
classification,
index,
1,
tokens,
1,
)
)
return visualize_text(vis_data_records)
def __call__(self, input_text, layer, baseline):
text_batch = [input_text]
encoding = self.tokenizer(text_batch, return_tensors="pt")
input_ids = encoding["input_ids"].to(self.device)
attention_mask = encoding["attention_mask"].to(self.device)
layer = int(layer)
if layer == 0:
layer = self.model.roberta.embeddings
else:
layer = getattr(self.model.roberta.encoder.layer, str(layer-1))
return self.build_visualization(input_ids, attention_mask, layer=layer, baseline=baseline)