Update modelling_landmark_llama.py
Browse files
modelling_landmark_llama.py
CHANGED
@@ -32,6 +32,8 @@ from transformers.modeling_utils import PreTrainedModel
|
|
32 |
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
33 |
from transformers.models.llama.configuration_llama import LlamaConfig
|
34 |
|
|
|
|
|
35 |
|
36 |
logger = logging.get_logger(__name__)
|
37 |
|
@@ -565,7 +567,7 @@ LLAMA_START_DOCSTRING = r"""
|
|
565 |
LLAMA_START_DOCSTRING,
|
566 |
)
|
567 |
class LlamaPreTrainedModel(PreTrainedModel):
|
568 |
-
config_class =
|
569 |
base_model_prefix = "model"
|
570 |
supports_gradient_checkpointing = True
|
571 |
_no_split_modules = ["LlamaDecoderLayer"]
|
@@ -873,7 +875,7 @@ class LlamaModel(LlamaPreTrainedModel):
|
|
873 |
|
874 |
|
875 |
class LlamaForCausalLM(LlamaPreTrainedModel):
|
876 |
-
def __init__(self, config):
|
877 |
super().__init__(config)
|
878 |
self.model = LlamaModel(config)
|
879 |
|
|
|
32 |
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
33 |
from transformers.models.llama.configuration_llama import LlamaConfig
|
34 |
|
35 |
+
from configuration_landmark_llama import LlamaConfig as LandmarkLlamaConfig
|
36 |
+
|
37 |
|
38 |
logger = logging.get_logger(__name__)
|
39 |
|
|
|
567 |
LLAMA_START_DOCSTRING,
|
568 |
)
|
569 |
class LlamaPreTrainedModel(PreTrainedModel):
|
570 |
+
config_class = LandmarkLlamaConfig
|
571 |
base_model_prefix = "model"
|
572 |
supports_gradient_checkpointing = True
|
573 |
_no_split_modules = ["LlamaDecoderLayer"]
|
|
|
875 |
|
876 |
|
877 |
class LlamaForCausalLM(LlamaPreTrainedModel):
|
878 |
+
def __init__(self, config: LandmarkLlamaConfig):
|
879 |
super().__init__(config)
|
880 |
self.model = LlamaModel(config)
|
881 |
|