Upload adversarial_training_clip_with_object_token.py
Browse files
train/adversarial_training_clip_with_object_token.py
CHANGED
@@ -265,7 +265,7 @@ def main(args):
|
|
265 |
|
266 |
# save final model
|
267 |
torch.save(unwrap_model(model).model.state_dict(), f'{args.output_dir}/checkpoints/final.pt')
|
268 |
-
torch.save(unwrap_model(proj_head).
|
269 |
|
270 |
torch.save(optimizer.state_dict(), f'{args.output_dir}/checkpoints/final_opt.pt')
|
271 |
|
@@ -505,13 +505,15 @@ def train_one_epoch(
|
|
505 |
wandb.log(log_data)
|
506 |
|
507 |
# save 10 models over the course of training
|
508 |
-
if args.save_checkpoints and (step_total % (args.steps //
|
509 |
# save model and optimizer state_dict
|
510 |
torch.save(unwrap_model(model).model.state_dict(), f'{args.output_dir}/checkpoints/step_{step_total}.pt')
|
|
|
511 |
torch.save(optimizer.state_dict(), f'{args.output_dir}/checkpoints/step_{step_total}_opt.pt')
|
512 |
# every 200 steps, save a fallback model, which gets overwritten
|
513 |
-
if step_total %
|
514 |
torch.save(unwrap_model(model).model.state_dict(), f'{args.output_dir}/checkpoints/fallback_{step_total}.pt')
|
|
|
515 |
torch.save(optimizer.state_dict(), f'{args.output_dir}/checkpoints/fallback_{step_total}_opt.pt')
|
516 |
# remove old fallback models
|
517 |
for file in os.listdir(f'{args.output_dir}/checkpoints'):
|
|
|
265 |
|
266 |
# save final model
|
267 |
torch.save(unwrap_model(model).model.state_dict(), f'{args.output_dir}/checkpoints/final.pt')
|
268 |
+
torch.save(unwrap_model(proj_head).state_dict(), f'{args.output_dir}/checkpoints/final_proj_head.pt')
|
269 |
|
270 |
torch.save(optimizer.state_dict(), f'{args.output_dir}/checkpoints/final_opt.pt')
|
271 |
|
|
|
505 |
wandb.log(log_data)
|
506 |
|
507 |
# save 10 models over the course of training
|
508 |
+
if args.save_checkpoints and (step_total % (args.steps // 1) == 0):
|
509 |
# save model and optimizer state_dict
|
510 |
torch.save(unwrap_model(model).model.state_dict(), f'{args.output_dir}/checkpoints/step_{step_total}.pt')
|
511 |
+
torch.save(unwrap_model(proj_head).state_dict(), f'{args.output_dir}/checkpoints/step_{step_total}_proj_head.pt')
|
512 |
torch.save(optimizer.state_dict(), f'{args.output_dir}/checkpoints/step_{step_total}_opt.pt')
|
513 |
# every 200 steps, save a fallback model, which gets overwritten
|
514 |
+
if step_total % 2 == 0:
|
515 |
torch.save(unwrap_model(model).model.state_dict(), f'{args.output_dir}/checkpoints/fallback_{step_total}.pt')
|
516 |
+
torch.save(unwrap_model(proj_head).state_dict(), f'{args.output_dir}/checkpoints/fallback_{step_total}_proj_head.pt')
|
517 |
torch.save(optimizer.state_dict(), f'{args.output_dir}/checkpoints/fallback_{step_total}_opt.pt')
|
518 |
# remove old fallback models
|
519 |
for file in os.listdir(f'{args.output_dir}/checkpoints'):
|