|
import sys |
|
import open_clip |
|
import torch |
|
|
|
|
|
def load_open_clip(model_name: str = "ViT-B-32-quickgelu", pretrained: str = "laion400m_e32", cache_dir: str = None, device="cpu"): |
|
try: |
|
model, _, transform = open_clip.create_model_and_transforms( |
|
model_name, pretrained='openai', cache_dir=cache_dir, device='cpu' |
|
) |
|
if isinstance(pretrained, str): |
|
checkpoint = torch.load(pretrained, map_location=torch.device('cpu')) |
|
else: |
|
checkpoint = pretrained |
|
if 'vision_encoder_state_dict' in checkpoint.keys(): |
|
model.visual.load_state_dict(checkpoint['vision_encoder_state_dict']) |
|
else: |
|
model.visual.load_state_dict(checkpoint) |
|
except Exception as e: |
|
|
|
print(f'error: {e}', file=sys.stderr) |
|
print('retrying by loading whole model..', file=sys.stderr) |
|
torch.cuda.empty_cache() |
|
model, _, transform = open_clip.create_model_and_transforms( |
|
model_name, pretrained=pretrained, cache_dir=cache_dir, device='cpu' |
|
) |
|
|
|
model = model.to(device) |
|
tokenizer = open_clip.get_tokenizer(model_name) |
|
return model, transform, tokenizer |
|
|