File size: 4,250 Bytes
da855ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import pdb

import torch

from torch import nn


class Tisa(nn.Module):
    def __init__(self, num_attention_heads: int = 12, num_kernels: int = 5):
        super().__init__()
        self.num_attention_heads = num_attention_heads
        self.num_kernels = num_kernels

        self.kernel_offsets = nn.Parameter(
            torch.Tensor(self.num_kernels, self.num_attention_heads)
        )
        self.kernel_amplitudes = nn.Parameter(
            torch.Tensor(self.num_kernels, self.num_attention_heads)
        )
        self.kernel_sharpness = nn.Parameter(
            torch.Tensor(self.num_kernels, self.num_attention_heads)
        )
        self._init_weights()

    def create_relative_offsets(self, seq_len: int):
        """Creates offsets for all the relative distances between
        -seq_len + 1 to seq_len - 1."""
        return torch.arange(-seq_len, seq_len + 1)

    def compute_positional_scores(self, relative_offsets):
        """Takes seq_len and outputs position scores for each relative distance.
        This implementation uses radial basis functions. Override this function to
        use other scoring functions than the example in the paper."""
        rbf_scores = (
            self.kernel_amplitudes.unsqueeze(-1)
            * torch.exp(
                -torch.abs(self.kernel_sharpness.unsqueeze(-1))
                * ((self.kernel_offsets.unsqueeze(-1) - relative_offsets) ** 2)
            )
        ).sum(axis=0)
        return rbf_scores

    def scores_to_toeplitz_matrix(self, positional_scores, seq_len: int):
        """Converts the TISA positional scores into the final matrix for the
        self-attention equation. PRs with memory and/or speed optimizations are
        welcome."""
        deformed_toeplitz = (
            (
                (torch.arange(0, -(seq_len ** 2), step=-1) + (seq_len - 1)).view(
                    seq_len, seq_len
                )
                + (seq_len + 1) * torch.arange(seq_len).view(-1, 1)
            )
            .view(-1)
            .long()
            .to(device=positional_scores.device)
        )
        expanded_positional_scores = torch.take_along_dim(
            positional_scores, deformed_toeplitz.view(1, -1), 1
        ).view(self.num_attention_heads, seq_len, seq_len)
        return expanded_positional_scores

    def forward(self, seq_len: int):
        """Computes the translation-invariant positional contribution to the
        attention matrix in the self-attention module of transformer models."""
        if not self.num_kernels:
            return torch.zeros((self.num_attention_heads, seq_len, seq_len))
        positional_scores_vector = self.compute_positional_scores(
            self.create_relative_offsets(seq_len)
        )
        positional_scores_matrix = self.scores_to_toeplitz_matrix(
            positional_scores_vector, seq_len
        )
        return positional_scores_matrix

    def visualize(self, seq_len: int = 10, attention_heads=None):
        """Visualizes the TISA interpretability by plotting position scores as
        a function of relative distance for each attention head."""
        if attention_heads is None:
            attention_heads = list(range(self.num_attention_heads))
        import matplotlib.pyplot as plt

        x = self.create_relative_offsets(seq_len).detach().numpy()
        y = (
            self.compute_positional_scores(self.create_relative_offsets(seq_len))
            .detach()
            .numpy()
        )
        for i in attention_heads:
            plt.plot(x, y[i])
        plt.savefig('./pic-tisa.png')
        plt.show()

    def _init_weights(self):
        """Initialize the weights"""
        ampl_init_mean = 0.1
        sharpness_init_mean = 0.1
        torch.nn.init.normal_(self.kernel_offsets, mean=0.0, std=5.0)
        torch.nn.init.normal_(
            self.kernel_amplitudes, mean=ampl_init_mean, std=0.1 * ampl_init_mean
        )
        torch.nn.init.normal_(
            self.kernel_sharpness,
            mean=sharpness_init_mean,
            std=0.1 * sharpness_init_mean,
        )


def main():
    tisa = Tisa()
    positional_scores = tisa(20)
    pdb.set_trace()
    tisa.visualize(seq_len=20)


if __name__ == "__main__":
    main()