|
|
|
|
|
import sys |
|
sys.path.append("src") |
|
import logging |
|
import os |
|
import random |
|
from datetime import datetime |
|
from functools import partial |
|
|
|
import numpy as np |
|
import torch |
|
from torch import optim |
|
from torch.cuda.amp import GradScaler |
|
|
|
|
|
try: |
|
import wandb |
|
except ImportError: |
|
wandb = None |
|
|
|
try: |
|
import torch.utils.tensorboard as tensorboard |
|
except ImportError: |
|
tensorboard = None |
|
|
|
try: |
|
import horovod.torch as hvd |
|
except ImportError: |
|
hvd = None |
|
|
|
from open_clip import create_model_and_transforms, trace_model, get_mean_std |
|
from training.data import get_data |
|
from training.distributed import is_master, init_distributed_device, world_info_from_env |
|
from training.logger import setup_logging |
|
from training.params import parse_args |
|
from training.scheduler import cosine_lr |
|
from training.train import train_one_epoch, evaluate |
|
from training import train |
|
|
|
|
|
def save_checkpoint(model, optimizer, scaler, completed_epoch, args): |
|
checkpoint_dict = { |
|
"epoch": completed_epoch, |
|
"name": args.name, |
|
"state_dict": model.state_dict(), |
|
"optimizer": optimizer.state_dict(), |
|
} |
|
if scaler is not None: |
|
checkpoint_dict["scaler"] = scaler.state_dict() |
|
|
|
if args.save_logs: |
|
if completed_epoch == args.epochs or ( |
|
args.save_frequency > 0 and completed_epoch % args.save_frequency == 0 |
|
): |
|
torch.save( |
|
checkpoint_dict, |
|
os.path.join(args.checkpoint_path, f"epoch_{completed_epoch}.pt"), |
|
) |
|
if args.save_most_recent: |
|
torch.save( |
|
checkpoint_dict, |
|
os.path.join(args.checkpoint_path, f"epoch_latest.pt"), |
|
) |
|
|
|
|
|
def random_seed(seed=42, rank=0): |
|
torch.manual_seed(seed + rank) |
|
np.random.seed(seed + rank) |
|
random.seed(seed + rank) |
|
|
|
|
|
def main(args=None): |
|
if args is None: |
|
args = parse_args() |
|
|
|
|
|
args.model = args.model.replace('/', '-') |
|
|
|
|
|
if args.name is None: |
|
args.name = '-'.join([ |
|
datetime.now().strftime("%Y_%m_%d-%H_%M_%S"), |
|
f"model_{args.model}", |
|
f"lr_{args.lr}", |
|
f"b_{args.batch_size}", |
|
f"j_{args.workers}", |
|
f"p_{args.precision}", |
|
]) |
|
|
|
|
|
args.distributed = False |
|
args.local_rank, args.rank, args.world_size = world_info_from_env() |
|
|
|
args.log_path = None |
|
if is_master(args, local=args.log_local): |
|
log_base_path = os.path.join(args.logs, args.name) |
|
os.makedirs(log_base_path, exist_ok=True) |
|
log_filename = f'out-{args.rank}' if args.log_local else 'out.log' |
|
args.log_path = os.path.join(log_base_path, log_filename) |
|
if os.path.exists(args.log_path) and args.resume is None and not hasattr(args, "eval"): |
|
print( |
|
"Error. Experiment already exists. Use --name {} to specify a new experiment." |
|
) |
|
return -1 |
|
|
|
|
|
args.log_level = logging.DEBUG if args.debug else logging.INFO |
|
setup_logging(args.log_path, args.log_level) |
|
|
|
|
|
torch.backends.cudnn.benchmark = True |
|
torch.backends.cudnn.deterministic = False |
|
device = init_distributed_device(args) |
|
|
|
args.wandb = 'wandb' in args.report_to or 'all' in args.report_to |
|
args.tensorboard = 'tensorboard' in args.report_to or 'all' in args.report_to |
|
if is_master(args): |
|
args.tensorboard_path = os.path.join(args.logs, args.name, "tensorboard") if args.tensorboard else '' |
|
args.checkpoint_path = os.path.join(args.logs, args.name, "checkpoints") |
|
for dirname in [args.tensorboard_path, args.checkpoint_path]: |
|
if dirname: |
|
os.makedirs(dirname, exist_ok=True) |
|
else: |
|
args.tensorboard_path = '' |
|
args.checkpoint_path = '' |
|
|
|
if args.copy_codebase: |
|
copy_codebase(args) |
|
|
|
assert args.precision in ['amp', 'fp16', 'fp32'] |
|
if args.precision == 'fp16': |
|
logging.warning( |
|
'It is recommended to use AMP mixed-precision instead of FP16. ' |
|
'FP16 support needs further verification and tuning, especially for train.') |
|
|
|
if args.horovod: |
|
logging.info( |
|
f'Running in horovod mode with multiple processes / nodes. Device: {args.device}.' |
|
f'Process (global: {args.rank}, local {args.local_rank}), total {args.world_size}.') |
|
elif args.distributed: |
|
logging.info( |
|
f'Running in distributed mode with multiple processes. Device: {args.device}.' |
|
f'Process (global: {args.rank}, local {args.local_rank}), total {args.world_size}.') |
|
else: |
|
logging.info(f'Running with a single process. Device {args.device}.') |
|
|
|
random_seed(args.seed, 0) |
|
mean, std = get_mean_std(args) |
|
model, preprocess_train, preprocess_val = create_model_and_transforms( |
|
args.model, |
|
args.pretrained, |
|
precision=args.precision, |
|
device=device, |
|
jit=args.torchscript, |
|
force_quick_gelu=args.force_quick_gelu, |
|
pretrained_image=args.pretrained_image, |
|
mean=mean, std=std, |
|
inmem=hasattr(args, "inmem"), |
|
clip_model=args.clip_model, |
|
text_encoder_name=args.text_encoder_model_name, |
|
) |
|
random_seed(args.seed, args.rank) |
|
|
|
if args.trace: |
|
model = trace_model(model, batch_size=args.batch_size, device=device) |
|
|
|
if args.lock_image: |
|
|
|
model.lock_image_tower( |
|
unlocked_groups=args.lock_image_unlocked_groups, |
|
freeze_bn_stats=args.lock_image_freeze_bn_stats) |
|
|
|
if args.grad_checkpointing: |
|
model.set_grad_checkpointing() |
|
|
|
if is_master(args): |
|
logging.info("Model:") |
|
logging.info(f"{str(model)}") |
|
logging.info("Params:") |
|
params_file = os.path.join(args.logs, args.name, "params.txt") |
|
with open(params_file, "w") as f: |
|
for name in sorted(vars(args)): |
|
val = getattr(args, name) |
|
logging.info(f" {name}: {val}") |
|
f.write(f"{name}: {val}\n") |
|
|
|
if args.distributed and not args.horovod: |
|
if args.use_bn_sync: |
|
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) |
|
|
|
if args.distributed_engine == 'ddp': |
|
ddp_args = {} |
|
if args.ddp_static_graph: |
|
|
|
ddp_args['static_graph'] = True |
|
|
|
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device], **ddp_args) |
|
else: |
|
print("--distrubted_engine should be either 'ddp'") |
|
sys.exit(1) |
|
|
|
|
|
optimizer = None |
|
scaler = None |
|
if args.train_data: |
|
assert not args.trace, 'Cannot train with traced model' |
|
|
|
exclude = lambda n, p: p.ndim < 2 or "bn" in n or "ln" in n or "bias" in n or 'logit_scale' in n |
|
include = lambda n, p: not exclude(n, p) |
|
|
|
named_parameters = list(model.named_parameters()) |
|
gain_or_bias_params = [p for n, p in named_parameters if exclude(n, p) and p.requires_grad] |
|
rest_params = [p for n, p in named_parameters if include(n, p) and p.requires_grad] |
|
|
|
optimizer = optim.AdamW( |
|
[ |
|
{"params": gain_or_bias_params, "weight_decay": 0.}, |
|
{"params": rest_params, "weight_decay": args.wd}, |
|
], |
|
lr=args.lr, |
|
betas=(args.beta1, args.beta2), |
|
eps=args.eps, |
|
) |
|
if args.horovod: |
|
optimizer = hvd.DistributedOptimizer(optimizer, named_parameters=model.named_parameters()) |
|
hvd.broadcast_parameters(model.state_dict(), root_rank=0) |
|
hvd.broadcast_optimizer_state(optimizer, root_rank=0) |
|
|
|
if args.precision == "amp": |
|
scaler = GradScaler() |
|
else: |
|
scaler = None |
|
|
|
|
|
start_epoch = 0 |
|
start_epoch_step = 0 |
|
if args.resume is not None: |
|
if os.path.isfile(args.resume): |
|
checkpoint = torch.load(args.resume, map_location='cpu') |
|
if 'epoch' in checkpoint: |
|
|
|
start_epoch = checkpoint["epoch"] |
|
sd = checkpoint["state_dict"] |
|
if next(iter(sd.items()))[0].startswith('_orig_mod'): |
|
sd = {k[len('_orig_mod.'):]: v for k, v in sd.items()} |
|
if not args.distributed and next(iter(sd.items()))[0].startswith('module'): |
|
sd = {k[len('module.'):]: v for k, v in sd.items()} |
|
model.load_state_dict(sd) |
|
if optimizer is not None: |
|
optimizer.load_state_dict(checkpoint["optimizer"]) |
|
if scaler is not None and 'scaler' in checkpoint: |
|
scaler.load_state_dict(checkpoint['scaler']) |
|
if 'epoch_step' in checkpoint: |
|
start_epoch_step = checkpoint["epoch_step"] + 1 |
|
logging.info(f"=> resuming checkpoint '{args.resume}' (epoch {start_epoch}, step {start_epoch_step})") |
|
else: |
|
start_epoch_step = 0 |
|
logging.info(f"=> resuming checkpoint '{args.resume}' (epoch {start_epoch})") |
|
else: |
|
|
|
model.load_state_dict(checkpoint) |
|
logging.info(f"=> loaded checkpoint '{args.resume}' (epoch {start_epoch})") |
|
else: |
|
logging.info("=> no checkpoint found at '{}'".format(args.resume)) |
|
|
|
|
|
data = get_data(args, (preprocess_train, preprocess_val), epoch=start_epoch) |
|
|
|
if hasattr(args, "torchcompile") and args.torchcompile: |
|
logging.info('Compiling model...') |
|
try: |
|
model = torch.compile(model) |
|
except Exception: |
|
logging.warn("please use PyTorch 2.0") |
|
|
|
|
|
scheduler = None |
|
if 'train' in data and optimizer is not None: |
|
total_steps = data["train"].dataloader.num_batches * args.epochs |
|
scheduler = cosine_lr(optimizer, args.lr, args.warmup, total_steps) |
|
|
|
|
|
args.save_logs = args.logs and args.logs.lower() != 'none' and is_master(args) |
|
writer = None |
|
if args.save_logs and args.tensorboard: |
|
assert tensorboard is not None, "Please install tensorboard." |
|
writer = tensorboard.SummaryWriter(args.tensorboard_path) |
|
|
|
if args.wandb and is_master(args): |
|
assert wandb is not None, 'Please install wandb.' |
|
logging.debug('Starting wandb.') |
|
args.train_sz = data["train"].dataloader.num_samples |
|
if args.val_data is not None: |
|
args.val_sz = data["val"].dataloader.num_samples |
|
|
|
wandb.init( |
|
project="open-clip", |
|
notes=args.wandb_notes, |
|
tags=[], |
|
config=vars(args), |
|
) |
|
|
|
wandb.define_metric("epoch") |
|
|
|
wandb.define_metric("val/*", step_metric="epoch") |
|
if args.debug: |
|
wandb.watch(model, log='all') |
|
wandb.save(params_file) |
|
logging.debug('Finished loading wandb.') |
|
|
|
if 'train' not in data or hasattr(args, "eval") and args.eval: |
|
|
|
from training.slip_evaluate import slip_evaluate |
|
from open_clip import HFTokenizer |
|
context_length = args.tokenizer_context_length |
|
tokenizer_kwargs = {} |
|
tokenize = HFTokenizer( |
|
args.text_encoder_model_name, |
|
context_length=context_length, |
|
**tokenizer_kwargs, |
|
) |
|
|
|
os.makedirs(args.output_dir, exist_ok=True) |
|
slip_evaluate(args, model, preprocess_val, tokenize) |
|
evaluate(model, data, start_epoch, args, writer) |
|
return |
|
|
|
epoch_step = start_epoch_step |
|
|
|
from training.slip_evaluate import slip_evaluate |
|
|
|
from open_clip import HFTokenizer |
|
context_length = args.tokenizer_context_length |
|
tokenizer_kwargs = {} |
|
tokenize = HFTokenizer( |
|
args.text_encoder_model_name, |
|
context_length=context_length, |
|
**tokenizer_kwargs, |
|
) |
|
for epoch in range(start_epoch, args.epochs): |
|
if is_master(args): |
|
logging.info(f'Start epoch {epoch}') |
|
if hasattr(args, "engine"): |
|
engine = args.engine |
|
module = train |
|
engine_cls = getattr(module, engine) |
|
engine_cls(model, data, epoch, epoch_step, optimizer, scaler, scheduler, args, writer) |
|
else: |
|
train_one_epoch(model, data, epoch, optimizer, scaler, scheduler, args, writer) |
|
|
|
epoch_step = 0 |
|
|
|
completed_epoch = epoch + 1 |
|
|
|
if any(v in data for v in ('val', 'imagenet-val', 'imagenet-v2')): |
|
evaluate(model, data, completed_epoch, args, writer) |
|
|
|
if (completed_epoch % args.eval_freq) == 0: |
|
slip_evaluate(args, model, preprocess_val, tokenize, epoch) |
|
save_checkpoint(model, optimizer, scaler, completed_epoch, args) |
|
|
|
if hasattr(args, "eval") and args.eval and any(v in data for v in ('val', 'imagenet-val', 'imagenet-v2')): |
|
from training.slip_evaluate import slip_evaluate |
|
|
|
slip_evaluate(args, model, preprocess_val, tokenize) |
|
|
|
if args.wandb and is_master(args): |
|
wandb.finish() |
|
|
|
|
|
def copy_codebase(args): |
|
from shutil import copytree, ignore_patterns |
|
new_code_path = os.path.join(args.logs, args.name, "code") |
|
if os.path.exists(new_code_path): |
|
print( |
|
f"Error. Experiment already exists at {new_code_path}. Use --name to specify a new experiment." |
|
) |
|
return -1 |
|
print(f"Copying codebase to {new_code_path}") |
|
current_code_path = os.path.realpath(__file__) |
|
for _ in range(3): |
|
current_code_path = os.path.dirname(current_code_path) |
|
copytree(current_code_path, new_code_path, ignore=ignore_patterns('log', 'logs', 'wandb')) |
|
print("Done copying code.") |
|
return 1 |
|
|
|
|
|
if __name__ == "__main__": |
|
import sys |
|
sys.path.append("./") |
|
from configs import search_config |
|
config = search_config(sys.argv[1]) |
|
exp_name = sys.argv[2] |
|
load_path = sys.argv[3] |
|
if len(sys.argv) == 3: |
|
config.resume = os.path.join(config.output_dir, "checkpoints", sys.argv[2]) |
|
config.pretrained = load_path |
|
config.logs = exp_name |
|
config.output_dir = os.path.join(config.logs, config.name) |
|
main(config) |
|
|