|
|
|
|
|
import json |
|
import logging |
|
import math |
|
import os |
|
import time |
|
from contextlib import suppress |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
|
|
import collections |
|
from collections import defaultdict |
|
|
|
try: |
|
import wandb |
|
except ImportError: |
|
wandb = None |
|
|
|
from open_clip import ClipLoss, get_mean_std |
|
from .distributed import is_master, world_info_from_env |
|
from .zero_shot import zero_shot_eval |
|
|
|
|
|
def save_checkpoint(model, optimizer, scaler, epoch, i, args): |
|
checkpoint_dict = { |
|
"epoch": epoch, |
|
"epoch_step": i, |
|
"name": args.name, |
|
"state_dict": model.state_dict(), |
|
"optimizer": optimizer.state_dict(), |
|
} |
|
if scaler is not None: |
|
checkpoint_dict["scaler"] = scaler.state_dict() |
|
|
|
|
|
if args.save_logs: |
|
|
|
if args.save_most_recent: |
|
torch.save( |
|
checkpoint_dict, |
|
os.path.join(args.checkpoint_path, f"epoch_latest.pt"), |
|
) |
|
|
|
|
|
class AverageMeter(object): |
|
"""Computes and stores the average and current value""" |
|
def __init__(self): |
|
self.reset() |
|
|
|
def reset(self): |
|
self.val = 0 |
|
self.avg = 0 |
|
self.sum = 0 |
|
self.count = 0 |
|
|
|
def update(self, val, n=1): |
|
self.val = val |
|
self.sum += val * n |
|
self.count += n |
|
self.avg = self.sum / self.count |
|
|
|
|
|
def unwrap_model(model): |
|
if hasattr(model, 'module'): |
|
return model.module |
|
else: |
|
return model |
|
|
|
|
|
def to_device(batch, device, args): |
|
images, texts = batch |
|
images = images.to(device=device, non_blocking=True) |
|
if hasattr(args, "inmem") and args.inmem: |
|
images = images.to(torch.float32).div_(255.) |
|
mean, std = get_mean_std(args) |
|
mean = torch.as_tensor(mean, device=images.device)[None, :, None, None] |
|
std = torch.as_tensor(std, device=images.device)[None, :, None, None] |
|
images.sub_(mean).div_(std) |
|
texts = texts.to(device=device, non_blocking=True) |
|
return images, texts |
|
|
|
|
|
def train_one_epoch_ex(model, data, epoch, epoch_step, optimizer, scaler, scheduler, args, tb_writer=None): |
|
device = torch.device(args.device) |
|
autocast = torch.cuda.amp.autocast if args.precision == 'amp' else suppress |
|
|
|
model.train() |
|
|
|
from open_clip import loss |
|
if hasattr(args, "loss"): |
|
loss_cls = getattr(loss, args.loss) |
|
else: |
|
loss_cls = getattr(loss, "ClipLoss") |
|
|
|
loss = loss_cls( |
|
local_loss=args.local_loss, |
|
gather_with_grad=args.gather_with_grad, |
|
cache_labels=True, |
|
rank=args.rank, |
|
world_size=args.world_size, |
|
use_horovod=args.horovod) |
|
|
|
data['train'].set_epoch(epoch) |
|
dataloader = data['train'].dataloader |
|
num_batches_per_epoch = dataloader.num_batches |
|
sample_digits = math.ceil(math.log(dataloader.num_samples + 1, 10)) |
|
|
|
loss_m = AverageMeter() |
|
batch_time_m = AverageMeter() |
|
data_time_m = AverageMeter() |
|
end = time.time() |
|
|
|
if hasattr(args, "one_iter") and args.one_iter is True: |
|
|
|
if not hasattr(data['train'], "dataloader_iter"): |
|
print(f"running dataloader across epochs ({args.train_num_samples} examples per epoch).") |
|
data['train'].dataloader_iter = iter(dataloader) |
|
batch_iter = data['train'].dataloader_iter |
|
else: |
|
batch_iter = iter(dataloader) |
|
|
|
for i in range(num_batches_per_epoch): |
|
if i < epoch_step: |
|
continue |
|
batch = next(batch_iter) |
|
step = num_batches_per_epoch * epoch + i |
|
scheduler(step) |
|
|
|
images, texts = to_device(batch, device, args) |
|
|
|
data_time_m.update(time.time() - end) |
|
optimizer.zero_grad() |
|
|
|
with autocast(): |
|
image_features, text_features, logit_scale = model(images, texts) |
|
total_loss = loss(image_features, text_features, logit_scale) |
|
|
|
if torch.isfinite(total_loss).all(): |
|
if scaler is not None: |
|
scaler.scale(total_loss).backward() |
|
|
|
if args.horovod: |
|
optimizer.synchronize() |
|
scaler.unscale_(optimizer) |
|
if args.norm_gradient_clip is not None: |
|
torch.nn.utils.clip_grad_norm_(model.parameters(), args.norm_gradient_clip, norm_type=2.0) |
|
with optimizer.skip_synchronize(): |
|
scaler.step(optimizer) |
|
else: |
|
if args.norm_gradient_clip is not None: |
|
scaler.unscale_(optimizer) |
|
torch.nn.utils.clip_grad_norm_(model.parameters(), args.norm_gradient_clip, norm_type=2.0) |
|
scaler.step(optimizer) |
|
scaler.update() |
|
else: |
|
total_loss.backward() |
|
if args.norm_gradient_clip is not None: |
|
torch.nn.utils.clip_grad_norm_(model.parameters(), args.norm_gradient_clip, norm_type=2.0) |
|
optimizer.step() |
|
|
|
|
|
with torch.no_grad(): |
|
unwrap_model(model).logit_scale.clamp_(0, math.log(100)) |
|
else: |
|
logging.warn(f"Loss is {total_loss}, skip back prop.") |
|
import sys |
|
sys.exit(1) |
|
|
|
|
|
batch_time_m.update(time.time() - end) |
|
end = time.time() |
|
batch_count = i + 1 |
|
if is_master(args) and (i % 100 == 0 or batch_count == num_batches_per_epoch): |
|
batch_size = len(images) |
|
num_samples = batch_count * batch_size * args.world_size |
|
samples_per_epoch = dataloader.num_samples |
|
percent_complete = 100.0 * batch_count / num_batches_per_epoch |
|
|
|
|
|
loss_m.update(total_loss.item(), batch_size) |
|
logit_scale_scalar = logit_scale.item() |
|
logging.info( |
|
f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] " |
|
f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) " |
|
f"Data (t): {data_time_m.avg:.3f} " |
|
f"Batch (t): {batch_time_m.avg:.3f}, {args.batch_size*args.world_size / batch_time_m.val:#g}/s " |
|
f"LR: {optimizer.param_groups[0]['lr']:5f} " |
|
f"Logit Scale: {logit_scale_scalar:.3f}" |
|
) |
|
|
|
|
|
log_data = { |
|
"loss": loss_m.val, |
|
"data_time": data_time_m.val, |
|
"batch_time": batch_time_m.val, |
|
"samples_per_scond": args.batch_size*args.world_size / batch_time_m.val, |
|
"scale": logit_scale_scalar, |
|
"lr": optimizer.param_groups[0]["lr"] |
|
} |
|
for name, val in log_data.items(): |
|
name = "train/" + name |
|
if tb_writer is not None: |
|
tb_writer.add_scalar(name, val, step) |
|
if args.wandb: |
|
assert wandb is not None, 'Please install wandb.' |
|
wandb.log({name: val, 'step': step}) |
|
|
|
|
|
batch_time_m.reset() |
|
data_time_m.reset() |
|
|
|
if hasattr(args, "save_steps") and (step + 1) % args.save_steps == 0: |
|
save_checkpoint(model, optimizer, scaler, epoch, i, args) |
|
|
|
|
|
if hasattr(args, "eval_steps") and (step + 1) % args.eval_steps == 0: |
|
if any(v in data for v in ('val', 'imagenet-val', 'imagenet-v2')): |
|
evaluate_ex(model, data, step, args, tb_writer) |
|
save_checkpoint(model, optimizer, scaler, epoch, i, args) |
|
model.train() |
|
|
|
|
|
|
|
def train_one_epoch(model, data, epoch, optimizer, scaler, scheduler, args, tb_writer=None): |
|
device = torch.device(args.device) |
|
autocast = torch.cuda.amp.autocast if args.precision == 'amp' else suppress |
|
|
|
model.train() |
|
loss = ClipLoss( |
|
local_loss=args.local_loss, |
|
gather_with_grad=args.gather_with_grad, |
|
cache_labels=True, |
|
rank=args.rank, |
|
world_size=args.world_size, |
|
use_horovod=args.horovod) |
|
|
|
data['train'].set_epoch(epoch) |
|
dataloader = data['train'].dataloader |
|
num_batches_per_epoch = dataloader.num_batches |
|
sample_digits = math.ceil(math.log(dataloader.num_samples + 1, 10)) |
|
|
|
loss_m = AverageMeter() |
|
batch_time_m = AverageMeter() |
|
data_time_m = AverageMeter() |
|
end = time.time() |
|
for i, batch in enumerate(dataloader): |
|
step = num_batches_per_epoch * epoch + i |
|
scheduler(step) |
|
|
|
images, texts = to_device(batch, device, args) |
|
|
|
data_time_m.update(time.time() - end) |
|
optimizer.zero_grad() |
|
|
|
with autocast(): |
|
image_features, text_features, logit_scale = model(images, texts) |
|
total_loss = loss(image_features, text_features, logit_scale) |
|
|
|
if scaler is not None: |
|
scaler.scale(total_loss).backward() |
|
if args.horovod: |
|
optimizer.synchronize() |
|
scaler.unscale_(optimizer) |
|
if args.norm_gradient_clip is not None: |
|
torch.nn.utils.clip_grad_norm_(model.parameters(), args.norm_gradient_clip, norm_type=2.0) |
|
with optimizer.skip_synchronize(): |
|
scaler.step(optimizer) |
|
else: |
|
if args.norm_gradient_clip is not None: |
|
scaler.unscale_(optimizer) |
|
torch.nn.utils.clip_grad_norm_(model.parameters(), args.norm_gradient_clip, norm_type=2.0) |
|
scaler.step(optimizer) |
|
scaler.update() |
|
else: |
|
total_loss.backward() |
|
if args.norm_gradient_clip is not None: |
|
torch.nn.utils.clip_grad_norm_(model.parameters(), args.norm_gradient_clip, norm_type=2.0) |
|
optimizer.step() |
|
|
|
|
|
with torch.no_grad(): |
|
unwrap_model(model).logit_scale.clamp_(0, math.log(100)) |
|
|
|
batch_time_m.update(time.time() - end) |
|
end = time.time() |
|
batch_count = i + 1 |
|
if is_master(args) and (i % 100 == 0 or batch_count == num_batches_per_epoch): |
|
batch_size = len(images) |
|
num_samples = batch_count * batch_size * args.world_size |
|
samples_per_epoch = dataloader.num_samples |
|
percent_complete = 100.0 * batch_count / num_batches_per_epoch |
|
|
|
|
|
loss_m.update(total_loss.item(), batch_size) |
|
logit_scale_scalar = logit_scale.item() |
|
logging.info( |
|
f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] " |
|
f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) " |
|
f"Data (t): {data_time_m.avg:.3f} " |
|
f"Batch (t): {batch_time_m.avg:.3f}, {args.batch_size*args.world_size / batch_time_m.val:#g}/s " |
|
f"LR: {optimizer.param_groups[0]['lr']:5f} " |
|
f"Logit Scale: {logit_scale_scalar:.3f}" |
|
) |
|
|
|
|
|
log_data = { |
|
"loss": loss_m.val, |
|
"data_time": data_time_m.val, |
|
"batch_time": batch_time_m.val, |
|
"samples_per_scond": args.batch_size*args.world_size / batch_time_m.val, |
|
"scale": logit_scale_scalar, |
|
"lr": optimizer.param_groups[0]["lr"] |
|
} |
|
for name, val in log_data.items(): |
|
name = "train/" + name |
|
if tb_writer is not None: |
|
tb_writer.add_scalar(name, val, step) |
|
if args.wandb: |
|
assert wandb is not None, 'Please install wandb.' |
|
wandb.log({name: val, 'step': step}) |
|
|
|
|
|
batch_time_m.reset() |
|
data_time_m.reset() |
|
|
|
|
|
|
|
|
|
def evaluate_ex(model, data, step, args, tb_writer=None): |
|
metrics = {} |
|
if not is_master(args): |
|
return metrics |
|
device = torch.device(args.device) |
|
model.eval() |
|
|
|
zero_shot_metrics = zero_shot_eval(model, data, 0, args) |
|
metrics.update(zero_shot_metrics) |
|
|
|
autocast = torch.cuda.amp.autocast if args.precision == 'amp' else suppress |
|
if 'val' in data: |
|
dataloader = data['val'].dataloader |
|
num_samples = 0 |
|
samples_per_val = dataloader.num_samples |
|
|
|
|
|
|
|
cumulative_loss = 0.0 |
|
all_image_features, all_text_features = [], [] |
|
with torch.no_grad(): |
|
for i, batch in enumerate(dataloader): |
|
images, texts = to_device(batch, device, args) |
|
|
|
with autocast(): |
|
image_features, text_features, logit_scale = model(images, texts) |
|
|
|
|
|
all_image_features.append(image_features.cpu()) |
|
all_text_features.append(text_features.cpu()) |
|
logit_scale = logit_scale.mean() |
|
logits_per_image = logit_scale * image_features @ text_features.t() |
|
logits_per_text = logits_per_image.t() |
|
|
|
batch_size = images.shape[0] |
|
labels = torch.arange(batch_size, device=device).long() |
|
total_loss = ( |
|
F.cross_entropy(logits_per_image, labels) + |
|
F.cross_entropy(logits_per_text, labels) |
|
) / 2 |
|
|
|
cumulative_loss += total_loss * batch_size |
|
num_samples += batch_size |
|
if is_master(args) and (i % 100) == 0: |
|
logging.info( |
|
f"Eval Step: {step} [{num_samples} / {samples_per_val}]\t" |
|
f"Loss: {cumulative_loss / num_samples:.6f}\t") |
|
|
|
val_metrics = get_metrics( |
|
image_features=torch.cat(all_image_features), |
|
text_features=torch.cat(all_text_features), |
|
logit_scale=logit_scale.cpu(), |
|
) |
|
loss = cumulative_loss / num_samples |
|
metrics.update( |
|
{**val_metrics, "val_loss": loss.item(), "step": step, "num_samples": num_samples} |
|
) |
|
|
|
if not metrics: |
|
return metrics |
|
|
|
logging.info( |
|
f"Eval Step: {step} " |
|
+ "\t".join([f"{k}: {round(v, 4):.4f}" for k, v in metrics.items()]) |
|
) |
|
|
|
if args.save_logs: |
|
for name, val in metrics.items(): |
|
if tb_writer is not None: |
|
tb_writer.add_scalar(f"val_step/{name}", val, step) |
|
|
|
with open(os.path.join(args.checkpoint_path, "results.jsonl"), "a+") as f: |
|
f.write(json.dumps(metrics)) |
|
f.write("\n") |
|
|
|
if args.wandb: |
|
assert wandb is not None, 'Please install wandb.' |
|
for name, val in metrics.items(): |
|
wandb.log({f"val_step/{name}": val, 'step': step}) |
|
|
|
return metrics |
|
|
|
|
|
def evaluate(model, data, epoch, args, tb_writer=None): |
|
metrics = {} |
|
if not is_master(args): |
|
return metrics |
|
device = torch.device(args.device) |
|
model.eval() |
|
|
|
zero_shot_metrics = zero_shot_eval(model, data, epoch, args) |
|
metrics.update(zero_shot_metrics) |
|
|
|
autocast = torch.cuda.amp.autocast if args.precision == 'amp' else suppress |
|
if 'val' in data and (args.val_frequency and ((epoch % args.val_frequency) == 0 or epoch == args.epochs)): |
|
dataloader = data['val'].dataloader |
|
num_samples = 0 |
|
samples_per_val = dataloader.num_samples |
|
|
|
|
|
|
|
cumulative_loss = 0.0 |
|
all_image_features, all_text_features = [], [] |
|
with torch.no_grad(): |
|
for i, batch in enumerate(dataloader): |
|
images, texts = to_device(batch, device, args) |
|
|
|
with autocast(): |
|
image_features, text_features, logit_scale = model(images, texts) |
|
|
|
|
|
all_image_features.append(image_features.cpu()) |
|
all_text_features.append(text_features.cpu()) |
|
logit_scale = logit_scale.mean() |
|
logits_per_image = logit_scale * image_features @ text_features.t() |
|
logits_per_text = logits_per_image.t() |
|
|
|
batch_size = images.shape[0] |
|
labels = torch.arange(batch_size, device=device).long() |
|
total_loss = ( |
|
F.cross_entropy(logits_per_image, labels) + |
|
F.cross_entropy(logits_per_text, labels) |
|
) / 2 |
|
|
|
cumulative_loss += total_loss * batch_size |
|
num_samples += batch_size |
|
if is_master(args) and (i % 100) == 0: |
|
logging.info( |
|
f"Eval Epoch: {epoch} [{num_samples} / {samples_per_val}]\t" |
|
f"Loss: {cumulative_loss / num_samples:.6f}\t") |
|
|
|
val_metrics = get_metrics( |
|
image_features=torch.cat(all_image_features), |
|
text_features=torch.cat(all_text_features), |
|
logit_scale=logit_scale.cpu(), |
|
) |
|
loss = cumulative_loss / num_samples |
|
metrics.update( |
|
{**val_metrics, "val_loss": loss.item(), "epoch": epoch, "num_samples": num_samples} |
|
) |
|
|
|
if not metrics: |
|
return metrics |
|
|
|
logging.info( |
|
f"Eval Epoch: {epoch} " |
|
+ "\t".join([f"{k}: {round(v, 4):.4f}" for k, v in metrics.items()]) |
|
) |
|
|
|
if args.save_logs: |
|
for name, val in metrics.items(): |
|
if tb_writer is not None: |
|
tb_writer.add_scalar(f"val/{name}", val, epoch) |
|
|
|
with open(os.path.join(args.checkpoint_path, "results.jsonl"), "a+") as f: |
|
f.write(json.dumps(metrics)) |
|
f.write("\n") |
|
|
|
if args.wandb: |
|
assert wandb is not None, 'Please install wandb.' |
|
for name, val in metrics.items(): |
|
wandb.log({f"val/{name}": val, 'epoch': epoch}) |
|
|
|
return metrics |
|
|
|
|
|
def get_metrics(image_features, text_features, logit_scale): |
|
metrics = {} |
|
logits_per_image = (logit_scale * image_features @ text_features.t()).detach().cpu() |
|
logits_per_text = logits_per_image.t().detach().cpu() |
|
|
|
logits = {"image_to_text": logits_per_image, "text_to_image": logits_per_text} |
|
ground_truth = torch.arange(len(text_features)).view(-1, 1) |
|
|
|
for name, logit in logits.items(): |
|
ranking = torch.argsort(logit, descending=True) |
|
preds = torch.where(ranking == ground_truth)[1] |
|
preds = preds.detach().cpu().numpy() |
|
metrics[f"{name}_mean_rank"] = preds.mean() + 1 |
|
metrics[f"{name}_median_rank"] = np.floor(np.median(preds)) + 1 |
|
for k in [1, 5, 10]: |
|
metrics[f"{name}_R@{k}"] = np.mean(preds < k) |
|
|
|
return metrics |
|
|