xmutly's picture
Upload 294 files
e1aaaac verified
import sys
import open_clip
import torch
import torch.nn.functional as F
import numpy as np
from torchvision.transforms import transforms
from open_flamingo.eval.classification_utils import IMAGENET_1K_CLASS_ID_TO_LABEL
def print_statistics(arr):
# make sure its 1-d
assert len(arr.shape) == 1
print(f'[mean] {arr.mean():.4f} [median] {np.median(arr):.4f} [min] {arr.min():.4f} [max] '
f'{arr.max():.4f} [std] {arr.std():.4f} [n] {len(arr)}\n')
def interpolate_state_dict(m1, beta=0.):
m = {}
try:
m2 = torch.load("/mnt/nsingh/project_multimodal/models/clip-vit-l-visual.pt", map_location='cpu')
except:
m2 = torch.load("/data/naman_deep_singh/project_multimodal/clip-vit-l-visual.pt", map_location='cpu')
for k in m1.keys():
m[k] = (1 - beta) * m1[k] + beta * m2[k]
return m
def load_clip_model(clip_model_name, pretrained, beta=0.):
try: # try loading only visual model
model, _, image_processor = open_clip.create_model_and_transforms(
clip_model_name, pretrained='openai', device='cpu'
)
if pretrained != 'openai':
if isinstance(pretrained, str):
checkpoint = torch.load(pretrained, map_location=torch.device('cpu'))
else:
checkpoint = pretrained
# if beta non-zero interpolate between clean and pretrained ckpts
if beta != 0.:
print("beta", beta)
checkpoint = interpolate_state_dict(pretrained, beta)
if 'vision_encoder_state_dict' in checkpoint.keys(): # tecoa checkpoint
model.visual.load_state_dict(checkpoint['vision_encoder_state_dict'])
else:
model.visual.load_state_dict(checkpoint)
except RuntimeError as e: # try loading whole model
print(f'error: {e}', file=sys.stderr)
print('retrying by loading whole model..', file=sys.stderr)
torch.cuda.empty_cache()
model, _, image_processor = open_clip.create_model_and_transforms(
clip_model_name, pretrained=pretrained, force_quick_gelu=True, device='cpu'
)
model.eval()
# Remove the Normalize transform by creating a new Compose object
preprocessor_no_norm = transforms.Compose(image_processor.transforms[:-1])
normalizer = image_processor.transforms[-1]
return model, preprocessor_no_norm, normalizer
@torch.no_grad()
def get_text_embeddings(model, dataset, texts):
assert not (dataset and texts)
if dataset:
assert dataset == 'imagenet'
if dataset == 'imagenet':
template = 'This is a photo of a {}'
texts = [template.format(c) for c in IMAGENET_1K_CLASS_ID_TO_LABEL.values()]
text_tokens = open_clip.tokenize(texts)
elif texts:
text_tokens = open_clip.tokenize(texts)
embedding_text_labels_norm = []
chunk_size = 500
for i in range(0, len(text_tokens), chunk_size):
el = text_tokens[i:i+chunk_size]
embedding_text_labels_norm.append(
model.model.encode_text(el.cuda(), normalize=True).detach().cpu()
)
embedding_text_labels_norm = torch.cat(embedding_text_labels_norm).T
if dataset == 'imagenet':
assert (embedding_text_labels_norm.shape == (512, 1000)
or embedding_text_labels_norm.shape == (768, 1000)), embedding_text_labels_norm.shape
return embedding_text_labels_norm
@torch.inference_mode()
def compute_accuracy_no_dataloader(model, data, targets, device, batch_size=1000):
# data, targets: tensors
# (in parts copied from autoattack)
train_flag = model.training
model.eval()
n_batches = int(np.ceil(data.shape[0] / batch_size))
n_total = 0
n_correct = 0
for batch_idx in range(n_batches):
start_idx = batch_idx * batch_size
end_idx = min((batch_idx + 1) * batch_size, data.shape[0])
data_batch = data[start_idx:end_idx, :].clone().to(device)
targets_batch = targets[start_idx:end_idx].clone().to(device)
logits = model(data_batch)
confs, preds = F.softmax(logits, dim=1).max(dim=1)
n_total += targets_batch.size(0)
n_correct += (preds.eq(targets_batch).sum()).item()
acc = n_correct / n_total
# print(f'{n_total=}')
# print(f'{n_correct=}')
if train_flag:
model.train()
return acc