d-Matrix
commited on
Commit
•
dfd01dd
1
Parent(s):
61e212c
Update modeling_gptj.py
Browse files- modeling_gptj.py +16 -2
modeling_gptj.py
CHANGED
@@ -40,6 +40,21 @@ from transformers.utils import (
|
|
40 |
)
|
41 |
from .configuration_gptj import GPTJConfig
|
42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
|
44 |
logger = logging.get_logger(__name__)
|
45 |
|
@@ -53,7 +68,6 @@ GPTJ_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
|
53 |
# See all GPT-J models at https://huggingface.co/models?filter=gptj
|
54 |
]
|
55 |
|
56 |
-
|
57 |
def create_sinusoidal_positions(num_pos: int, dim: int) -> torch.Tensor:
|
58 |
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64) / dim))
|
59 |
sinusoid_inp = torch.einsum(
|
@@ -365,7 +379,7 @@ class GPTJBlock(nn.Module):
|
|
365 |
return outputs # hidden_states, present, (attentions)
|
366 |
|
367 |
|
368 |
-
class GPTJPreTrainedModel(
|
369 |
"""
|
370 |
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
371 |
models.
|
|
|
40 |
)
|
41 |
from .configuration_gptj import GPTJConfig
|
42 |
|
43 |
+
from mltools.dmx import DmxModel
|
44 |
+
|
45 |
+
|
46 |
+
class DmxPreTrainedModel(PreTrainedModel):
|
47 |
+
@classmethod
|
48 |
+
def from_pretrained(cls, *args, **kwargs):
|
49 |
+
_model = super().from_pretrained(*args, **kwargs)
|
50 |
+
_model = DmxModel.from_torch(
|
51 |
+
_model,
|
52 |
+
hf=True,
|
53 |
+
input_names=["input_ids"], # TODO: no hard-coding!!!
|
54 |
+
concrete_args=None,
|
55 |
+
)
|
56 |
+
return _model
|
57 |
+
|
58 |
|
59 |
logger = logging.get_logger(__name__)
|
60 |
|
|
|
68 |
# See all GPT-J models at https://huggingface.co/models?filter=gptj
|
69 |
]
|
70 |
|
|
|
71 |
def create_sinusoidal_positions(num_pos: int, dim: int) -> torch.Tensor:
|
72 |
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64) / dim))
|
73 |
sinusoid_inp = torch.einsum(
|
|
|
379 |
return outputs # hidden_states, present, (attentions)
|
380 |
|
381 |
|
382 |
+
class GPTJPreTrainedModel(DmxPreTrainedModel):
|
383 |
"""
|
384 |
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
385 |
models.
|