Chengxu Zhuang
commited on
Commit
•
ee546e1
1
Parent(s):
4d7c080
minor fix for config class
Browse files- modeling_flamingo.py +2 -0
modeling_flamingo.py
CHANGED
@@ -16,6 +16,7 @@ from transformers.models.opt.modeling_opt\
|
|
16 |
from transformers import ViTModel
|
17 |
from .utils import exists, freeze_all_layers_, unfreeze_all_layers_
|
18 |
from .flamingo_pytorch import GatedCrossAttentionBlock, PerceiverResampler
|
|
|
19 |
|
20 |
|
21 |
class OPTLearnedPositionalEmbedding(nn.Embedding):
|
@@ -357,6 +358,7 @@ class FlamingoForCausalLM(modeling_opt.OPTForCausalLM):
|
|
357 |
_keys_to_ignore_on_load_missing = [
|
358 |
r"lm_head.weight",
|
359 |
]
|
|
|
360 |
|
361 |
def __init__(self, config):
|
362 |
OPTPreTrainedModel.__init__(self, config)
|
|
|
16 |
from transformers import ViTModel
|
17 |
from .utils import exists, freeze_all_layers_, unfreeze_all_layers_
|
18 |
from .flamingo_pytorch import GatedCrossAttentionBlock, PerceiverResampler
|
19 |
+
from .configuration_flamingo import FlamingoConfig
|
20 |
|
21 |
|
22 |
class OPTLearnedPositionalEmbedding(nn.Embedding):
|
|
|
358 |
_keys_to_ignore_on_load_missing = [
|
359 |
r"lm_head.weight",
|
360 |
]
|
361 |
+
config_class = FlamingoConfig
|
362 |
|
363 |
def __init__(self, config):
|
364 |
OPTPreTrainedModel.__init__(self, config)
|