Update raven_modeling_minimal.py
Browse files
raven_modeling_minimal.py
CHANGED
@@ -19,7 +19,7 @@ import torch.nn.functional as F
|
|
19 |
from transformers import GenerationConfig
|
20 |
|
21 |
|
22 |
-
class RavenPreTrainedModel(PreTrainedModel):
|
23 |
config_class = RavenConfig
|
24 |
base_model_prefix = "model"
|
25 |
supports_gradient_checkpointing = True
|
@@ -32,7 +32,8 @@ class RavenPreTrainedModel(PreTrainedModel):
|
|
32 |
_supports_static_cache = False
|
33 |
|
34 |
def _init_weights(self, module):
|
35 |
-
|
|
|
36 |
|
37 |
|
38 |
@dataclass
|
|
|
19 |
from transformers import GenerationConfig
|
20 |
|
21 |
|
22 |
+
class RavenPreTrainedModel(PreTrainedModel, GenerationMixin):
|
23 |
config_class = RavenConfig
|
24 |
base_model_prefix = "model"
|
25 |
supports_gradient_checkpointing = True
|
|
|
32 |
_supports_static_cache = False
|
33 |
|
34 |
def _init_weights(self, module):
|
35 |
+
if not torch.rand((1,)).is_meta:
|
36 |
+
print("Random Initialization not implemented.")
|
37 |
|
38 |
|
39 |
@dataclass
|