nz commited on
Commit
4df5cef
·
1 Parent(s): ac3299c

Update rita_modeling.py

Browse files
Files changed (1) hide show
  1. rita_modeling.py +3 -3
rita_modeling.py CHANGED
@@ -217,15 +217,15 @@ class RITAModel(PreTrainedModel):
217
  config
218
  ):
219
  super().__init__(config)
220
- self.embedding = nn.Embedding(config.in_vocab_size, config.d_model)
221
  self.layers = nn.ModuleList([DecoderLayer(config) for _ in range(config.num_layers)])
222
  self.final_norm = nn.LayerNorm(config.d_model)
223
- self.projector = nn.Linear(config.d_model, config.out_vocab_size, bias = False)
224
 
225
  def forward(self, ids, attn_mask=None, padding_mask=None, return_hidden=False) -> torch.FloatTensor:
226
  x = self.embedding(ids) # N x L x D
227
  if attn_mask == None:
228
- attn_mask = (torch.triu(torch.ones(ids.size(1), ids.size(1))) == 0).transpose(0, 1).contiguous()
229
  for layer in self.layers:
230
  x = layer(x, attn_mask=attn_mask, padding_mask=padding_mask)
231
  x = self.final_norm(x) # N x L x D
 
217
  config
218
  ):
219
  super().__init__(config)
220
+ self.embedding = nn.Embedding(config.vocab_size, config.d_model)
221
  self.layers = nn.ModuleList([DecoderLayer(config) for _ in range(config.num_layers)])
222
  self.final_norm = nn.LayerNorm(config.d_model)
223
+ self.projector = nn.Linear(config.d_model, config.vocab_size, bias = False)
224
 
225
  def forward(self, ids, attn_mask=None, padding_mask=None, return_hidden=False) -> torch.FloatTensor:
226
  x = self.embedding(ids) # N x L x D
227
  if attn_mask == None:
228
+ attn_mask = (torch.triu(torch.ones(ids.size(1), ids.size(1))) == 0).transpose(0, 1).contiguous().to(ids.device)
229
  for layer in self.layers:
230
  x = layer(x, attn_mask=attn_mask, padding_mask=padding_mask)
231
  x = self.final_norm(x) # N x L x D