vumichien commited on
Commit
c675260
·
1 Parent(s): d27d313

First commit

Browse files
Files changed (1) hide show
  1. VQ-Trans/checkpoints/train_vq.py +171 -0
VQ-Trans/checkpoints/train_vq.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+
4
+ import torch
5
+ import torch.optim as optim
6
+ from torch.utils.tensorboard import SummaryWriter
7
+
8
+ import models.vqvae as vqvae
9
+ import utils.losses as losses
10
+ import options.option_vq as option_vq
11
+ import utils.utils_model as utils_model
12
+ from dataset import dataset_VQ, dataset_TM_eval
13
+ import utils.eval_trans as eval_trans
14
+ from options.get_eval_option import get_opt
15
+ from models.evaluator_wrapper import EvaluatorModelWrapper
16
+ import warnings
17
+ warnings.filterwarnings('ignore')
18
+ from utils.word_vectorizer import WordVectorizer
19
+
20
+ def update_lr_warm_up(optimizer, nb_iter, warm_up_iter, lr):
21
+
22
+ current_lr = lr * (nb_iter + 1) / (warm_up_iter + 1)
23
+ for param_group in optimizer.param_groups:
24
+ param_group["lr"] = current_lr
25
+
26
+ return optimizer, current_lr
27
+
28
+ ##### ---- Exp dirs ---- #####
29
+ args = option_vq.get_args_parser()
30
+ torch.manual_seed(args.seed)
31
+
32
+ args.out_dir = os.path.join(args.out_dir, f'{args.exp_name}')
33
+ os.makedirs(args.out_dir, exist_ok = True)
34
+
35
+ ##### ---- Logger ---- #####
36
+ logger = utils_model.get_logger(args.out_dir)
37
+ writer = SummaryWriter(args.out_dir)
38
+ logger.info(json.dumps(vars(args), indent=4, sort_keys=True))
39
+
40
+
41
+
42
+ w_vectorizer = WordVectorizer('./glove', 'our_vab')
43
+
44
+ if args.dataname == 'kit' :
45
+ dataset_opt_path = 'checkpoints/kit/Comp_v6_KLD005/opt.txt'
46
+ args.nb_joints = 21
47
+
48
+ else :
49
+ dataset_opt_path = 'checkpoints/t2m/Comp_v6_KLD005/opt.txt'
50
+ args.nb_joints = 22
51
+
52
+ logger.info(f'Training on {args.dataname}, motions are with {args.nb_joints} joints')
53
+
54
+ wrapper_opt = get_opt(dataset_opt_path, torch.device('cuda'))
55
+ eval_wrapper = EvaluatorModelWrapper(wrapper_opt)
56
+
57
+
58
+ ##### ---- Dataloader ---- #####
59
+ train_loader = dataset_VQ.DATALoader(args.dataname,
60
+ args.batch_size,
61
+ window_size=args.window_size,
62
+ unit_length=2**args.down_t)
63
+
64
+ train_loader_iter = dataset_VQ.cycle(train_loader)
65
+
66
+ val_loader = dataset_TM_eval.DATALoader(args.dataname, False,
67
+ 32,
68
+ w_vectorizer,
69
+ unit_length=2**args.down_t)
70
+
71
+ ##### ---- Network ---- #####
72
+ net = vqvae.HumanVQVAE(args, ## use args to define different parameters in different quantizers
73
+ args.nb_code,
74
+ args.code_dim,
75
+ args.output_emb_width,
76
+ args.down_t,
77
+ args.stride_t,
78
+ args.width,
79
+ args.depth,
80
+ args.dilation_growth_rate,
81
+ args.vq_act,
82
+ args.vq_norm)
83
+
84
+
85
+ if args.resume_pth :
86
+ logger.info('loading checkpoint from {}'.format(args.resume_pth))
87
+ ckpt = torch.load(args.resume_pth, map_location='cpu')
88
+ net.load_state_dict(ckpt['net'], strict=True)
89
+ net.train()
90
+ net.cuda()
91
+
92
+ ##### ---- Optimizer & Scheduler ---- #####
93
+ optimizer = optim.AdamW(net.parameters(), lr=args.lr, betas=(0.9, 0.99), weight_decay=args.weight_decay)
94
+ scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_scheduler, gamma=args.gamma)
95
+
96
+
97
+ Loss = losses.ReConsLoss(args.recons_loss, args.nb_joints)
98
+
99
+ ##### ------ warm-up ------- #####
100
+ avg_recons, avg_perplexity, avg_commit = 0., 0., 0.
101
+
102
+ for nb_iter in range(1, args.warm_up_iter):
103
+
104
+ optimizer, current_lr = update_lr_warm_up(optimizer, nb_iter, args.warm_up_iter, args.lr)
105
+
106
+ gt_motion = next(train_loader_iter)
107
+ gt_motion = gt_motion.cuda().float() # (bs, 64, dim)
108
+
109
+ pred_motion, loss_commit, perplexity = net(gt_motion)
110
+ loss_motion = Loss(pred_motion, gt_motion)
111
+ loss_vel = Loss.forward_vel(pred_motion, gt_motion)
112
+
113
+ loss = loss_motion + args.commit * loss_commit + args.loss_vel * loss_vel
114
+
115
+ optimizer.zero_grad()
116
+ loss.backward()
117
+ optimizer.step()
118
+
119
+ avg_recons += loss_motion.item()
120
+ avg_perplexity += perplexity.item()
121
+ avg_commit += loss_commit.item()
122
+
123
+ if nb_iter % args.print_iter == 0 :
124
+ avg_recons /= args.print_iter
125
+ avg_perplexity /= args.print_iter
126
+ avg_commit /= args.print_iter
127
+
128
+ logger.info(f"Warmup. Iter {nb_iter} : lr {current_lr:.5f} \t Commit. {avg_commit:.5f} \t PPL. {avg_perplexity:.2f} \t Recons. {avg_recons:.5f}")
129
+
130
+ avg_recons, avg_perplexity, avg_commit = 0., 0., 0.
131
+
132
+ ##### ---- Training ---- #####
133
+ avg_recons, avg_perplexity, avg_commit = 0., 0., 0.
134
+ best_fid, best_iter, best_div, best_top1, best_top2, best_top3, best_matching, writer, logger = eval_trans.evaluation_vqvae(args.out_dir, val_loader, net, logger, writer, 0, best_fid=1000, best_iter=0, best_div=100, best_top1=0, best_top2=0, best_top3=0, best_matching=100, eval_wrapper=eval_wrapper)
135
+
136
+ for nb_iter in range(1, args.total_iter + 1):
137
+
138
+ gt_motion = next(train_loader_iter)
139
+ gt_motion = gt_motion.cuda().float() # bs, nb_joints, joints_dim, seq_len
140
+
141
+ pred_motion, loss_commit, perplexity = net(gt_motion)
142
+ loss_motion = Loss(pred_motion, gt_motion)
143
+ loss_vel = Loss.forward_vel(pred_motion, gt_motion)
144
+
145
+ loss = loss_motion + args.commit * loss_commit + args.loss_vel * loss_vel
146
+
147
+ optimizer.zero_grad()
148
+ loss.backward()
149
+ optimizer.step()
150
+ scheduler.step()
151
+
152
+ avg_recons += loss_motion.item()
153
+ avg_perplexity += perplexity.item()
154
+ avg_commit += loss_commit.item()
155
+
156
+ if nb_iter % args.print_iter == 0 :
157
+ avg_recons /= args.print_iter
158
+ avg_perplexity /= args.print_iter
159
+ avg_commit /= args.print_iter
160
+
161
+ writer.add_scalar('./Train/L1', avg_recons, nb_iter)
162
+ writer.add_scalar('./Train/PPL', avg_perplexity, nb_iter)
163
+ writer.add_scalar('./Train/Commit', avg_commit, nb_iter)
164
+
165
+ logger.info(f"Train. Iter {nb_iter} : \t Commit. {avg_commit:.5f} \t PPL. {avg_perplexity:.2f} \t Recons. {avg_recons:.5f}")
166
+
167
+ avg_recons, avg_perplexity, avg_commit = 0., 0., 0.,
168
+
169
+ if nb_iter % args.eval_iter==0 :
170
+ best_fid, best_iter, best_div, best_top1, best_top2, best_top3, best_matching, writer, logger = eval_trans.evaluation_vqvae(args.out_dir, val_loader, net, logger, writer, nb_iter, best_fid, best_iter, best_div, best_top1, best_top2, best_top3, best_matching, eval_wrapper=eval_wrapper)
171
+