Text Generation
Transformers
Safetensors
lola_v1
custom_code
neo-nlp-dev commited on
Commit
6eae452
·
1 Parent(s): 3a7a4b9

updating model class and logo

Browse files

- fixing multi-device training for model
- updating logo to sharpened version

Files changed (2) hide show
  1. lola-logo.png +0 -0
  2. modeling_lola_gpt2.py +7 -1
lola-logo.png CHANGED
modeling_lola_gpt2.py CHANGED
@@ -204,7 +204,7 @@ class LOLAModel(GPT2PreTrainedModel):
204
  if input_ids is not None and inputs_embeds is not None:
205
  raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
206
  elif input_ids is not None:
207
- # self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
208
  input_shape = input_ids.size()
209
  input_ids = input_ids.view(-1, input_shape[-1])
210
  batch_size = input_ids.shape[0]
@@ -537,6 +537,12 @@ class LOLALMHeadModel(GPT2LMHeadModel):
537
  return_dict=True, # Ensure we get a MoeModelOutputWithPast
538
  )
539
  hidden_states = transformer_outputs.last_hidden_state
 
 
 
 
 
 
540
  lm_logits = self.lm_head(hidden_states)
541
 
542
  aux_loss = transformer_outputs.aux_loss if hasattr(transformer_outputs, 'aux_loss') else None
 
204
  if input_ids is not None and inputs_embeds is not None:
205
  raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
206
  elif input_ids is not None:
207
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
208
  input_shape = input_ids.size()
209
  input_ids = input_ids.view(-1, input_shape[-1])
210
  batch_size = input_ids.shape[0]
 
537
  return_dict=True, # Ensure we get a MoeModelOutputWithPast
538
  )
539
  hidden_states = transformer_outputs.last_hidden_state
540
+
541
+ # Set device for model parallelism
542
+ if self.model_parallel:
543
+ torch.cuda.set_device(self.transformer.first_device)
544
+ hidden_states = hidden_states.to(self.lm_head.weight.device)
545
+
546
  lm_logits = self.lm_head(hidden_states)
547
 
548
  aux_loss = transformer_outputs.aux_loss if hasattr(transformer_outputs, 'aux_loss') else None