Chengxu Zhuang commited on
Commit
ee546e1
1 Parent(s): 4d7c080

minor fix for config class

Browse files
Files changed (1) hide show
  1. modeling_flamingo.py +2 -0
modeling_flamingo.py CHANGED
@@ -16,6 +16,7 @@ from transformers.models.opt.modeling_opt\
16
  from transformers import ViTModel
17
  from .utils import exists, freeze_all_layers_, unfreeze_all_layers_
18
  from .flamingo_pytorch import GatedCrossAttentionBlock, PerceiverResampler
 
19
 
20
 
21
  class OPTLearnedPositionalEmbedding(nn.Embedding):
@@ -357,6 +358,7 @@ class FlamingoForCausalLM(modeling_opt.OPTForCausalLM):
357
  _keys_to_ignore_on_load_missing = [
358
  r"lm_head.weight",
359
  ]
 
360
 
361
  def __init__(self, config):
362
  OPTPreTrainedModel.__init__(self, config)
 
16
  from transformers import ViTModel
17
  from .utils import exists, freeze_all_layers_, unfreeze_all_layers_
18
  from .flamingo_pytorch import GatedCrossAttentionBlock, PerceiverResampler
19
+ from .configuration_flamingo import FlamingoConfig
20
 
21
 
22
  class OPTLearnedPositionalEmbedding(nn.Embedding):
 
358
  _keys_to_ignore_on_load_missing = [
359
  r"lm_head.weight",
360
  ]
361
+ config_class = FlamingoConfig
362
 
363
  def __init__(self, config):
364
  OPTPreTrainedModel.__init__(self, config)