Update raven_modeling_minimal.py
Browse files
raven_modeling_minimal.py
CHANGED
@@ -445,7 +445,7 @@ class RavenForCausalLM(RavenPreTrainedModel):
|
|
445 |
input_ids = input_ids[:, cache_position] # type: ignore
|
446 |
model_inputs["input_ids"] = input_ids.clone(memory_format=torch.contiguous_format)
|
447 |
|
448 |
-
position_ids = torch.arange(current_input_length)[None, :]
|
449 |
model_inputs["position_ids"] = position_ids[:, -current_input_length:].clone(
|
450 |
memory_format=torch.contiguous_format
|
451 |
) # positions_ids is a critical argument for the model to correctly apply rope!
|
|
|
445 |
input_ids = input_ids[:, cache_position] # type: ignore
|
446 |
model_inputs["input_ids"] = input_ids.clone(memory_format=torch.contiguous_format)
|
447 |
|
448 |
+
position_ids = torch.arange(current_input_length)[None, :].to(input_ids.device)
|
449 |
model_inputs["position_ids"] = position_ids[:, -current_input_length:].clone(
|
450 |
memory_format=torch.contiguous_format
|
451 |
) # positions_ids is a critical argument for the model to correctly apply rope!
|