xmutly commited on
Commit
399c272
·
verified ·
1 Parent(s): 737b8b4

Upload adversarial_training_clip_with_object_token.py

Browse files
train/adversarial_training_clip_with_object_token.py CHANGED
@@ -265,7 +265,7 @@ def main(args):
265
 
266
  # save final model
267
  torch.save(unwrap_model(model).model.state_dict(), f'{args.output_dir}/checkpoints/final.pt')
268
- torch.save(unwrap_model(proj_head).model.state_dict(), f'{args.output_dir}/checkpoints/final_proj_head.pt')
269
 
270
  torch.save(optimizer.state_dict(), f'{args.output_dir}/checkpoints/final_opt.pt')
271
 
@@ -505,13 +505,15 @@ def train_one_epoch(
505
  wandb.log(log_data)
506
 
507
  # save 10 models over the course of training
508
- if args.save_checkpoints and (step_total % (args.steps // 10) == 0):
509
  # save model and optimizer state_dict
510
  torch.save(unwrap_model(model).model.state_dict(), f'{args.output_dir}/checkpoints/step_{step_total}.pt')
 
511
  torch.save(optimizer.state_dict(), f'{args.output_dir}/checkpoints/step_{step_total}_opt.pt')
512
  # every 200 steps, save a fallback model, which gets overwritten
513
- if step_total % 200 == 0:
514
  torch.save(unwrap_model(model).model.state_dict(), f'{args.output_dir}/checkpoints/fallback_{step_total}.pt')
 
515
  torch.save(optimizer.state_dict(), f'{args.output_dir}/checkpoints/fallback_{step_total}_opt.pt')
516
  # remove old fallback models
517
  for file in os.listdir(f'{args.output_dir}/checkpoints'):
 
265
 
266
  # save final model
267
  torch.save(unwrap_model(model).model.state_dict(), f'{args.output_dir}/checkpoints/final.pt')
268
+ torch.save(unwrap_model(proj_head).state_dict(), f'{args.output_dir}/checkpoints/final_proj_head.pt')
269
 
270
  torch.save(optimizer.state_dict(), f'{args.output_dir}/checkpoints/final_opt.pt')
271
 
 
505
  wandb.log(log_data)
506
 
507
  # save 10 models over the course of training
508
+ if args.save_checkpoints and (step_total % (args.steps // 1) == 0):
509
  # save model and optimizer state_dict
510
  torch.save(unwrap_model(model).model.state_dict(), f'{args.output_dir}/checkpoints/step_{step_total}.pt')
511
+ torch.save(unwrap_model(proj_head).state_dict(), f'{args.output_dir}/checkpoints/step_{step_total}_proj_head.pt')
512
  torch.save(optimizer.state_dict(), f'{args.output_dir}/checkpoints/step_{step_total}_opt.pt')
513
  # every 200 steps, save a fallback model, which gets overwritten
514
+ if step_total % 2 == 0:
515
  torch.save(unwrap_model(model).model.state_dict(), f'{args.output_dir}/checkpoints/fallback_{step_total}.pt')
516
+ torch.save(unwrap_model(proj_head).state_dict(), f'{args.output_dir}/checkpoints/fallback_{step_total}_proj_head.pt')
517
  torch.save(optimizer.state_dict(), f'{args.output_dir}/checkpoints/fallback_{step_total}_opt.pt')
518
  # remove old fallback models
519
  for file in os.listdir(f'{args.output_dir}/checkpoints'):