rrivera1849
commited on
Commit
·
82ec677
1
Parent(s):
9031ba4
Upload LUAR
Browse files
model.py
CHANGED
@@ -146,7 +146,7 @@ class LUAR(PreTrainedModel):
|
|
146 |
config.k_bucket_size,
|
147 |
)
|
148 |
self.linear = nn.Linear(self.hidden_size, config.embedding_size)
|
149 |
-
|
150 |
def create_transformer(self):
|
151 |
"""Creates the Transformer backbone.
|
152 |
"""
|
@@ -163,7 +163,7 @@ class LUAR(PreTrainedModel):
|
|
163 |
sum_mask = torch.clamp(reduce(input_mask_expanded, 'b l d -> b d', 'sum'), min=1e-9)
|
164 |
return sum_embeddings / sum_mask
|
165 |
|
166 |
-
def get_episode_embeddings(self, input_ids, attention_mask, output_attentions=False):
|
167 |
"""Computes the Author Embedding.
|
168 |
"""
|
169 |
B, E, _ = attention_mask.shape
|
@@ -171,14 +171,31 @@ class LUAR(PreTrainedModel):
|
|
171 |
input_ids = rearrange(input_ids, 'b e l -> (b e) l')
|
172 |
attention_mask = rearrange(attention_mask, 'b e l -> (b e) l')
|
173 |
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
182 |
# at this point, we're embedding individual "comments"
|
183 |
comment_embeddings = self.mean_pooling(outputs['last_hidden_state'], attention_mask)
|
184 |
comment_embeddings = rearrange(comment_embeddings, '(b e) l -> b e l', b=B, e=E)
|
@@ -194,9 +211,9 @@ class LUAR(PreTrainedModel):
|
|
194 |
|
195 |
return episode_embeddings
|
196 |
|
197 |
-
def forward(self, input_ids, attention_mask, output_attentions=False):
|
198 |
"""Calculates a fixed-length feature vector for a batch of episode samples.
|
199 |
"""
|
200 |
-
output = self.get_episode_embeddings(input_ids, attention_mask, output_attentions)
|
201 |
|
202 |
return output
|
|
|
146 |
config.k_bucket_size,
|
147 |
)
|
148 |
self.linear = nn.Linear(self.hidden_size, config.embedding_size)
|
149 |
+
|
150 |
def create_transformer(self):
|
151 |
"""Creates the Transformer backbone.
|
152 |
"""
|
|
|
163 |
sum_mask = torch.clamp(reduce(input_mask_expanded, 'b l d -> b d', 'sum'), min=1e-9)
|
164 |
return sum_embeddings / sum_mask
|
165 |
|
166 |
+
def get_episode_embeddings(self, input_ids, attention_mask, output_attentions=False, document_batch_size=0):
|
167 |
"""Computes the Author Embedding.
|
168 |
"""
|
169 |
B, E, _ = attention_mask.shape
|
|
|
171 |
input_ids = rearrange(input_ids, 'b e l -> (b e) l')
|
172 |
attention_mask = rearrange(attention_mask, 'b e l -> (b e) l')
|
173 |
|
174 |
+
if document_batch_size > 0:
|
175 |
+
outputs = {"last_hidden_state": [], "attentions": []}
|
176 |
+
for i in range(0, len(input_ids), document_batch_size):
|
177 |
+
out = self.transformer(
|
178 |
+
input_ids=input_ids[i:i+document_batch_size],
|
179 |
+
attention_mask=attention_mask[i:i+document_batch_size],
|
180 |
+
return_dict=True,
|
181 |
+
output_hidden_states=False,
|
182 |
+
output_attentions=output_attentions,
|
183 |
+
)
|
184 |
+
outputs["last_hidden_state"].append(out["last_hidden_state"])
|
185 |
+
if output_attentions:
|
186 |
+
outputs["attentions"].append(out["attentions"])
|
187 |
+
outputs["last_hidden_state"] = torch.cat(outputs["last_hidden_state"], dim=0)
|
188 |
+
if output_attentions:
|
189 |
+
outputs["attentions"] = tuple([torch.cat([x[i] for x in outputs["attentions"]], dim=0) for i in range(len(outputs["attentions"][0]))])
|
190 |
+
else:
|
191 |
+
outputs = self.transformer(
|
192 |
+
input_ids=input_ids,
|
193 |
+
attention_mask=attention_mask,
|
194 |
+
return_dict=True,
|
195 |
+
output_hidden_states=False,
|
196 |
+
output_attentions=output_attentions,
|
197 |
+
)
|
198 |
+
|
199 |
# at this point, we're embedding individual "comments"
|
200 |
comment_embeddings = self.mean_pooling(outputs['last_hidden_state'], attention_mask)
|
201 |
comment_embeddings = rearrange(comment_embeddings, '(b e) l -> b e l', b=B, e=E)
|
|
|
211 |
|
212 |
return episode_embeddings
|
213 |
|
214 |
+
def forward(self, input_ids, attention_mask, output_attentions=False, document_batch_size=0):
|
215 |
"""Calculates a fixed-length feature vector for a batch of episode samples.
|
216 |
"""
|
217 |
+
output = self.get_episode_embeddings(input_ids, attention_mask, output_attentions, document_batch_size)
|
218 |
|
219 |
return output
|