xmutly commited on
Commit
5218bc4
·
verified ·
1 Parent(s): 399c272

Upload adversarial_training_clip_with_object_token.py

Browse files
train/adversarial_training_clip_with_object_token.py CHANGED
@@ -108,6 +108,8 @@ def main(args):
108
  assert str(args.start_step) in args.optimizer_state
109
  assert args.pretrained in ['', 'none']
110
  args.pretrained = args.optimizer_state.replace('_opt', '')
 
 
111
  model, _, _ = load_clip_model(args.clip_model_name, args.pretrained)
112
 
113
  # Remove the Normalize transform by creating a new Compose object
@@ -128,6 +130,9 @@ def main(args):
128
  cfg_dict = {'slot_dim': 256, 'num_slots': 10, 'token_num': 256, 'ISA': False, 'slot_att_iter': 3, 'query_opt': False}
129
  model_slots = DINOSAURpp(cfg_dict)
130
  proj_head = torch.nn.Linear(256, 1024) # slot-num to slot-num
 
 
 
131
 
132
 
133
  # get data
@@ -505,13 +510,13 @@ def train_one_epoch(
505
  wandb.log(log_data)
506
 
507
  # save 10 models over the course of training
508
- if args.save_checkpoints and (step_total % (args.steps // 1) == 0):
509
  # save model and optimizer state_dict
510
  torch.save(unwrap_model(model).model.state_dict(), f'{args.output_dir}/checkpoints/step_{step_total}.pt')
511
  torch.save(unwrap_model(proj_head).state_dict(), f'{args.output_dir}/checkpoints/step_{step_total}_proj_head.pt')
512
  torch.save(optimizer.state_dict(), f'{args.output_dir}/checkpoints/step_{step_total}_opt.pt')
513
  # every 200 steps, save a fallback model, which gets overwritten
514
- if step_total % 2 == 0:
515
  torch.save(unwrap_model(model).model.state_dict(), f'{args.output_dir}/checkpoints/fallback_{step_total}.pt')
516
  torch.save(unwrap_model(proj_head).state_dict(), f'{args.output_dir}/checkpoints/fallback_{step_total}_proj_head.pt')
517
  torch.save(optimizer.state_dict(), f'{args.output_dir}/checkpoints/fallback_{step_total}_opt.pt')
@@ -523,7 +528,7 @@ def train_one_epoch(
523
  if step_total >= args.steps:
524
  break
525
 
526
- torch.cuda.empty_cache()
527
  return step_total
528
 
529
 
 
108
  assert str(args.start_step) in args.optimizer_state
109
  assert args.pretrained in ['', 'none']
110
  args.pretrained = args.optimizer_state.replace('_opt', '')
111
+ args.pretrained_proj_head = args.optimizer_state.replace('_opt', '_proj_head')
112
+
113
  model, _, _ = load_clip_model(args.clip_model_name, args.pretrained)
114
 
115
  # Remove the Normalize transform by creating a new Compose object
 
130
  cfg_dict = {'slot_dim': 256, 'num_slots': 10, 'token_num': 256, 'ISA': False, 'slot_att_iter': 3, 'query_opt': False}
131
  model_slots = DINOSAURpp(cfg_dict)
132
  proj_head = torch.nn.Linear(256, 1024) # slot-num to slot-num
133
+ if args.optimizer_state != '':
134
+ proj_head.load_state_dict(torch.load(args.pretrained_proj_head))
135
+
136
 
137
 
138
  # get data
 
510
  wandb.log(log_data)
511
 
512
  # save 10 models over the course of training
513
+ if args.save_checkpoints and (step_total % (args.steps // 10) == 0):
514
  # save model and optimizer state_dict
515
  torch.save(unwrap_model(model).model.state_dict(), f'{args.output_dir}/checkpoints/step_{step_total}.pt')
516
  torch.save(unwrap_model(proj_head).state_dict(), f'{args.output_dir}/checkpoints/step_{step_total}_proj_head.pt')
517
  torch.save(optimizer.state_dict(), f'{args.output_dir}/checkpoints/step_{step_total}_opt.pt')
518
  # every 200 steps, save a fallback model, which gets overwritten
519
+ if step_total % 2000 == 0:
520
  torch.save(unwrap_model(model).model.state_dict(), f'{args.output_dir}/checkpoints/fallback_{step_total}.pt')
521
  torch.save(unwrap_model(proj_head).state_dict(), f'{args.output_dir}/checkpoints/fallback_{step_total}_proj_head.pt')
522
  torch.save(optimizer.state_dict(), f'{args.output_dir}/checkpoints/fallback_{step_total}_opt.pt')
 
528
  if step_total >= args.steps:
529
  break
530
 
531
+ # torch.cuda.empty_cache()
532
  return step_total
533
 
534