JonasGeiping commited on
Commit
ae9243e
·
verified ·
1 Parent(s): 8b64e3a

Update raven_modeling_minimal.py

Browse files
Files changed (1) hide show
  1. raven_modeling_minimal.py +3 -2
raven_modeling_minimal.py CHANGED
@@ -19,7 +19,7 @@ import torch.nn.functional as F
19
  from transformers import GenerationConfig
20
 
21
 
22
- class RavenPreTrainedModel(PreTrainedModel):
23
  config_class = RavenConfig
24
  base_model_prefix = "model"
25
  supports_gradient_checkpointing = True
@@ -32,7 +32,8 @@ class RavenPreTrainedModel(PreTrainedModel):
32
  _supports_static_cache = False
33
 
34
  def _init_weights(self, module):
35
- print("Random Initialization not implemented.")
 
36
 
37
 
38
  @dataclass
 
19
  from transformers import GenerationConfig
20
 
21
 
22
+ class RavenPreTrainedModel(PreTrainedModel, GenerationMixin):
23
  config_class = RavenConfig
24
  base_model_prefix = "model"
25
  supports_gradient_checkpointing = True
 
32
  _supports_static_cache = False
33
 
34
  def _init_weights(self, module):
35
+ if not torch.rand((1,)).is_meta:
36
+ print("Random Initialization not implemented.")
37
 
38
 
39
  @dataclass