SunderAli17 commited on
Commit
5148630
·
verified ·
1 Parent(s): b880666

Create factory.py

Browse files
Files changed (1) hide show
  1. eva_clip/factory.py +517 -0
eva_clip/factory.py ADDED
@@ -0,0 +1,517 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ import pathlib
5
+ import re
6
+ from copy import deepcopy
7
+ from pathlib import Path
8
+ from typing import Optional, Tuple, Union, Dict, Any
9
+ import torch
10
+
11
+ from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
12
+ from .model import CLIP, CustomCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\
13
+ get_cast_dtype
14
+ from .openai import load_openai_model
15
+ from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained, list_pretrained_tags_by_model
16
+ from .transform import image_transform
17
+ from .tokenizer import HFTokenizer, tokenize
18
+ from .utils import resize_clip_pos_embed, resize_evaclip_pos_embed, resize_visual_pos_embed, resize_eva_pos_embed
19
+
20
+
21
+ _MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
22
+ _MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
23
+
24
+
25
+ def _natural_key(string_):
26
+ return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
27
+
28
+
29
+ def _rescan_model_configs():
30
+ global _MODEL_CONFIGS
31
+
32
+ config_ext = ('.json',)
33
+ config_files = []
34
+ for config_path in _MODEL_CONFIG_PATHS:
35
+ if config_path.is_file() and config_path.suffix in config_ext:
36
+ config_files.append(config_path)
37
+ elif config_path.is_dir():
38
+ for ext in config_ext:
39
+ config_files.extend(config_path.glob(f'*{ext}'))
40
+
41
+ for cf in config_files:
42
+ with open(cf, "r", encoding="utf8") as f:
43
+ model_cfg = json.load(f)
44
+ if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')):
45
+ _MODEL_CONFIGS[cf.stem] = model_cfg
46
+
47
+ _MODEL_CONFIGS = dict(sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0])))
48
+
49
+
50
+ _rescan_model_configs() # initial populate of model config registry
51
+
52
+
53
+ def list_models():
54
+ """ enumerate available model architectures based on config files """
55
+ return list(_MODEL_CONFIGS.keys())
56
+
57
+
58
+ def add_model_config(path):
59
+ """ add model config path or file and update registry """
60
+ if not isinstance(path, Path):
61
+ path = Path(path)
62
+ _MODEL_CONFIG_PATHS.append(path)
63
+ _rescan_model_configs()
64
+
65
+
66
+ def get_model_config(model_name):
67
+ if model_name in _MODEL_CONFIGS:
68
+ return deepcopy(_MODEL_CONFIGS[model_name])
69
+ else:
70
+ return None
71
+
72
+
73
+ def get_tokenizer(model_name):
74
+ config = get_model_config(model_name)
75
+ tokenizer = HFTokenizer(config['text_cfg']['hf_tokenizer_name']) if 'hf_tokenizer_name' in config['text_cfg'] else tokenize
76
+ return tokenizer
77
+
78
+
79
+ # loading openai CLIP weights when is_openai=True for training
80
+ def load_state_dict(checkpoint_path: str, map_location: str='cpu', model_key: str='model|module|state_dict', is_openai: bool=False, skip_list: list=[]):
81
+ if is_openai:
82
+ model = torch.jit.load(checkpoint_path, map_location="cpu").eval()
83
+ state_dict = model.state_dict()
84
+ for key in ["input_resolution", "context_length", "vocab_size"]:
85
+ state_dict.pop(key, None)
86
+ else:
87
+ checkpoint = torch.load(checkpoint_path, map_location=map_location)
88
+ for mk in model_key.split('|'):
89
+ if isinstance(checkpoint, dict) and mk in checkpoint:
90
+ state_dict = checkpoint[mk]
91
+ break
92
+ else:
93
+ state_dict = checkpoint
94
+ if next(iter(state_dict.items()))[0].startswith('module'):
95
+ state_dict = {k[7:]: v for k, v in state_dict.items()}
96
+
97
+ for k in skip_list:
98
+ if k in list(state_dict.keys()):
99
+ logging.info(f"Removing key {k} from pretrained checkpoint")
100
+ del state_dict[k]
101
+
102
+ if os.getenv('RoPE') == '1':
103
+ for k in list(state_dict.keys()):
104
+ if 'freqs_cos' in k or 'freqs_sin' in k:
105
+ del state_dict[k]
106
+ return state_dict
107
+
108
+
109
+
110
+ def load_checkpoint(model, checkpoint_path, model_key="model|module|state_dict", strict=True):
111
+ state_dict = load_state_dict(checkpoint_path, model_key=model_key, is_openai=False)
112
+ # detect old format and make compatible with new format
113
+ if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'):
114
+ state_dict = convert_to_custom_text_state_dict(state_dict)
115
+ if 'text.logit_scale' in state_dict and hasattr(model, 'logit_scale'):
116
+ state_dict['logit_scale'] = state_dict['text.logit_scale']
117
+ del state_dict['text.logit_scale']
118
+
119
+ # resize_clip_pos_embed for CLIP and open CLIP
120
+ if 'visual.positional_embedding' in state_dict:
121
+ resize_clip_pos_embed(state_dict, model)
122
+ # specified to eva_vit_model
123
+ elif 'visual.pos_embed' in state_dict:
124
+ resize_evaclip_pos_embed(state_dict, model)
125
+
126
+ # resize_clip_pos_embed(state_dict, model)
127
+ incompatible_keys = model.load_state_dict(state_dict, strict=strict)
128
+ logging.info(f"incompatible_keys.missing_keys: {incompatible_keys.missing_keys}")
129
+ return incompatible_keys
130
+
131
+ def load_clip_visual_state_dict(checkpoint_path: str, map_location: str='cpu', is_openai: bool=False, skip_list:list=[]):
132
+ state_dict = load_state_dict(checkpoint_path, map_location=map_location, is_openai=is_openai, skip_list=skip_list)
133
+
134
+ for k in list(state_dict.keys()):
135
+ if not k.startswith('visual.'):
136
+ del state_dict[k]
137
+ for k in list(state_dict.keys()):
138
+ if k.startswith('visual.'):
139
+ new_k = k[7:]
140
+ state_dict[new_k] = state_dict[k]
141
+ del state_dict[k]
142
+ return state_dict
143
+
144
+ def load_clip_text_state_dict(checkpoint_path: str, map_location: str='cpu', is_openai: bool=False, skip_list:list=[]):
145
+ state_dict = load_state_dict(checkpoint_path, map_location=map_location, is_openai=is_openai, skip_list=skip_list)
146
+
147
+ for k in list(state_dict.keys()):
148
+ if k.startswith('visual.'):
149
+ del state_dict[k]
150
+ return state_dict
151
+
152
+ def get_pretrained_tag(pretrained_model):
153
+ pretrained_model = pretrained_model.lower()
154
+ if "laion" in pretrained_model or "open_clip" in pretrained_model:
155
+ return "open_clip"
156
+ elif "openai" in pretrained_model:
157
+ return "clip"
158
+ elif "eva" in pretrained_model and "clip" in pretrained_model:
159
+ return "eva_clip"
160
+ else:
161
+ return "other"
162
+
163
+ def load_pretrained_checkpoint(
164
+ model,
165
+ visual_checkpoint_path,
166
+ text_checkpoint_path,
167
+ strict=True,
168
+ visual_model=None,
169
+ text_model=None,
170
+ model_key="model|module|state_dict",
171
+ skip_list=[]):
172
+ visual_tag = get_pretrained_tag(visual_model)
173
+ text_tag = get_pretrained_tag(text_model)
174
+
175
+ logging.info(f"num of model state_dict keys: {len(model.state_dict().keys())}")
176
+ visual_incompatible_keys, text_incompatible_keys = None, None
177
+ if visual_checkpoint_path:
178
+ if visual_tag == "eva_clip" or visual_tag == "open_clip":
179
+ visual_state_dict = load_clip_visual_state_dict(visual_checkpoint_path, is_openai=False, skip_list=skip_list)
180
+ elif visual_tag == "clip":
181
+ visual_state_dict = load_clip_visual_state_dict(visual_checkpoint_path, is_openai=True, skip_list=skip_list)
182
+ else:
183
+ visual_state_dict = load_state_dict(visual_checkpoint_path, model_key=model_key, is_openai=False, skip_list=skip_list)
184
+
185
+ # resize_clip_pos_embed for CLIP and open CLIP
186
+ if 'positional_embedding' in visual_state_dict:
187
+ resize_visual_pos_embed(visual_state_dict, model)
188
+ # specified to EVA model
189
+ elif 'pos_embed' in visual_state_dict:
190
+ resize_eva_pos_embed(visual_state_dict, model)
191
+
192
+ visual_incompatible_keys = model.visual.load_state_dict(visual_state_dict, strict=strict)
193
+ logging.info(f"num of loaded visual_state_dict keys: {len(visual_state_dict.keys())}")
194
+ logging.info(f"visual_incompatible_keys.missing_keys: {visual_incompatible_keys.missing_keys}")
195
+
196
+ if text_checkpoint_path:
197
+ if text_tag == "eva_clip" or text_tag == "open_clip":
198
+ text_state_dict = load_clip_text_state_dict(text_checkpoint_path, is_openai=False, skip_list=skip_list)
199
+ elif text_tag == "clip":
200
+ text_state_dict = load_clip_text_state_dict(text_checkpoint_path, is_openai=True, skip_list=skip_list)
201
+ else:
202
+ text_state_dict = load_state_dict(visual_checkpoint_path, model_key=model_key, is_openai=False, skip_list=skip_list)
203
+
204
+ text_incompatible_keys = model.text.load_state_dict(text_state_dict, strict=strict)
205
+
206
+ logging.info(f"num of loaded text_state_dict keys: {len(text_state_dict.keys())}")
207
+ logging.info(f"text_incompatible_keys.missing_keys: {text_incompatible_keys.missing_keys}")
208
+
209
+ return visual_incompatible_keys, text_incompatible_keys
210
+
211
+ def create_model(
212
+ model_name: str,
213
+ pretrained: Optional[str] = None,
214
+ precision: str = 'fp32',
215
+ device: Union[str, torch.device] = 'cpu',
216
+ jit: bool = False,
217
+ force_quick_gelu: bool = False,
218
+ force_custom_clip: bool = False,
219
+ force_patch_dropout: Optional[float] = None,
220
+ pretrained_image: str = '',
221
+ pretrained_text: str = '',
222
+ pretrained_hf: bool = True,
223
+ pretrained_visual_model: str = None,
224
+ pretrained_text_model: str = None,
225
+ cache_dir: Optional[str] = None,
226
+ skip_list: list = [],
227
+ ):
228
+ model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names
229
+ if isinstance(device, str):
230
+ device = torch.device(device)
231
+
232
+ if pretrained and pretrained.lower() == 'openai':
233
+ logging.info(f'Loading pretrained {model_name} from OpenAI.')
234
+ model = load_openai_model(
235
+ model_name,
236
+ precision=precision,
237
+ device=device,
238
+ jit=jit,
239
+ cache_dir=cache_dir,
240
+ )
241
+ else:
242
+ model_cfg = get_model_config(model_name)
243
+ if model_cfg is not None:
244
+ logging.info(f'Loaded {model_name} model config.')
245
+ else:
246
+ logging.error(f'Model config for {model_name} not found; available models {list_models()}.')
247
+ raise RuntimeError(f'Model config for {model_name} not found.')
248
+
249
+ if 'rope' in model_cfg.get('vision_cfg', {}):
250
+ if model_cfg['vision_cfg']['rope']:
251
+ os.environ['RoPE'] = "1"
252
+ else:
253
+ os.environ['RoPE'] = "0"
254
+
255
+ if force_quick_gelu:
256
+ # override for use of QuickGELU on non-OpenAI transformer models
257
+ model_cfg["quick_gelu"] = True
258
+
259
+ if force_patch_dropout is not None:
260
+ # override the default patch dropout value
261
+ model_cfg['vision_cfg']["patch_dropout"] = force_patch_dropout
262
+
263
+ cast_dtype = get_cast_dtype(precision)
264
+ custom_clip = model_cfg.pop('custom_text', False) or force_custom_clip or ('hf_model_name' in model_cfg['text_cfg'])
265
+
266
+
267
+ if custom_clip:
268
+ if 'hf_model_name' in model_cfg.get('text_cfg', {}):
269
+ model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf
270
+ model = CustomCLIP(**model_cfg, cast_dtype=cast_dtype)
271
+ else:
272
+ model = CLIP(**model_cfg, cast_dtype=cast_dtype)
273
+
274
+ pretrained_cfg = {}
275
+ if pretrained:
276
+ checkpoint_path = ''
277
+ pretrained_cfg = get_pretrained_cfg(model_name, pretrained)
278
+ if pretrained_cfg:
279
+ checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir)
280
+ elif os.path.exists(pretrained):
281
+ checkpoint_path = pretrained
282
+
283
+ if checkpoint_path:
284
+ logging.info(f'Loading pretrained {model_name} weights ({pretrained}).')
285
+ load_checkpoint(model,
286
+ checkpoint_path,
287
+ model_key="model|module|state_dict",
288
+ strict=False
289
+ )
290
+ else:
291
+ error_str = (
292
+ f'Pretrained weights ({pretrained}) not found for model {model_name}.'
293
+ f'Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.')
294
+ logging.warning(error_str)
295
+ raise RuntimeError(error_str)
296
+ else:
297
+ visual_checkpoint_path = ''
298
+ text_checkpoint_path = ''
299
+
300
+ if pretrained_image:
301
+ pretrained_visual_model = pretrained_visual_model.replace('/', '-') # for callers using old naming with / in ViT names
302
+ pretrained_image_cfg = get_pretrained_cfg(pretrained_visual_model, pretrained_image)
303
+ if 'timm_model_name' in model_cfg.get('vision_cfg', {}):
304
+ # pretrained weight loading for timm models set via vision_cfg
305
+ model_cfg['vision_cfg']['timm_model_pretrained'] = True
306
+ elif pretrained_image_cfg:
307
+ visual_checkpoint_path = download_pretrained(pretrained_image_cfg, cache_dir=cache_dir)
308
+ elif os.path.exists(pretrained_image):
309
+ visual_checkpoint_path = pretrained_image
310
+ else:
311
+ logging.warning(f'Pretrained weights ({visual_checkpoint_path}) not found for model {model_name}.visual.')
312
+ raise RuntimeError(f'Pretrained weights ({visual_checkpoint_path}) not found for model {model_name}.visual.')
313
+
314
+ if pretrained_text:
315
+ pretrained_text_model = pretrained_text_model.replace('/', '-') # for callers using old naming with / in ViT names
316
+ pretrained_text_cfg = get_pretrained_cfg(pretrained_text_model, pretrained_text)
317
+ if pretrained_image_cfg:
318
+ text_checkpoint_path = download_pretrained(pretrained_text_cfg, cache_dir=cache_dir)
319
+ elif os.path.exists(pretrained_text):
320
+ text_checkpoint_path = pretrained_text
321
+ else:
322
+ logging.warning(f'Pretrained weights ({text_checkpoint_path}) not found for model {model_name}.text.')
323
+ raise RuntimeError(f'Pretrained weights ({text_checkpoint_path}) not found for model {model_name}.text.')
324
+
325
+ if visual_checkpoint_path:
326
+ logging.info(f'Loading pretrained {model_name}.visual weights ({visual_checkpoint_path}).')
327
+ if text_checkpoint_path:
328
+ logging.info(f'Loading pretrained {model_name}.text weights ({text_checkpoint_path}).')
329
+
330
+ if visual_checkpoint_path or text_checkpoint_path:
331
+ load_pretrained_checkpoint(
332
+ model,
333
+ visual_checkpoint_path,
334
+ text_checkpoint_path,
335
+ strict=False,
336
+ visual_model=pretrained_visual_model,
337
+ text_model=pretrained_text_model,
338
+ model_key="model|module|state_dict",
339
+ skip_list=skip_list
340
+ )
341
+
342
+ if "fp16" in precision or "bf16" in precision:
343
+ logging.info(f'convert precision to {precision}')
344
+ model = model.to(torch.bfloat16) if 'bf16' in precision else model.to(torch.float16)
345
+
346
+ model.to(device=device)
347
+
348
+ # set image / mean metadata from pretrained_cfg if available, or use default
349
+ model.visual.image_mean = pretrained_cfg.get('mean', None) or OPENAI_DATASET_MEAN
350
+ model.visual.image_std = pretrained_cfg.get('std', None) or OPENAI_DATASET_STD
351
+
352
+ if jit:
353
+ model = torch.jit.script(model)
354
+
355
+ return model
356
+
357
+
358
+ def create_model_and_transforms(
359
+ model_name: str,
360
+ pretrained: Optional[str] = None,
361
+ precision: str = 'fp32',
362
+ device: Union[str, torch.device] = 'cpu',
363
+ jit: bool = False,
364
+ force_quick_gelu: bool = False,
365
+ force_custom_clip: bool = False,
366
+ force_patch_dropout: Optional[float] = None,
367
+ pretrained_image: str = '',
368
+ pretrained_text: str = '',
369
+ pretrained_hf: bool = True,
370
+ pretrained_visual_model: str = None,
371
+ pretrained_text_model: str = None,
372
+ image_mean: Optional[Tuple[float, ...]] = None,
373
+ image_std: Optional[Tuple[float, ...]] = None,
374
+ cache_dir: Optional[str] = None,
375
+ skip_list: list = [],
376
+ ):
377
+ model = create_model(
378
+ model_name,
379
+ pretrained,
380
+ precision=precision,
381
+ device=device,
382
+ jit=jit,
383
+ force_quick_gelu=force_quick_gelu,
384
+ force_custom_clip=force_custom_clip,
385
+ force_patch_dropout=force_patch_dropout,
386
+ pretrained_image=pretrained_image,
387
+ pretrained_text=pretrained_text,
388
+ pretrained_hf=pretrained_hf,
389
+ pretrained_visual_model=pretrained_visual_model,
390
+ pretrained_text_model=pretrained_text_model,
391
+ cache_dir=cache_dir,
392
+ skip_list=skip_list,
393
+ )
394
+
395
+ image_mean = image_mean or getattr(model.visual, 'image_mean', None)
396
+ image_std = image_std or getattr(model.visual, 'image_std', None)
397
+ preprocess_train = image_transform(
398
+ model.visual.image_size,
399
+ is_train=True,
400
+ mean=image_mean,
401
+ std=image_std
402
+ )
403
+ preprocess_val = image_transform(
404
+ model.visual.image_size,
405
+ is_train=False,
406
+ mean=image_mean,
407
+ std=image_std
408
+ )
409
+
410
+ return model, preprocess_train, preprocess_val
411
+
412
+
413
+ def create_transforms(
414
+ model_name: str,
415
+ pretrained: Optional[str] = None,
416
+ precision: str = 'fp32',
417
+ device: Union[str, torch.device] = 'cpu',
418
+ jit: bool = False,
419
+ force_quick_gelu: bool = False,
420
+ force_custom_clip: bool = False,
421
+ force_patch_dropout: Optional[float] = None,
422
+ pretrained_image: str = '',
423
+ pretrained_text: str = '',
424
+ pretrained_hf: bool = True,
425
+ pretrained_visual_model: str = None,
426
+ pretrained_text_model: str = None,
427
+ image_mean: Optional[Tuple[float, ...]] = None,
428
+ image_std: Optional[Tuple[float, ...]] = None,
429
+ cache_dir: Optional[str] = None,
430
+ skip_list: list = [],
431
+ ):
432
+ model = create_model(
433
+ model_name,
434
+ pretrained,
435
+ precision=precision,
436
+ device=device,
437
+ jit=jit,
438
+ force_quick_gelu=force_quick_gelu,
439
+ force_custom_clip=force_custom_clip,
440
+ force_patch_dropout=force_patch_dropout,
441
+ pretrained_image=pretrained_image,
442
+ pretrained_text=pretrained_text,
443
+ pretrained_hf=pretrained_hf,
444
+ pretrained_visual_model=pretrained_visual_model,
445
+ pretrained_text_model=pretrained_text_model,
446
+ cache_dir=cache_dir,
447
+ skip_list=skip_list,
448
+ )
449
+
450
+
451
+ image_mean = image_mean or getattr(model.visual, 'image_mean', None)
452
+ image_std = image_std or getattr(model.visual, 'image_std', None)
453
+ preprocess_train = image_transform(
454
+ model.visual.image_size,
455
+ is_train=True,
456
+ mean=image_mean,
457
+ std=image_std
458
+ )
459
+ preprocess_val = image_transform(
460
+ model.visual.image_size,
461
+ is_train=False,
462
+ mean=image_mean,
463
+ std=image_std
464
+ )
465
+ del model
466
+
467
+ return preprocess_train, preprocess_val
468
+
469
+ def create_model_from_pretrained(
470
+ model_name: str,
471
+ pretrained: str,
472
+ precision: str = 'fp32',
473
+ device: Union[str, torch.device] = 'cpu',
474
+ jit: bool = False,
475
+ force_quick_gelu: bool = False,
476
+ force_custom_clip: bool = False,
477
+ force_patch_dropout: Optional[float] = None,
478
+ return_transform: bool = True,
479
+ image_mean: Optional[Tuple[float, ...]] = None,
480
+ image_std: Optional[Tuple[float, ...]] = None,
481
+ cache_dir: Optional[str] = None,
482
+ is_frozen: bool = False,
483
+ ):
484
+ if not is_pretrained_cfg(model_name, pretrained) and not os.path.exists(pretrained):
485
+ raise RuntimeError(
486
+ f'{pretrained} is not a valid pretrained cfg or checkpoint for {model_name}.'
487
+ f' Use open_clip.list_pretrained() to find one.')
488
+
489
+ model = create_model(
490
+ model_name,
491
+ pretrained,
492
+ precision=precision,
493
+ device=device,
494
+ jit=jit,
495
+ force_quick_gelu=force_quick_gelu,
496
+ force_custom_clip=force_custom_clip,
497
+ force_patch_dropout=force_patch_dropout,
498
+ cache_dir=cache_dir,
499
+ )
500
+
501
+ if is_frozen:
502
+ for param in model.parameters():
503
+ param.requires_grad = False
504
+
505
+ if not return_transform:
506
+ return model
507
+
508
+ image_mean = image_mean or getattr(model.visual, 'image_mean', None)
509
+ image_std = image_std or getattr(model.visual, 'image_std', None)
510
+ preprocess = image_transform(
511
+ model.visual.image_size,
512
+ is_train=False,
513
+ mean=image_mean,
514
+ std=image_std
515
+ )
516
+
517
+ return model, preprocess