xmutly's picture
Upload 294 files
e1aaaac verified
raw
history blame
1.23 kB
import sys
import open_clip
import torch
def load_open_clip(model_name: str = "ViT-B-32-quickgelu", pretrained: str = "laion400m_e32", cache_dir: str = None, device="cpu"):
try:
model, _, transform = open_clip.create_model_and_transforms(
model_name, pretrained='openai', cache_dir=cache_dir, device='cpu'
)
if isinstance(pretrained, str):
checkpoint = torch.load(pretrained, map_location=torch.device('cpu'))
else:
checkpoint = pretrained
if 'vision_encoder_state_dict' in checkpoint.keys(): # tecoa checkpoint
model.visual.load_state_dict(checkpoint['vision_encoder_state_dict'])
else:
model.visual.load_state_dict(checkpoint)
except Exception as e:
# try loading whole model
print(f'error: {e}', file=sys.stderr)
print('retrying by loading whole model..', file=sys.stderr)
torch.cuda.empty_cache()
model, _, transform = open_clip.create_model_and_transforms(
model_name, pretrained=pretrained, cache_dir=cache_dir, device='cpu'
)
model = model.to(device)
tokenizer = open_clip.get_tokenizer(model_name)
return model, transform, tokenizer