visheratin commited on
Commit
aa00ba6
1 Parent(s): 89d694e

Update nllb_mrl.py

Browse files
Files changed (1) hide show
  1. nllb_mrl.py +3 -1
nllb_mrl.py CHANGED
@@ -4,6 +4,7 @@ import torch
4
  import torch.nn as nn
5
  import torch.nn.functional as F
6
  from open_clip import create_model, get_tokenizer
 
7
  from open_clip.transform import PreprocessCfg, image_transform_v2
8
  from PIL import Image
9
  from transformers import PretrainedConfig, PreTrainedModel
@@ -53,7 +54,8 @@ class MatryoshkaNllbClip(PreTrainedModel):
53
  self.model = create_model(
54
  config.clip_model_name, output_dict=True
55
  )
56
- pp_cfg = PreprocessCfg(**self.model.visual.preprocess_cfg)
 
57
  self.transform = image_transform_v2(
58
  pp_cfg,
59
  is_train=False,
 
4
  import torch.nn as nn
5
  import torch.nn.functional as F
6
  from open_clip import create_model, get_tokenizer
7
+ from open_clip.pretrained import get_pretrained_cfg
8
  from open_clip.transform import PreprocessCfg, image_transform_v2
9
  from PIL import Image
10
  from transformers import PretrainedConfig, PreTrainedModel
 
54
  self.model = create_model(
55
  config.clip_model_name, output_dict=True
56
  )
57
+ pretrained_config = get_pretrained_cfg(config.clip_model_name, "v1")
58
+ pp_cfg = PreprocessCfg(pretrained_config.preprocess_cfg)
59
  self.transform = image_transform_v2(
60
  pp_cfg,
61
  is_train=False,