|
import torch |
|
from transformers import AutoTokenizer |
|
from captum.attr import visualization |
|
|
|
from roberta2 import RobertaForSequenceClassification |
|
from ExplanationGenerator import Generator |
|
from util import visualize_text, PyTMinMaxScalerVectorized |
|
|
|
classifications = ["NEGATIVE", "POSITIVE"] |
|
|
|
class RolloutExplainer(Generator): |
|
def __init__(self, model, tokenizer): |
|
super().__init__(model, key="roberta.encoder.layer") |
|
self.device = model.device |
|
self.tokenizer = tokenizer |
|
|
|
def build_visualization(self, input_ids, attention_mask, start_layer=8): |
|
|
|
vis_data_records = [] |
|
|
|
output, expl = self.generate_rollout( |
|
input_ids, attention_mask, start_layer=start_layer |
|
) |
|
|
|
scaler = PyTMinMaxScalerVectorized() |
|
|
|
norm = scaler(expl) |
|
|
|
output = torch.nn.functional.softmax(output, dim=-1) |
|
|
|
for record in range(input_ids.size(0)): |
|
classification = output[record].argmax(dim=-1).item() |
|
class_name = classifications[classification] |
|
nrm = norm[record] |
|
|
|
|
|
|
|
if class_name == "NEGATIVE": |
|
nrm *= -1 |
|
tokens = self.tokens_from_ids(input_ids[record].flatten())[ |
|
1 : 0 - ((attention_mask[record] == 0).sum().item() + 1) |
|
] |
|
vis_data_records.append( |
|
visualization.VisualizationDataRecord( |
|
nrm, |
|
output[record][classification], |
|
classification, |
|
classification, |
|
classification, |
|
1, |
|
tokens, |
|
1, |
|
) |
|
) |
|
return visualize_text(vis_data_records) |
|
|
|
def __call__(self, input_text, start_layer=8): |
|
if start_layer > 0: |
|
start_layer -= 1 |
|
|
|
text_batch = [input_text] |
|
encoding = self.tokenizer(text_batch, return_tensors="pt") |
|
input_ids = encoding["input_ids"].to(self.device) |
|
attention_mask = encoding["attention_mask"].to(self.device) |
|
|
|
return self.build_visualization(input_ids, attention_mask, start_layer=int(start_layer)) |
|
|
|
|