File size: 2,404 Bytes
4f67e27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import torch
from transformers import AutoTokenizer
from captum.attr import visualization

from roberta2 import RobertaForSequenceClassification
from ExplanationGenerator import Generator
from util import visualize_text, PyTMinMaxScalerVectorized

classifications = ["NEGATIVE", "POSITIVE"]

class RolloutExplainer(Generator):
    def __init__(self, model, tokenizer):
        super().__init__(model, key="roberta.encoder.layer")
        self.device = model.device
        self.tokenizer = tokenizer

    def build_visualization(self, input_ids, attention_mask, start_layer=8):
        # generate an explanation for the input
        vis_data_records = []

        output, expl = self.generate_rollout(
            input_ids, attention_mask, start_layer=start_layer
        )
        # normalize scores
        scaler = PyTMinMaxScalerVectorized()

        norm = scaler(expl)
        # get the model classification
        output = torch.nn.functional.softmax(output, dim=-1)

        for record in range(input_ids.size(0)):
            classification = output[record].argmax(dim=-1).item()
            class_name = classifications[classification]
            nrm = norm[record]

            # if the classification is negative, higher explanation scores are more negative
            # flip for visualization
            if class_name == "NEGATIVE":
                nrm *= -1
            tokens = self.tokens_from_ids(input_ids[record].flatten())[
                1 : 0 - ((attention_mask[record] == 0).sum().item() + 1)
            ]
            vis_data_records.append(
                visualization.VisualizationDataRecord(
                    nrm,
                    output[record][classification],
                    classification,
                    classification,
                    classification,
                    1,
                    tokens,
                    1,
                )
            )
        return visualize_text(vis_data_records)

    def __call__(self, input_text, start_layer=8):
        if start_layer > 0:
            start_layer -= 1

        text_batch = [input_text]
        encoding = self.tokenizer(text_batch, return_tensors="pt")
        input_ids = encoding["input_ids"].to(self.device)
        attention_mask = encoding["attention_mask"].to(self.device)

        return self.build_visualization(input_ids, attention_mask, start_layer=int(start_layer))