import logging |
from contextlib import suppress |
import torch |
import torch.nn.functional as F |
from tqdm import tqdm |
from open_clip import tokenize |
from .imagenet_zeroshot_data import imagenet_classnames, openai_imagenet_template |
def zero_shot_classifier(model, classnames, templates, args): |
with torch.no_grad(): |
zeroshot_weights = [] |
for classname in tqdm(classnames): |
texts = [template(classname) for template in templates] |
texts = tokenize(texts).to(args.device) |
if args.distributed and not args.horovod: |
class_embeddings = model.module.encode_text(texts) |
else: |
class_embeddings = model.encode_text(texts) |
class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0) |
class_embedding /= class_embedding.norm() |
zeroshot_weights.append(class_embedding) |
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(args.device) |
return zeroshot_weights |
def accuracy(output, target, topk=(1,)): |
pred = output.topk(max(topk), 1, True, True)[1].t() |
correct = pred.eq(target.view(1, -1).expand_as(pred)) |
return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk] |
def run(model, classifier, dataloader, args): |
autocast = torch.cuda.amp.autocast if args.precision == 'amp' else suppress |
with torch.no_grad(): |
top1, top5, n = 0., 0., 0. |
for images, target in tqdm(dataloader, unit_scale=args.batch_size): |
images = images.to(args.device) |
target = target.to(args.device) |
with autocast(): |
if args.distributed and not args.horovod: |
if args.distributed_engine == 'fsdp': |
image_features, _, _ = model(image=images, text=None) |
else: |
image_features = model.module.encode_image(images) |
else: |
image_features = model.encode_image(images) |
image_features = F.normalize(image_features, dim=-1) |
logits = 100. * image_features @ classifier |
acc1, acc5 = accuracy(logits, target, topk=(1, 5)) |
top1 += acc1 |
top5 += acc5 |
n += images.size(0) |
top1 = (top1 / n) |
top5 = (top5 / n) |
return top1, top5 |
def zero_shot_eval(model, data, epoch, args): |
if 'imagenet-val' not in data and 'imagenet-v2' not in data: |
return {} |
if args.zeroshot_frequency == 0: |
return {} |
if (epoch % args.zeroshot_frequency) != 0 and epoch != args.epochs: |
return {} |
logging.info('Starting zero-shot imagenet.') |
logging.info('Building zero-shot classifier') |
classifier = zero_shot_classifier(model, imagenet_classnames, openai_imagenet_template, args) |
logging.info('Using classifier') |
results = {} |
if 'imagenet-val' in data: |
top1, top5 = run(model, classifier, data['imagenet-val'].dataloader, args) |
results['imagenet-zeroshot-val-top1'] = top1 |
results['imagenet-zeroshot-val-top5'] = top5 |
if 'imagenet-v2' in data: |
top1, top5 = run(model, classifier, data['imagenet-v2'].dataloader, args) |
results['imagenetv2-zeroshot-val-top1'] = top1 |
results['imagenetv2-zeroshot-val-top5'] = top5 |
logging.info('Finished zero-shot imagenet.') |
return results |