Update rita_modeling.py
Browse files- rita_modeling.py +3 -3
rita_modeling.py
CHANGED
|
@@ -217,15 +217,15 @@ class RITAModel(PreTrainedModel):
|
|
| 217 |
config
|
| 218 |
):
|
| 219 |
super().__init__(config)
|
| 220 |
-
self.embedding = nn.Embedding(config.
|
| 221 |
self.layers = nn.ModuleList([DecoderLayer(config) for _ in range(config.num_layers)])
|
| 222 |
self.final_norm = nn.LayerNorm(config.d_model)
|
| 223 |
-
self.projector = nn.Linear(config.d_model, config.
|
| 224 |
|
| 225 |
def forward(self, ids, attn_mask=None, padding_mask=None, return_hidden=False) -> torch.FloatTensor:
|
| 226 |
x = self.embedding(ids) # N x L x D
|
| 227 |
if attn_mask == None:
|
| 228 |
-
attn_mask = (torch.triu(torch.ones(ids.size(1), ids.size(1))) == 0).transpose(0, 1).contiguous()
|
| 229 |
for layer in self.layers:
|
| 230 |
x = layer(x, attn_mask=attn_mask, padding_mask=padding_mask)
|
| 231 |
x = self.final_norm(x) # N x L x D
|
|
|
|
| 217 |
config
|
| 218 |
):
|
| 219 |
super().__init__(config)
|
| 220 |
+
self.embedding = nn.Embedding(config.vocab_size, config.d_model)
|
| 221 |
self.layers = nn.ModuleList([DecoderLayer(config) for _ in range(config.num_layers)])
|
| 222 |
self.final_norm = nn.LayerNorm(config.d_model)
|
| 223 |
+
self.projector = nn.Linear(config.d_model, config.vocab_size, bias = False)
|
| 224 |
|
| 225 |
def forward(self, ids, attn_mask=None, padding_mask=None, return_hidden=False) -> torch.FloatTensor:
|
| 226 |
x = self.embedding(ids) # N x L x D
|
| 227 |
if attn_mask == None:
|
| 228 |
+
attn_mask = (torch.triu(torch.ones(ids.size(1), ids.size(1))) == 0).transpose(0, 1).contiguous().to(ids.device)
|
| 229 |
for layer in self.layers:
|
| 230 |
x = layer(x, attn_mask=attn_mask, padding_mask=padding_mask)
|
| 231 |
x = self.final_norm(x) # N x L x D
|