rrivera1849 commited on
Commit
291e050
·
1 Parent(s): b1178a0

Upload LUAR

Browse files
Files changed (4) hide show
  1. config.json +13 -0
  2. config.py +12 -0
  3. model.py +85 -0
  4. pytorch_model.bin +3 -0
config.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "LUAR"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "config.LUARConfig",
7
+ "AutoModel": "model.LUAR"
8
+ },
9
+ "embedding_size": 512,
10
+ "model_type": "LUAR",
11
+ "torch_dtype": "float32",
12
+ "transformers_version": "4.33.2"
13
+ }
config.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from transformers import PretrainedConfig
3
+
4
+ class LUARConfig(PretrainedConfig):
5
+ model_type = "LUAR"
6
+
7
+ def __init__(self,
8
+ embedding_size: int = 512,
9
+ **kwargs,
10
+ ):
11
+ self.embedding_size = embedding_size
12
+ super().__init__(**kwargs)
model.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import math
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from einops import rearrange, reduce, repeat
8
+ from transformers import AutoModel, PreTrainedModel
9
+
10
+ from .config import LUARConfig
11
+
12
+ class SelfAttention(nn.Module):
13
+ """Implements Dot-Product Self-Attention as used in "Attention is all You Need".
14
+ """
15
+ def __init__(self):
16
+ super(SelfAttention, self).__init__()
17
+
18
+ def forward(self, k, q, v):
19
+ d_k = q.size(-1)
20
+ scores = torch.matmul(k, q.transpose(-2, -1)) / math.sqrt(d_k)
21
+ p_attn = F.softmax(scores, dim=-1)
22
+
23
+ return torch.matmul(p_attn, v)
24
+
25
+ class LUAR(PreTrainedModel):
26
+ """Defines the LUAR model.
27
+ """
28
+ config_class = LUARConfig
29
+
30
+ def __init__(self, config):
31
+ super().__init__(config)
32
+ self.create_transformer()
33
+ self.attn_fn = SelfAttention()
34
+ self.linear = nn.Linear(self.hidden_size, config.embedding_size)
35
+
36
+ def create_transformer(self):
37
+ """Creates the Transformer backbone.
38
+ """
39
+ self.transformer = AutoModel.from_pretrained("sentence-transformers/paraphrase-distilroberta-base-v1")
40
+ self.hidden_size = self.transformer.config.hidden_size
41
+ self.num_attention_heads = self.transformer.config.num_attention_heads
42
+ self.dim_head = self.hidden_size // self.num_attention_heads
43
+
44
+ def mean_pooling(self, token_embeddings, attention_mask):
45
+ """Mean Pooling as described in the SBERT paper.
46
+ """
47
+ input_mask_expanded = repeat(attention_mask, 'b l -> b l d', d=self.hidden_size).float()
48
+ sum_embeddings = reduce(token_embeddings * input_mask_expanded, 'b l d -> b d', 'sum')
49
+ sum_mask = torch.clamp(reduce(input_mask_expanded, 'b l d -> b d', 'sum'), min=1e-9)
50
+ return sum_embeddings / sum_mask
51
+
52
+ def get_episode_embeddings(self, data):
53
+ """Computes the Author Embedding.
54
+ """
55
+ input, attention_mask = data[0], data[1]
56
+ B, N, E, _ = attention_mask.shape
57
+ attention_mask = rearrange(attention_mask, 'b n e l -> (b n e) l')
58
+
59
+ input = rearrange(input, 'b n e l -> (b n e) l')
60
+
61
+ outputs = self.transformer(
62
+ input_ids=input,
63
+ attention_mask=attention_mask,
64
+ return_dict=True,
65
+ output_hidden_states=True
66
+ )
67
+
68
+ # at this point, we're embedding individual "comments"
69
+ comment_embeddings = self.mean_pooling(outputs['last_hidden_state'], attention_mask)
70
+ comment_embeddings = rearrange(comment_embeddings, '(b n e) l -> (b n) e l', b=B, n=N, e=E)
71
+
72
+ # aggregate individual comments embeddings into episode embeddings
73
+ episode_embeddings = self.attn_fn(comment_embeddings, comment_embeddings, comment_embeddings)
74
+ episode_embeddings = reduce(episode_embeddings, 'b e l -> b l', 'max')
75
+
76
+ episode_embeddings = self.linear(episode_embeddings)
77
+
78
+ return episode_embeddings, comment_embeddings
79
+
80
+ def forward(self, data):
81
+ """Calculates a fixed-length feature vector for a batch of episode samples.
82
+ """
83
+ output = self.get_episode_embeddings(data)
84
+
85
+ return output
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:76a4ca30b8ef7c5d3fb806e7030c8375cdc3f60e2e2a607a2156917ac78e74b4
3
+ size 330083185