Cyrile commited on
Commit
40807a1
1 Parent(s): 4a7f514

Update modeling_codegen.py

Browse files

Hello, this small change ensures that the labels are on the correct device.

Files changed (1) hide show
  1. modeling_codegen.py +1 -0
modeling_codegen.py CHANGED
@@ -713,6 +713,7 @@ class CodeGenForCausalLM(CodeGenPreTrainedModel):
713
 
714
  loss = None
715
  if labels is not None:
 
716
  # Shift so that tokens < n predict n
717
  shift_logits = lm_logits[..., :-1, :].contiguous()
718
  shift_labels = labels[..., 1:].contiguous()
 
713
 
714
  loss = None
715
  if labels is not None:
716
+ labels = labels.to(lm_logits.device)
717
  # Shift so that tokens < n predict n
718
  shift_logits = lm_logits[..., :-1, :].contiguous()
719
  shift_labels = labels[..., 1:].contiguous()