# Copyright (c) Meta Platforms, Inc. and affiliates import json import logging import math import os import time from contextlib import suppress import numpy as np import torch import torch.nn.functional as F import collections from collections import defaultdict try: import wandb except ImportError: wandb = None from open_clip import ClipLoss, get_mean_std from .distributed import is_master, world_info_from_env from .zero_shot import zero_shot_eval def save_checkpoint(model, optimizer, scaler, epoch, i, args): checkpoint_dict = { "epoch": epoch, "epoch_step": i, # inner loop saves step and args.resume in main.py will decide if a checkpoint is saved by innerloop or epoch loop (in main). "name": args.name, "state_dict": model.state_dict(), "optimizer": optimizer.state_dict(), } if scaler is not None: checkpoint_dict["scaler"] = scaler.state_dict() # Saving checkpoints. use eval_steps to save a checkpoint. if args.save_logs: # master_only. # epoch saving is removed. only save `epoch_latest.pt`. if args.save_most_recent: torch.save( checkpoint_dict, os.path.join(args.checkpoint_path, f"epoch_latest.pt"), ) class AverageMeter(object): """Computes and stores the average and current value""" def __init__(self): self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count def unwrap_model(model): if hasattr(model, 'module'): return model.module else: return model def to_device(batch, device, args): images, texts = batch images = images.to(device=device, non_blocking=True) if hasattr(args, "inmem") and args.inmem: images = images.to(torch.float32).div_(255.) # b, 3, 224, 224 mean, std = get_mean_std(args) mean = torch.as_tensor(mean, device=images.device)[None, :, None, None] std = torch.as_tensor(std, device=images.device)[None, :, None, None] images.sub_(mean).div_(std) texts = texts.to(device=device, non_blocking=True) return images, texts def train_one_epoch_ex(model, data, epoch, epoch_step, optimizer, scaler, scheduler, args, tb_writer=None): device = torch.device(args.device) autocast = torch.cuda.amp.autocast if args.precision == 'amp' else suppress model.train() from open_clip import loss if hasattr(args, "loss"): loss_cls = getattr(loss, args.loss) else: loss_cls = getattr(loss, "ClipLoss") loss = loss_cls( local_loss=args.local_loss, gather_with_grad=args.gather_with_grad, cache_labels=True, rank=args.rank, world_size=args.world_size, use_horovod=args.horovod) data['train'].set_epoch(epoch) # set epoch in process safe manner via sampler or shared_epoch dataloader = data['train'].dataloader num_batches_per_epoch = dataloader.num_batches sample_digits = math.ceil(math.log(dataloader.num_samples + 1, 10)) loss_m = AverageMeter() batch_time_m = AverageMeter() data_time_m = AverageMeter() end = time.time() if hasattr(args, "one_iter") and args.one_iter is True: # hack for big dataset using one iterator to run across 400M epoch. if not hasattr(data['train'], "dataloader_iter"): print(f"running dataloader across epochs ({args.train_num_samples} examples per epoch).") data['train'].dataloader_iter = iter(dataloader) batch_iter = data['train'].dataloader_iter else: batch_iter = iter(dataloader) for i in range(num_batches_per_epoch): if i < epoch_step: # skip to the right i when resuming happens. continue batch = next(batch_iter) step = num_batches_per_epoch * epoch + i scheduler(step) images, texts = to_device(batch, device, args) data_time_m.update(time.time() - end) optimizer.zero_grad() with autocast(): image_features, text_features, logit_scale = model(images, texts) total_loss = loss(image_features, text_features, logit_scale) if torch.isfinite(total_loss).all(): if scaler is not None: scaler.scale(total_loss).backward() if args.horovod: optimizer.synchronize() scaler.unscale_(optimizer) if args.norm_gradient_clip is not None: torch.nn.utils.clip_grad_norm_(model.parameters(), args.norm_gradient_clip, norm_type=2.0) with optimizer.skip_synchronize(): scaler.step(optimizer) else: if args.norm_gradient_clip is not None: scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), args.norm_gradient_clip, norm_type=2.0) scaler.step(optimizer) scaler.update() else: total_loss.backward() if args.norm_gradient_clip is not None: torch.nn.utils.clip_grad_norm_(model.parameters(), args.norm_gradient_clip, norm_type=2.0) optimizer.step() # Note: we clamp to 4.6052 = ln(100), as in the original paper. with torch.no_grad(): unwrap_model(model).logit_scale.clamp_(0, math.log(100)) else: logging.warn(f"Loss is {total_loss}, skip back prop.") import sys sys.exit(1) # protect the checkpoint for debugging. batch_time_m.update(time.time() - end) end = time.time() batch_count = i + 1 if is_master(args) and (i % 100 == 0 or batch_count == num_batches_per_epoch): batch_size = len(images) num_samples = batch_count * batch_size * args.world_size samples_per_epoch = dataloader.num_samples percent_complete = 100.0 * batch_count / num_batches_per_epoch # NOTE loss is coarsely sampled, just master node and per log update loss_m.update(total_loss.item(), batch_size) logit_scale_scalar = logit_scale.item() logging.info( f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] " f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) " f"Data (t): {data_time_m.avg:.3f} " f"Batch (t): {batch_time_m.avg:.3f}, {args.batch_size*args.world_size / batch_time_m.val:#g}/s " f"LR: {optimizer.param_groups[0]['lr']:5f} " f"Logit Scale: {logit_scale_scalar:.3f}" ) # Save train loss / etc. Using non avg meter values as loggers have their own smoothing log_data = { "loss": loss_m.val, "data_time": data_time_m.val, "batch_time": batch_time_m.val, "samples_per_scond": args.batch_size*args.world_size / batch_time_m.val, "scale": logit_scale_scalar, "lr": optimizer.param_groups[0]["lr"] } for name, val in log_data.items(): name = "train/" + name if tb_writer is not None: tb_writer.add_scalar(name, val, step) if args.wandb: assert wandb is not None, 'Please install wandb.' wandb.log({name: val, 'step': step}) # resetting batch / data time meters per log window batch_time_m.reset() data_time_m.reset() if hasattr(args, "save_steps") and (step + 1) % args.save_steps == 0: save_checkpoint(model, optimizer, scaler, epoch, i, args) # TODO: copied from main.py, wrap as a function call. if hasattr(args, "eval_steps") and (step + 1) % args.eval_steps == 0: # TODO (huxu): put eval on master only? if any(v in data for v in ('val', 'imagenet-val', 'imagenet-v2')): evaluate_ex(model, data, step, args, tb_writer) # completed_epoch -> epoch, writer -> tb_writer save_checkpoint(model, optimizer, scaler, epoch, i, args) model.train() # evaluate won't turn model back to train.""" # end for def train_one_epoch(model, data, epoch, optimizer, scaler, scheduler, args, tb_writer=None): device = torch.device(args.device) autocast = torch.cuda.amp.autocast if args.precision == 'amp' else suppress model.train() loss = ClipLoss( local_loss=args.local_loss, gather_with_grad=args.gather_with_grad, cache_labels=True, rank=args.rank, world_size=args.world_size, use_horovod=args.horovod) data['train'].set_epoch(epoch) # set epoch in process safe manner via sampler or shared_epoch dataloader = data['train'].dataloader num_batches_per_epoch = dataloader.num_batches sample_digits = math.ceil(math.log(dataloader.num_samples + 1, 10)) loss_m = AverageMeter() batch_time_m = AverageMeter() data_time_m = AverageMeter() end = time.time() for i, batch in enumerate(dataloader): step = num_batches_per_epoch * epoch + i scheduler(step) images, texts = to_device(batch, device, args) data_time_m.update(time.time() - end) optimizer.zero_grad() with autocast(): image_features, text_features, logit_scale = model(images, texts) total_loss = loss(image_features, text_features, logit_scale) if scaler is not None: scaler.scale(total_loss).backward() if args.horovod: optimizer.synchronize() scaler.unscale_(optimizer) if args.norm_gradient_clip is not None: torch.nn.utils.clip_grad_norm_(model.parameters(), args.norm_gradient_clip, norm_type=2.0) with optimizer.skip_synchronize(): scaler.step(optimizer) else: if args.norm_gradient_clip is not None: scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), args.norm_gradient_clip, norm_type=2.0) scaler.step(optimizer) scaler.update() else: total_loss.backward() if args.norm_gradient_clip is not None: torch.nn.utils.clip_grad_norm_(model.parameters(), args.norm_gradient_clip, norm_type=2.0) optimizer.step() # Note: we clamp to 4.6052 = ln(100), as in the original paper. with torch.no_grad(): unwrap_model(model).logit_scale.clamp_(0, math.log(100)) batch_time_m.update(time.time() - end) end = time.time() batch_count = i + 1 if is_master(args) and (i % 100 == 0 or batch_count == num_batches_per_epoch): batch_size = len(images) num_samples = batch_count * batch_size * args.world_size samples_per_epoch = dataloader.num_samples percent_complete = 100.0 * batch_count / num_batches_per_epoch # NOTE loss is coarsely sampled, just master node and per log update loss_m.update(total_loss.item(), batch_size) logit_scale_scalar = logit_scale.item() logging.info( f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] " f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) " f"Data (t): {data_time_m.avg:.3f} " f"Batch (t): {batch_time_m.avg:.3f}, {args.batch_size*args.world_size / batch_time_m.val:#g}/s " f"LR: {optimizer.param_groups[0]['lr']:5f} " f"Logit Scale: {logit_scale_scalar:.3f}" ) # Save train loss / etc. Using non avg meter values as loggers have their own smoothing log_data = { "loss": loss_m.val, "data_time": data_time_m.val, "batch_time": batch_time_m.val, "samples_per_scond": args.batch_size*args.world_size / batch_time_m.val, "scale": logit_scale_scalar, "lr": optimizer.param_groups[0]["lr"] } for name, val in log_data.items(): name = "train/" + name if tb_writer is not None: tb_writer.add_scalar(name, val, step) if args.wandb: assert wandb is not None, 'Please install wandb.' wandb.log({name: val, 'step': step}) # resetting batch / data time meters per log window batch_time_m.reset() data_time_m.reset() # end for # huxu: used inside train_epoch. def evaluate_ex(model, data, step, args, tb_writer=None): metrics = {} if not is_master(args): return metrics device = torch.device(args.device) model.eval() zero_shot_metrics = zero_shot_eval(model, data, 0, args) # huxu: epoch = 0 as a trick to bypass checking. metrics.update(zero_shot_metrics) autocast = torch.cuda.amp.autocast if args.precision == 'amp' else suppress if 'val' in data: # and (args.val_frequency and ((epoch % args.val_frequency) == 0 or epoch == args.epochs)): # huxu: val anytime called. dataloader = data['val'].dataloader num_samples = 0 samples_per_val = dataloader.num_samples # FIXME this does not scale past small eval datasets # all_image_features @ all_text_features will blow up memory and compute very quickly cumulative_loss = 0.0 all_image_features, all_text_features = [], [] with torch.no_grad(): for i, batch in enumerate(dataloader): images, texts = to_device(batch, device, args) with autocast(): image_features, text_features, logit_scale = model(images, texts) # features are accumulated in CPU tensors, otherwise GPU memory exhausted quickly # however, system RAM is easily exceeded and compute time becomes problematic all_image_features.append(image_features.cpu()) all_text_features.append(text_features.cpu()) logit_scale = logit_scale.mean() logits_per_image = logit_scale * image_features @ text_features.t() logits_per_text = logits_per_image.t() batch_size = images.shape[0] labels = torch.arange(batch_size, device=device).long() total_loss = ( F.cross_entropy(logits_per_image, labels) + F.cross_entropy(logits_per_text, labels) ) / 2 cumulative_loss += total_loss * batch_size num_samples += batch_size if is_master(args) and (i % 100) == 0: logging.info( f"Eval Step: {step} [{num_samples} / {samples_per_val}]\t" f"Loss: {cumulative_loss / num_samples:.6f}\t") val_metrics = get_metrics( image_features=torch.cat(all_image_features), text_features=torch.cat(all_text_features), logit_scale=logit_scale.cpu(), ) loss = cumulative_loss / num_samples metrics.update( {**val_metrics, "val_loss": loss.item(), "step": step, "num_samples": num_samples} ) if not metrics: return metrics logging.info( f"Eval Step: {step} " + "\t".join([f"{k}: {round(v, 4):.4f}" for k, v in metrics.items()]) ) if args.save_logs: for name, val in metrics.items(): if tb_writer is not None: tb_writer.add_scalar(f"val_step/{name}", val, step) with open(os.path.join(args.checkpoint_path, "results.jsonl"), "a+") as f: f.write(json.dumps(metrics)) f.write("\n") if args.wandb: assert wandb is not None, 'Please install wandb.' for name, val in metrics.items(): wandb.log({f"val_step/{name}": val, 'step': step}) return metrics def evaluate(model, data, epoch, args, tb_writer=None): metrics = {} if not is_master(args): return metrics device = torch.device(args.device) model.eval() zero_shot_metrics = zero_shot_eval(model, data, epoch, args) metrics.update(zero_shot_metrics) autocast = torch.cuda.amp.autocast if args.precision == 'amp' else suppress if 'val' in data and (args.val_frequency and ((epoch % args.val_frequency) == 0 or epoch == args.epochs)): dataloader = data['val'].dataloader num_samples = 0 samples_per_val = dataloader.num_samples # FIXME this does not scale past small eval datasets # all_image_features @ all_text_features will blow up memory and compute very quickly cumulative_loss = 0.0 all_image_features, all_text_features = [], [] with torch.no_grad(): for i, batch in enumerate(dataloader): images, texts = to_device(batch, device, args) with autocast(): image_features, text_features, logit_scale = model(images, texts) # features are accumulated in CPU tensors, otherwise GPU memory exhausted quickly # however, system RAM is easily exceeded and compute time becomes problematic all_image_features.append(image_features.cpu()) all_text_features.append(text_features.cpu()) logit_scale = logit_scale.mean() logits_per_image = logit_scale * image_features @ text_features.t() logits_per_text = logits_per_image.t() batch_size = images.shape[0] labels = torch.arange(batch_size, device=device).long() total_loss = ( F.cross_entropy(logits_per_image, labels) + F.cross_entropy(logits_per_text, labels) ) / 2 cumulative_loss += total_loss * batch_size num_samples += batch_size if is_master(args) and (i % 100) == 0: logging.info( f"Eval Epoch: {epoch} [{num_samples} / {samples_per_val}]\t" f"Loss: {cumulative_loss / num_samples:.6f}\t") val_metrics = get_metrics( image_features=torch.cat(all_image_features), text_features=torch.cat(all_text_features), logit_scale=logit_scale.cpu(), ) loss = cumulative_loss / num_samples metrics.update( {**val_metrics, "val_loss": loss.item(), "epoch": epoch, "num_samples": num_samples} ) if not metrics: return metrics logging.info( f"Eval Epoch: {epoch} " + "\t".join([f"{k}: {round(v, 4):.4f}" for k, v in metrics.items()]) ) if args.save_logs: for name, val in metrics.items(): if tb_writer is not None: tb_writer.add_scalar(f"val/{name}", val, epoch) with open(os.path.join(args.checkpoint_path, "results.jsonl"), "a+") as f: f.write(json.dumps(metrics)) f.write("\n") if args.wandb: assert wandb is not None, 'Please install wandb.' for name, val in metrics.items(): wandb.log({f"val/{name}": val, 'epoch': epoch}) return metrics def get_metrics(image_features, text_features, logit_scale): metrics = {} logits_per_image = (logit_scale * image_features @ text_features.t()).detach().cpu() logits_per_text = logits_per_image.t().detach().cpu() logits = {"image_to_text": logits_per_image, "text_to_image": logits_per_text} ground_truth = torch.arange(len(text_features)).view(-1, 1) for name, logit in logits.items(): ranking = torch.argsort(logit, descending=True) preds = torch.where(ranking == ground_truth)[1] preds = preds.detach().cpu().numpy() metrics[f"{name}_mean_rank"] = preds.mean() + 1 metrics[f"{name}_median_rank"] = np.floor(np.median(preds)) + 1 for k in [1, 5, 10]: metrics[f"{name}_R@{k}"] = np.mean(preds < k) return metrics