pranjalchitale
commited on
Update TieWeights
Browse files- modeling_indictrans.py +9 -9
modeling_indictrans.py
CHANGED
@@ -1644,7 +1644,7 @@ class IndicTransModel(IndicTransPreTrainedModel):
|
|
1644 |
# Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100ForConditionalGeneration->IndicTrans
|
1645 |
class IndicTransForConditionalGeneration(IndicTransPreTrainedModel, GenerationMixin):
|
1646 |
base_model_prefix = "model"
|
1647 |
-
_tied_weights_keys =
|
1648 |
_label_smoothing = 0.0
|
1649 |
|
1650 |
def __init__(self, config: IndicTransConfig):
|
@@ -1654,19 +1654,20 @@ class IndicTransForConditionalGeneration(IndicTransPreTrainedModel, GenerationMi
|
|
1654 |
config.decoder_embed_dim, config.decoder_vocab_size, bias=False
|
1655 |
)
|
1656 |
|
1657 |
-
if config.share_decoder_input_output_embed:
|
1658 |
-
self.lm_head.weight = self.model.decoder.embed_tokens.weight
|
1659 |
-
|
1660 |
self.post_init()
|
1661 |
|
1662 |
-
def
|
1663 |
-
|
|
|
1664 |
|
1665 |
def get_encoder(self):
|
1666 |
-
return self.model.
|
1667 |
|
1668 |
def get_decoder(self):
|
1669 |
-
return self.model.
|
|
|
|
|
|
|
1670 |
|
1671 |
def get_output_embeddings(self):
|
1672 |
return self.lm_head
|
@@ -1676,7 +1677,6 @@ class IndicTransForConditionalGeneration(IndicTransPreTrainedModel, GenerationMi
|
|
1676 |
|
1677 |
def set_label_smoothing(self, label_smoothing):
|
1678 |
self._label_smoothing = label_smoothing
|
1679 |
-
|
1680 |
def forward(
|
1681 |
self,
|
1682 |
input_ids: Optional[torch.LongTensor] = None,
|
|
|
1644 |
# Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100ForConditionalGeneration->IndicTrans
|
1645 |
class IndicTransForConditionalGeneration(IndicTransPreTrainedModel, GenerationMixin):
|
1646 |
base_model_prefix = "model"
|
1647 |
+
_tied_weights_keys = ["decoder.embed_tokens.weight", "lm_head.weight"]
|
1648 |
_label_smoothing = 0.0
|
1649 |
|
1650 |
def __init__(self, config: IndicTransConfig):
|
|
|
1654 |
config.decoder_embed_dim, config.decoder_vocab_size, bias=False
|
1655 |
)
|
1656 |
|
|
|
|
|
|
|
1657 |
self.post_init()
|
1658 |
|
1659 |
+
def tie_weights(self):
|
1660 |
+
if self.config.share_decoder_input_output_embed:
|
1661 |
+
self._tie_or_clone_weights(self.decoder.embed_tokens, self.lm_head)
|
1662 |
|
1663 |
def get_encoder(self):
|
1664 |
+
return self.model.encoder
|
1665 |
|
1666 |
def get_decoder(self):
|
1667 |
+
return self.model.decoder
|
1668 |
+
|
1669 |
+
def get_input_embeddings(self):
|
1670 |
+
return self.model.encoder.embed_tokens
|
1671 |
|
1672 |
def get_output_embeddings(self):
|
1673 |
return self.lm_head
|
|
|
1677 |
|
1678 |
def set_label_smoothing(self, label_smoothing):
|
1679 |
self._label_smoothing = label_smoothing
|
|
|
1680 |
def forward(
|
1681 |
self,
|
1682 |
input_ids: Optional[torch.LongTensor] = None,
|