Upload 2 files
Browse files- modeling.py +4 -4
modeling.py
CHANGED
|
@@ -47,7 +47,7 @@ from transformers.utils import (
|
|
| 47 |
replace_return_docstrings,
|
| 48 |
)
|
| 49 |
|
| 50 |
-
from .configuration import ExtendedLlamaConfig
|
| 51 |
|
| 52 |
logger = logging.get_logger(__name__)
|
| 53 |
|
|
@@ -1144,7 +1144,7 @@ class ExtendedLlamaForCausalLM(LlamaPreTrainedModel):
|
|
| 1144 |
|
| 1145 |
_tied_weights_keys = ["lm_head.weight"]
|
| 1146 |
|
| 1147 |
-
def __init__(self, config, external_memories=None):
|
| 1148 |
super().__init__(config)
|
| 1149 |
self.model = ExtendedLlamaModel(config)
|
| 1150 |
self.vocab_size = config.vocab_size
|
|
@@ -1242,9 +1242,9 @@ class ExtendedLlamaForCausalLM(LlamaPreTrainedModel):
|
|
| 1242 |
if (
|
| 1243 |
self.memory_ids is not None and self.memories is None
|
| 1244 |
):
|
|
|
|
| 1245 |
self.memories = self.generate_cache(
|
| 1246 |
-
|
| 1247 |
-
cache_type=self.memory_type,
|
| 1248 |
)
|
| 1249 |
# EM: Remove special tokens from memory cache
|
| 1250 |
if self.remove_special_ids:
|
|
|
|
| 47 |
replace_return_docstrings,
|
| 48 |
)
|
| 49 |
|
| 50 |
+
from emts_clean.src.llama.configuration import ExtendedLlamaConfig
|
| 51 |
|
| 52 |
logger = logging.get_logger(__name__)
|
| 53 |
|
|
|
|
| 1144 |
|
| 1145 |
_tied_weights_keys = ["lm_head.weight"]
|
| 1146 |
|
| 1147 |
+
def __init__(self, config, external_memories:list=None):
|
| 1148 |
super().__init__(config)
|
| 1149 |
self.model = ExtendedLlamaModel(config)
|
| 1150 |
self.vocab_size = config.vocab_size
|
|
|
|
| 1242 |
if (
|
| 1243 |
self.memory_ids is not None and self.memories is None
|
| 1244 |
):
|
| 1245 |
+
self.memory_ids = torch.tensor([self.memory_ids], device=self.device) if type(self.memory_ids)==list else self.memory_ids
|
| 1246 |
self.memories = self.generate_cache(
|
| 1247 |
+
self.memory_ids, cache_type=self.memory_type,
|
|
|
|
| 1248 |
)
|
| 1249 |
# EM: Remove special tokens from memory cache
|
| 1250 |
if self.remove_special_ids:
|