fix-normalization-after-truncate
#53
by
jupyterjazz
- opened
- modeling_xlm_roberta.py +5 -7
modeling_xlm_roberta.py
CHANGED
@@ -588,12 +588,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
588 |
embeddings = self.mean_pooling(
|
589 |
token_embs, encoded_input["attention_mask"]
|
590 |
)
|
591 |
-
|
592 |
-
if normalize_embeddings:
|
593 |
-
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
|
594 |
-
|
595 |
-
if convert_to_numpy:
|
596 |
-
embeddings = embeddings.cpu()
|
597 |
all_embeddings.extend(embeddings)
|
598 |
|
599 |
all_embeddings = [all_embeddings[idx] for idx in inverse_permutation]
|
@@ -601,11 +596,14 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
601 |
truncate_dim = truncate_dim or self.config.truncate_dim
|
602 |
if truncate_dim:
|
603 |
all_embeddings = self.truncate_embeddings(all_embeddings, truncate_dim)
|
|
|
|
|
|
|
604 |
|
605 |
if convert_to_tensor:
|
606 |
all_embeddings = torch.stack(all_embeddings)
|
607 |
elif convert_to_numpy:
|
608 |
-
all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
|
609 |
|
610 |
if input_was_string:
|
611 |
all_embeddings = all_embeddings[0]
|
|
|
588 |
embeddings = self.mean_pooling(
|
589 |
token_embs, encoded_input["attention_mask"]
|
590 |
)
|
591 |
+
|
|
|
|
|
|
|
|
|
|
|
592 |
all_embeddings.extend(embeddings)
|
593 |
|
594 |
all_embeddings = [all_embeddings[idx] for idx in inverse_permutation]
|
|
|
596 |
truncate_dim = truncate_dim or self.config.truncate_dim
|
597 |
if truncate_dim:
|
598 |
all_embeddings = self.truncate_embeddings(all_embeddings, truncate_dim)
|
599 |
+
|
600 |
+
if normalize_embeddings:
|
601 |
+
all_embeddings = [torch.nn.functional.normalize(embedding, p=2, dim=0) for embedding in all_embeddings]
|
602 |
|
603 |
if convert_to_tensor:
|
604 |
all_embeddings = torch.stack(all_embeddings)
|
605 |
elif convert_to_numpy:
|
606 |
+
all_embeddings = np.asarray([emb.cpu().numpy() for emb in all_embeddings])
|
607 |
|
608 |
if input_was_string:
|
609 |
all_embeddings = all_embeddings[0]
|