import sys from train.datasets import COCOFlickrDataset, ImageNetDataset from CLIP_eval.eval_utils import load_clip_model sys.path.append("open_flamingo") import os import shutil import time import string import random import numpy as np import open_clip import torch import torch.nn.functional as F from torch.utils.data import DataLoader from training.scheduler import cosine_lr from torchvision import transforms from open_flamingo.eval.classification_utils import IMAGENET_1K_CLASS_ID_TO_LABEL from train.pgd_train import pgd from train.apgd_train import apgd_train as apgd import wandb from train.utils import init_wandb, AverageMeter from train.sam_data import SamData from open_flamingo.eval.models.utils import unwrap_model from train.utils import str2bool from torch.hub import load RANDOM_SEED = 42 # any random number def set_seed(seed): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) # CPU torch.cuda.manual_seed(seed) # GPU torch.cuda.manual_seed_all(seed) # All GPU os.environ['PYTHONHASHSEED'] = str(seed) # 禁止hash随机化 torch.backends.cudnn.deterministic = True # 确保每次返回的卷积算法是确定的 torch.backends.cudnn.benchmark = False # True的话会自动寻找最适合当前配置的高效算法,来达到优化运行效率的问题。False保证实验结果可复现 set_seed(RANDOM_SEED) from slots.DINOSAUR import DINOSAURpp import matplotlib.pyplot as plt from einops import rearrange, repeat import argparse parser = argparse.ArgumentParser() parser.add_argument('--clip_model_name', type=str, default='ViT-L-14', help='ViT-L-14, ViT-B-32') parser.add_argument('--pretrained', type=str, default='openai') parser.add_argument('--dataset', type=str, default='imagenet') parser.add_argument('--template', type=str, default='std') parser.add_argument('--imagenet_root', type=str, default='/mnt/datasets/imagenet', help='Imagenet dataset root directory') parser.add_argument('--output_normalize', type=str2bool, default=False, help='Whether the embedding is normalized') parser.add_argument('--start_step', type=int, default=0, help='Start step for training') parser.add_argument('--optimizer_state', type=str, default='', help='Optimizer state file path') parser.add_argument('--steps', type=int, default=20000, help='Number of training steps') parser.add_argument('--warmup', type=int, default=14000, help='Warmup steps') parser.add_argument('--batch_size', type=int, default=256) parser.add_argument('--loss', type=str, default='l2', help='ce, l2') parser.add_argument('--loss_clean', type=str, default='none', help='ce, l2') parser.add_argument('--clean_weight', type=float, default=0., help='Weight for clean loss') parser.add_argument('--trades', type=str2bool, default=False, help='Use TRADES') parser.add_argument('--opt', type=str, default='adamw', help='Optimizer type; sgd, adamw') parser.add_argument('--momentum_sgd', type=float, default=0.9, help='Momentum for SGD optimizer') parser.add_argument('--lr', type=float, default=1e-5, help='Learning rate') parser.add_argument('--wd', type=float, default=1e-4, help='Weight decay') parser.add_argument('--attack', type=str, default='apgd', help='Adversarial attack type') parser.add_argument('--inner_loss', type=str, default='l2', help='Inner loss function for adversarial training') parser.add_argument('--norm', type=str, default='linf', help='Norm for adversarial perturbation') parser.add_argument('--eps', type=float, default=4, help='Epsilon for adversarial perturbation') parser.add_argument('--iterations_adv', type=int, default=10, help='Iterations for adversarial attack') parser.add_argument('--stepsize_adv', type=float, default=1., help='Step size for adversarial attack (no effect for apgd)') parser.add_argument('--wandb', type=str2bool, default=True, help='Use Weights & Biases for logging') parser.add_argument('--experiment_name', type=str, default='') parser.add_argument('--overwrite', type=str2bool, default=False, help='Overwrite existing directory') parser.add_argument('--log_freq', type=int, default=1, help='Logging frequency') parser.add_argument('--eval_freq', type=int, default=50, help='Evaluation frequency') parser.add_argument('--output_dir', type=str, default='', help='Output directory') parser.add_argument('--save_checkpoints', type=str2bool, default=True, help='Save 10 training checkpoints') parser.add_argument('--devices', type=str, default='', help='Device IDs for CUDA') def main(args): # setup wandb if args.wandb: init_wandb( project_name='clip-finetune', model_name=args.finetuned_model_name, config=vars(args) ) else: wandb.init(mode='disabled') # print args print(f"Arguments:\n{'-' * 20}") for arg, value in vars(args).items(): print(f"{arg}: {value}") print(f"{'-' * 20}") # setup dirs if args.overwrite: shutil.rmtree(args.output_dir, ignore_errors=True) os.makedirs(os.path.join(args.output_dir, 'checkpoints'), exist_ok=False) # write args to file with open(os.path.join(args.output_dir, 'args.txt'), 'w') as f: f.write(str(args)) main_device = 0 # get models from open_clip.model import CLIPVisionCfg CLIPVisionCfg.output_tokens = True model_orig, _, image_processor = open_clip.create_model_and_transforms( args.clip_model_name, pretrained='openai'#, output_tokens=True # 可选 output_tokens=True,返回token + patches ) # Remove the Normalize transform by creating a new Compose object preprocessor_without_normalize = transforms.Compose(image_processor.transforms[:-1]) normalize = image_processor.transforms[-1] del image_processor print(f'[preprocessor_without_normalize] {preprocessor_without_normalize}') ####################################################### get slot-attention model ######################################################### cfg_dict = {'slot_dim': 256, 'num_slots': 10, 'token_num': 256, 'ISA': False, 'slot_att_iter': 3, 'query_opt': False} model_slots = DINOSAURpp(cfg_dict) # get data if args.dataset == 'imagenet': dataset = ImageNetDataset( root=args.imagenet_root + '/train', transform=preprocessor_without_normalize, ) elif args.dataset == 'segment_anything': dataset = SamData('/data/naman_deep_singh/datasets/newSAM', transform=preprocessor_without_normalize) print(dataset.__len__()) elif args.dataset == 'coco': if os.path.exists('/mnt/datasets/coco'): image_dir_path = '/mnt/datasets/coco/train2017' annotations_path = '/mnt/datasets/coco/annotations/captions_train2017.json' elif os.path.exists('/mnt/lustre'): image_dir_path = '/mnt/lustre/hein/cschlarmann37/datasets/coco/train2017' annotations_path = '/mnt/lustre/hein/cschlarmann37/datasets/coco/annotations/captions_train2017.json' else: raise ValueError('COCO dataset not found') dataset = COCOFlickrDataset( image_dir_path=image_dir_path, annotations_path=annotations_path, transform=preprocessor_without_normalize ) dataset_eval = ImageNetDataset( root=args.imagenet_root + '/val', transform=preprocessor_without_normalize, ) dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=8, drop_last=True) dataloader_eval = DataLoader(dataset_eval, batch_size=args.batch_size, shuffle=True, num_workers=8, drop_last=True) # Get text label embeddings of all ImageNet classes if args.template == 'std': template = 'This is a photo of a {}' elif args.template == 'blurry': template = 'This is a blurry photo of a {}' else: raise ValueError(f'Unknown template: {args.template}') print(f'template: {template}') texts = [template.format(c) for c in IMAGENET_1K_CLASS_ID_TO_LABEL.values()] text_tokens = open_clip.tokenize(texts) model_orig.to(main_device) with torch.no_grad(): embedding_text_labels_norm = [] for el in (text_tokens[:500], text_tokens[500:]): # we need to split the text tokens into two batches because otherwise we run out of memory # note that we are accessing the model directly here, not the CustomModel wrapper # thus its always normalizing the text embeddings embedding_text_labels_norm.append( model_orig.encode_text(el.to(main_device), normalize=True).detach().cpu() ) embedding_text_labels_norm = torch.cat(embedding_text_labels_norm).T.to(main_device) assert torch.allclose( F.normalize(embedding_text_labels_norm, dim=0), embedding_text_labels_norm ) if args.clip_model_name == 'ViT-B-32': assert embedding_text_labels_norm.shape == (512, 1000), embedding_text_labels_norm.shape elif args.clip_model_name in ('ViT-L-14', 'ViT-L-14-336'): assert embedding_text_labels_norm.shape == (768, 1000), embedding_text_labels_norm.shape else: raise ValueError(f'Unknown model: {args.clip_model_name}') model_orig.cpu() model_orig = ClipVisionModel(model=model_orig.visual, args=args, normalize=normalize) if num_gpus > 1: model_orig = torch.nn.DataParallel(model_orig) model_orig.cuda() model_slots = model_slots if num_gpus > 1: model_slots = torch.nn.DataParallel(model_slots) model_slots.cuda() ####################################################### add dino v2 as target model ##################################################### target_backbone_name = 'dinov2_vitb14' model_dinov2 = load('/home/tly/.cache/torch/hub/facebookresearch_dinov2_main', target_backbone_name, source='local') # /home/ly/.cache/torch/hub/facebookresearch_dinov2_main facebookresearch/dinov2 # model_dinov2 = load('facebookresearch/dinov2', target_backbone_name) # /home/ly/.cache/torch/hub/facebookresearch_dinov2_main facebookresearch/dinov2 proj_head = torch.nn.Linear(1024, 768) if num_gpus > 1: model_dinov2 = torch.nn.DataParallel(model_dinov2) proj_head = torch.nn.DataParallel(proj_head) model_dinov2.cuda() proj_head.cuda() # set optimizer (all params have requires_grad=True) params_slots = unwrap_model(model_slots).parameters() params_head = unwrap_model(proj_head).parameters() if args.opt == 'adamw': optimizer = torch.optim.AdamW( [{'params': params_slots}, {'params': params_head}], lr=args.lr, weight_decay=args.wd) elif args.opt == 'sgd': optimizer = torch.optim.SGD( [{'params': params_slots}, {'params': params_head}], lr=args.lr, momentum=args.momentum_sgd, weight_decay=args.wd ) else: raise ValueError(f'Optimizer {args.optimizer} not supported.') if args.optimizer_state != '': optimizer.load_state_dict(torch.load(args.optimizer_state)) # set scheduler scheduler = cosine_lr(optimizer, args.lr, args.warmup, args.steps) # compute amount of epochs total_epochs = args.steps / len(dataloader) print(f'train for {total_epochs} epochs') args.total_epochs = total_epochs # finetune step_total = args.start_step epoch = 0 while step_total < args.steps: step_total = train_one_epoch_slots( step_total, model_slots=model_slots, model_orig=model_orig, model_dinov2=model_dinov2, proj_head=proj_head, dataloader=dataloader, dataloader_eval=dataloader_eval, optimizer=optimizer, scheduler=scheduler, embedding_text_labels_norm=embedding_text_labels_norm, normalize=normalize, args=args, epoch=epoch ) print(f'Epoch {epoch} done.') epoch += 1 # save final model torch.save(unwrap_model(model_slots).state_dict(), f'{args.output_dir}/checkpoints/final.pt') torch.save(optimizer.state_dict(), f'{args.output_dir}/checkpoints/final_opt.pt') if args.output_dir.endswith('_temp'): # rename temp dir to final dir os.rename(args.output_dir, args.output_dir[:-5]) class ClipVisionModel(torch.nn.Module): def __init__(self, model, args, normalize): super().__init__() self.model = model self.args = args self.normalize = normalize def forward(self, vision, output_normalize, return_all_blocks=False): vision = self.normalize(vision) embedding, patches = self.model(vision, return_all_blocks=return_all_blocks) if output_normalize: embedding = F.normalize(embedding, dim=-1) return embedding, patches class ComputeLossWrapper: def __init__(self, embedding_orig, embedding_text_labels_norm, reduction='mean', loss=None, logit_scale=100.): self.embedding_orig = embedding_orig self.embedding_text_labels_norm = embedding_text_labels_norm self.reduction = reduction self.loss_str = loss self.logit_scale = logit_scale def __call__(self, embedding, targets): return compute_loss( loss_str=self.loss_str, embedding=embedding, targets=targets, embedding_orig=self.embedding_orig, logit_scale=self.logit_scale, embedding_text_labels_norm=self.embedding_text_labels_norm, reduction=self.reduction ) def mean_flat(x): """ Take the mean over all non-batch dimensions. """ return torch.mean(x, dim=list(range(1, len(x.size())))) def train_one_epoch_slots( step_total, model_slots, model_orig, model_dinov2, proj_head, dataloader, optimizer, scheduler, normalize, embedding_text_labels_norm, args, epoch, dataloader_eval=None ): model_orig.eval() model_slots.train() MSEFunc = torch.nn.MSELoss() loss_meter = AverageMeter('loss') epoch_start_time = time.time() for i, (data, targets) in enumerate(dataloader): is_classification = isinstance(targets, torch.Tensor) data = data.cuda() n_samples = data.shape[0] if is_classification: targets = targets.cuda() with torch.no_grad(): embedding_orig, patches_orig = model_orig(vision=data, output_normalize=args.output_normalize, return_all_blocks=True) feat_dinov2_dict = model_dinov2(data, is_training=True) if num_gpus > 1: patches_orig = model_orig.module.model.ln_pre(patches_orig) else: patches_orig = model_orig.model.ln_pre(patches_orig) reconstruction, slots, masks, x_dinov2 = model_slots(patches_orig) # (B, token, 768) x_dinov2 = torch.sum(x_dinov2, dim=1) # change target to dino features # reconstruction_to_dinov2 = proj_head(reconstruction) b, hw, c = patches_orig.shape h, w = int(np.sqrt(hw)), int(np.sqrt(hw)) k = slots.size(1) c_dinov2 = feat_dinov2_dict['x_norm_patchtokens'].size(-1) reconstruction = rearrange(reconstruction, 'b (h w) c -> b c h w', h=h, w=w) masks = rearrange(masks, 'b k (h w) -> b k h w', h=h, w=w, k=k) patches_orig = rearrange(patches_orig, 'b (h w) c -> b c h w', h=h, w=w) # loss for the attack loss = MSEFunc(reconstruction, patches_orig) x_dinov2 = torch.nn.functional.normalize(x_dinov2, dim=-1) feat_dinov2 = torch.nn.functional.normalize(feat_dinov2_dict['x_norm_patchtokens'], dim=-1) # x_dinov2 = torch.nn.functional.normalize(x_dinov2, dim=-1) # feat_dinov2 = torch.nn.functional.normalize(feat_dinov2_dict['x_norm_patchtokens'], dim=-1) # loss_dinov2 = MSEFunc(feat_dinov2, x_dinov2)#reconstruction_to_dinov2) loss_dinov2 = mean_flat(-(x_dinov2 * feat_dinov2).sum(dim=-1)).mean() loss_total = loss + 0.0001*loss_dinov2 loss_total.backward() optimizer.step() optimizer.zero_grad() step_total += 1 scheduler(step_total) lr_ = optimizer.param_groups[0].get('lr') if (step_total-1) % args.log_freq == 0: log_str = f'[step] {step_total} [lr] {lr_:.6f} [loss] {loss.item():.6f} [loss_dinov2] {loss_dinov2.item():.6f}' print(log_str) log_data = { 'step': step_total, 'lr': lr_, 'loss': loss.item(), 'loss_dinov2': loss_dinov2.item(), 'loss-total': loss_total.item(), 'avg/loss': loss_meter.avg, } if (step_total-1) % (args.log_freq * 10) == 0: # compute expected average epoch time in hours batch_average_time = (time.time() - epoch_start_time) / (i + 1) / (60**2) epoch_average_time = batch_average_time * len(dataloader) this_epoch_remaining = epoch_average_time - \ (time.time() - epoch_start_time) / 60**2 total_remaining = epoch_average_time * (args.total_epochs - epoch - i / len(dataloader)) print(f'[epoch average time] {epoch_average_time:.2f} [this epoch remaining] ' f'{this_epoch_remaining:.2f} [total remaining] {total_remaining:.2f}') log_data.update({ 'time/total-remaining': total_remaining, 'time/this-epoch-remaining': this_epoch_remaining, 'time/epoch-average-time': epoch_average_time, 'time/batch-average-time': batch_average_time, 'other/epoch': epoch + i / len(dataloader), }) wandb.log(log_data) # save 10 models over the course of training if args.save_checkpoints and (step_total % (args.steps // 10) == 0): # save model and optimizer state_dict torch.save(unwrap_model(model_slots).state_dict(), f'{args.output_dir}/checkpoints/step_{step_total}.pt') torch.save(unwrap_model(proj_head).state_dict(), f'{args.output_dir}/checkpoints/step_{step_total}_proj.pt') torch.save(optimizer.state_dict(), f'{args.output_dir}/checkpoints/step_{step_total}_opt.pt') # every 200 steps, save a fallback model, which gets overwritten if step_total % 200 == 0: torch.save(unwrap_model(model_slots).state_dict(), f'{args.output_dir}/checkpoints/fallback_{step_total}.pt') torch.save(unwrap_model(proj_head).state_dict(), f'{args.output_dir}/checkpoints/fallback_{step_total}_proj.pt') torch.save(optimizer.state_dict(), f'{args.output_dir}/checkpoints/fallback_{step_total}_opt.pt') # remove old fallback models for file in os.listdir(f'{args.output_dir}/checkpoints'): if file.startswith('fallback') and not str(step_total) in file: os.remove(f'{args.output_dir}/checkpoints/{file}') ######################################################## Save ori Image and recon Image ######################## # if epoch % 5 == 0: save_pics_path = os.path.join(args.output_dir, 'slots_recons') recon_pic_save_path = os.path.join(save_pics_path, args.dataset) os.makedirs(recon_pic_save_path, exist_ok=True) plt.imshow(reconstruction[0, 0].detach().cpu().numpy()) save_name = 'recon_pic_steps{}.png'.format(step_total) plt.savefig(os.path.join(recon_pic_save_path, save_name)) plt.imshow(patches_orig[0, 0].detach().cpu().numpy()) save_name = 'recon_pic_steps{}_feat.png'.format(step_total) plt.savefig(os.path.join(recon_pic_save_path, save_name)) plt.imshow(rearrange(x_dinov2, 'b (h w) c_dinov2 -> b c_dinov2 h w', h=h, w=w)[0, 0].detach().cpu().numpy()) save_name = 'recon_pic_steps{}_dinov2.png'.format(step_total) plt.savefig(os.path.join(recon_pic_save_path, save_name)) plt.imshow(rearrange(feat_dinov2_dict['x_norm_patchtokens'], 'b (h w) c_dinov2 -> b c_dinov2 h w', h=h, w=w)[0, 0].detach().cpu().numpy()) save_name = 'recon_pic_steps{}_feat_dinov2.png'.format(step_total) plt.savefig(os.path.join(recon_pic_save_path, save_name)) plt.imshow(data[0].permute(1, 2, 0).detach().cpu().numpy()) save_name = 'recon_pic_steps{}_ori.png'.format(step_total) plt.savefig(os.path.join(recon_pic_save_path, save_name)) plt.close('all') if step_total >= args.steps: break # torch.cuda.empty_cache() return step_total @torch.no_grad() def compute_acc(logits, targets): preds_clean = logits.max(dim=1)[1].detach() acc = (preds_clean.eq(targets).sum() / targets.shape[0]).item() * 100 return acc def compute_loss(loss_str, embedding, targets, embedding_orig, logit_scale, embedding_text_labels_norm=None, reduction='mean'): if loss_str == 'l2': loss = l2(out=embedding, targets=embedding_orig, reduction=reduction) elif loss_str == 'ce': loss = ce( out=embedding @ (logit_scale * embedding_text_labels_norm), targets=targets, reduction=reduction ) else: raise ValueError(f'loss {loss_str} not supported') return loss def l2(out, targets, reduction='none'): # squared l2 - it does not divide by the latent dimension # should have shape (batch_size, embedding_size) assert out.shape == targets.shape, f'{out.shape} != {targets.shape}' assert out.shape[0] > 1 # Compute the element-wise squared error squared_error_batch = F.mse_loss(out, targets, reduction='none') if reduction == 'mean': squared_error_batch = torch.mean(squared_error_batch.sum(dim=1)) else: squared_error_batch = squared_error_batch.sum(dim=1) assert squared_error_batch.shape == (out.shape[0],), f'{squared_error_batch.shape} != {(out.shape[0],)}' return squared_error_batch def ce(out, targets, reduction='mean'): # out = logits assert out.shape[0] == targets.shape[0], (out.shape, targets.shape) assert out.shape[0] > 1 return F.cross_entropy(out, targets, reduction=reduction) if __name__ == '__main__': # set seeds torch.manual_seed(0) np.random.seed(0) # Parse command-line arguments args = parser.parse_args() args.eps /= 255 args.stepsize_adv /= 255 # make sure there is no string in args that should be a bool assert not any([isinstance(x, str) and x in ['True', 'False'] for x in args.__dict__.values()]), f'args contains a string that should be a bool: {args}' assert args.eval_freq % args.log_freq == 0, 'eval_freq must be a multiple of log_freq' if args.devices != '': # set cuda visible devices os.environ['CUDA_VISIBLE_DEVICES'] = args.devices num_gpus = torch.cuda.device_count() if num_gpus > 1: print(f'Number of GPUs available: {num_gpus}') else: print('No multiple GPUs available.') # set model name and output dir random_str = ''.join(random.choices(string.ascii_letters + string.digits, k=5)) args.finetuned_model_name = f'{args.clip_model_name}_{args.pretrained}_{args.dataset}_{args.loss}_{args.dataset}_{args.experiment_name}_{random_str}' args.finetuned_model_name = args.finetuned_model_name.replace('/', '_') args.output_dir = os.path.join(args.output_dir, args.finetuned_model_name) # run main(args)