import logging import torch import torch.nn.functional as F from tqdm import tqdm from open_clip import get_cast_dtype, get_tokenizer, build_zero_shot_classifier, \ IMAGENET_CLASSNAMES, OPENAI_IMAGENET_TEMPLATES from .precision import get_autocast def accuracy(output, target, topk=(1,)): pred = output.topk(max(topk), 1, True, True)[1].t() correct = pred.eq(target.view(1, -1).expand_as(pred)) return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk] def run(model, classifier, dataloader, args): autocast = get_autocast(args.precision) cast_dtype = get_cast_dtype(args.precision) with torch.no_grad(): top1, top5, n = 0., 0., 0. for images, target in tqdm(dataloader, unit_scale=args.batch_size): images = images.to(args.device) if cast_dtype is not None: images = images.to(dtype=cast_dtype) target = target.to(args.device) with autocast(): # predict image_features = model.encode_image(images) image_features = F.normalize(image_features, dim=-1) logits = 100. * image_features @ classifier # measure accuracy acc1, acc5 = accuracy(logits, target, topk=(1, 5)) top1 += acc1 top5 += acc5 n += images.size(0) top1 = (top1 / n) top5 = (top5 / n) return top1, top5 def zero_shot_eval(model, data, epoch, args): if 'imagenet-val' not in data and 'imagenet-v2' not in data: return {} if args.zeroshot_frequency == 0: return {} if (epoch % args.zeroshot_frequency) != 0 and epoch != args.epochs: return {} if args.distributed and not args.horovod: model = model.module logging.info('Starting zero-shot imagenet.') logging.info('Building zero-shot classifier') autocast = get_autocast(args.precision) with autocast(): tokenizer = get_tokenizer(args.model) classifier = build_zero_shot_classifier( model, tokenizer=tokenizer, classnames=IMAGENET_CLASSNAMES, templates=OPENAI_IMAGENET_TEMPLATES, num_classes_per_batch=10, device=args.device, use_tqdm=True, ) logging.info('Using classifier') results = {} if 'imagenet-val' in data: top1, top5 = run(model, classifier, data['imagenet-val'].dataloader, args) results['imagenet-zeroshot-val-top1'] = top1 results['imagenet-zeroshot-val-top5'] = top5 if 'imagenet-v2' in data: top1, top5 = run(model, classifier, data['imagenet-v2'].dataloader, args) results['imagenetv2-zeroshot-val-top1'] = top1 results['imagenetv2-zeroshot-val-top5'] = top5 logging.info('Finished zero-shot imagenet.') return results