JonasGeiping commited on
Commit
f39aa4c
·
verified ·
1 Parent(s): a8df662

Update raven_modeling_minimal.py

Browse files
Files changed (1) hide show
  1. raven_modeling_minimal.py +1 -1
raven_modeling_minimal.py CHANGED
@@ -445,7 +445,7 @@ class RavenForCausalLM(RavenPreTrainedModel):
445
  input_ids = input_ids[:, cache_position] # type: ignore
446
  model_inputs["input_ids"] = input_ids.clone(memory_format=torch.contiguous_format)
447
 
448
- position_ids = torch.arange(current_input_length)[None, :]
449
  model_inputs["position_ids"] = position_ids[:, -current_input_length:].clone(
450
  memory_format=torch.contiguous_format
451
  ) # positions_ids is a critical argument for the model to correctly apply rope!
 
445
  input_ids = input_ids[:, cache_position] # type: ignore
446
  model_inputs["input_ids"] = input_ids.clone(memory_format=torch.contiguous_format)
447
 
448
+ position_ids = torch.arange(current_input_length)[None, :].to(input_ids.device)
449
  model_inputs["position_ids"] = position_ids[:, -current_input_length:].clone(
450
  memory_format=torch.contiguous_format
451
  ) # positions_ids is a critical argument for the model to correctly apply rope!