Upload 7 files
Browse files- modeling_indictrans.py +13 -0
modeling_indictrans.py
CHANGED
@@ -606,6 +606,17 @@ class IndicTransEncoder(IndicTransPreTrainedModel):
|
|
606 |
# Initialize weights and apply final processing
|
607 |
self.post_init()
|
608 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
609 |
def forward(
|
610 |
self,
|
611 |
input_ids: Optional[torch.Tensor] = None,
|
@@ -745,6 +756,8 @@ class IndicTransEncoder(IndicTransPreTrainedModel):
|
|
745 |
if output_hidden_states:
|
746 |
encoder_states = encoder_states + (hidden_states,)
|
747 |
|
|
|
|
|
748 |
if not return_dict:
|
749 |
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
750 |
return BaseModelOutput(
|
|
|
606 |
# Initialize weights and apply final processing
|
607 |
self.post_init()
|
608 |
|
609 |
+
def get_pooled_representation(self, hidden_states, attention_mask):
|
610 |
+
seqs = torch.clone(hidden_states)
|
611 |
+
seqs[attention_mask == 0] = 0
|
612 |
+
sentence_embedding = seqs.sum(dim=1)
|
613 |
+
weights = 1.0 / ((attention_mask != 0).float().sum(dim=1) + 1e-7)
|
614 |
+
|
615 |
+
sentence_embedding = torch.einsum(
|
616 |
+
"i...,i ->i...", sentence_embedding, weights
|
617 |
+
)
|
618 |
+
return sentence_embedding
|
619 |
+
|
620 |
def forward(
|
621 |
self,
|
622 |
input_ids: Optional[torch.Tensor] = None,
|
|
|
756 |
if output_hidden_states:
|
757 |
encoder_states = encoder_states + (hidden_states,)
|
758 |
|
759 |
+
hidden_states = self.get_pooled_representation(hidden_states, attention_mask)
|
760 |
+
|
761 |
if not return_dict:
|
762 |
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
763 |
return BaseModelOutput(
|