visheratin
commited on
Commit
•
aa00ba6
1
Parent(s):
89d694e
Update nllb_mrl.py
Browse files- nllb_mrl.py +3 -1
nllb_mrl.py
CHANGED
@@ -4,6 +4,7 @@ import torch
|
|
4 |
import torch.nn as nn
|
5 |
import torch.nn.functional as F
|
6 |
from open_clip import create_model, get_tokenizer
|
|
|
7 |
from open_clip.transform import PreprocessCfg, image_transform_v2
|
8 |
from PIL import Image
|
9 |
from transformers import PretrainedConfig, PreTrainedModel
|
@@ -53,7 +54,8 @@ class MatryoshkaNllbClip(PreTrainedModel):
|
|
53 |
self.model = create_model(
|
54 |
config.clip_model_name, output_dict=True
|
55 |
)
|
56 |
-
|
|
|
57 |
self.transform = image_transform_v2(
|
58 |
pp_cfg,
|
59 |
is_train=False,
|
|
|
4 |
import torch.nn as nn
|
5 |
import torch.nn.functional as F
|
6 |
from open_clip import create_model, get_tokenizer
|
7 |
+
from open_clip.pretrained import get_pretrained_cfg
|
8 |
from open_clip.transform import PreprocessCfg, image_transform_v2
|
9 |
from PIL import Image
|
10 |
from transformers import PretrainedConfig, PreTrainedModel
|
|
|
54 |
self.model = create_model(
|
55 |
config.clip_model_name, output_dict=True
|
56 |
)
|
57 |
+
pretrained_config = get_pretrained_cfg(config.clip_model_name, "v1")
|
58 |
+
pp_cfg = PreprocessCfg(pretrained_config.preprocess_cfg)
|
59 |
self.transform = image_transform_v2(
|
60 |
pp_cfg,
|
61 |
is_train=False,
|