ldkong commited on
Commit
bb64240
·
1 Parent(s): 3a316b0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +536 -0
app.py CHANGED
@@ -85,6 +85,536 @@ parser.add_argument('--data_threads', type=int, default=5, help='number of data
85
  opt = parser.parse_args(args=[])
86
 
87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  def display_gif(file_name, save_name):
89
  images = []
90
 
@@ -175,6 +705,12 @@ def run(domain_source, action_source, hair_source, top_source, bottom_source, do
175
 
176
  gif_target = display_gif_pad(file_name_target, 'avatar_target.gif')
177
 
 
 
 
 
 
 
178
  return 'demo.gif'
179
 
180
 
 
85
  opt = parser.parse_args(args=[])
86
 
87
 
88
+ class GradReverse(Function):
89
+ @staticmethod
90
+ def forward(ctx, x, beta):
91
+ ctx.beta = beta
92
+ return x.view_as(x)
93
+
94
+ @staticmethod
95
+ def backward(ctx, grad_output):
96
+ grad_input = grad_output.neg() * ctx.beta
97
+ return grad_input, None
98
+
99
+
100
+ class TransferVAE_Video(nn.Module):
101
+
102
+ def __init__(self, opt):
103
+ super(TransferVAE_Video, self).__init__()
104
+ self.f_dim = opt.f_dim
105
+ self.z_dim = opt.z_dim
106
+ self.fc_dim = opt.fc_dim
107
+ self.channels = opt.channels
108
+ self.input_type = opt.input_type
109
+ self.frames = opt.num_segments
110
+ self.use_bn = opt.use_bn
111
+ self.frame_aggregation = opt.frame_aggregation
112
+ self.batch_size = opt.batch_size
113
+ self.use_attn = opt.use_attn
114
+ self.dropout_rate = opt.dropout_rate
115
+ self.num_class = opt.num_class
116
+ self.prior_sample = opt.prior_sample
117
+
118
+ if self.input_type == 'image':
119
+ import dcgan_64
120
+ self.encoder = dcgan_64.encoder(self.fc_dim, self.channels)
121
+ self.decoder = dcgan_64.decoder_woSkip(self.z_dim + self.f_dim, self.channels)
122
+ self.fc_output_dim = self.fc_dim
123
+ elif self.input_type == 'feature':
124
+ if opt.backbone == 'resnet101':
125
+ model_backnone = getattr(torchvision.models, opt.backbone)(True) # model_test is only used for getting the dim #
126
+ self.input_dim = model_backnone.fc.in_features
127
+ elif opt.backbone == 'I3Dpretrain':
128
+ self.input_dim = 2048
129
+ elif opt.backbone == 'I3Dfinetune':
130
+ self.input_dim = 2048
131
+ self.add_fc = opt.add_fc
132
+ self.enc_fc_layer1 = nn.Linear(self.input_dim, self.fc_dim)
133
+ self.dec_fc_layer1 = nn.Linear(self.fc_dim, self.input_dim)
134
+ self.fc_output_dim = self.fc_dim
135
+
136
+ if self.use_bn == 'shared':
137
+ self.bn_enc_layer1 = nn.BatchNorm1d(self.fc_output_dim)
138
+ self.bn_dec_layer1 = nn.BatchNorm1d(self.input_dim)
139
+ elif self.use_bn == 'separated':
140
+ self.bn_S_enc_layer1 = nn.BatchNorm1d(self.fc_output_dim)
141
+ self.bn_T_enc_layer1 = nn.BatchNorm1d(self.fc_output_dim)
142
+ self.bn_S_dec_layer1 = nn.BatchNorm1d(self.input_dim)
143
+ self.bn_T_dec_layer1 = nn.BatchNorm1d(self.input_dim)
144
+
145
+ if self.add_fc > 1:
146
+ self.enc_fc_layer2 = nn.Linear(self.fc_dim, self.fc_dim)
147
+ self.dec_fc_layer2 = nn.Linear(self.fc_dim, self.fc_dim)
148
+ self.fc_output_dim = self.fc_dim
149
+ ## use batchnormalization or not (if yes whether the source and target share the same batchnormalization)
150
+ if self.use_bn == 'shared':
151
+ self.bn_enc_layer2 = nn.BatchNorm1d(self.fc_output_dim)
152
+ self.bn_dec_layer2 = nn.BatchNorm1d(self.fc_dim)
153
+ elif self.use_bn == 'separated':
154
+ self.bn_S_enc_layer2 = nn.BatchNorm1d(self.fc_output_dim)
155
+ self.bn_T_enc_layer2 = nn.BatchNorm1d(self.fc_output_dim)
156
+ self.bn_S_dec_layer2 = nn.BatchNorm1d(self.fc_dim)
157
+ self.bn_T_dec_layer2 = nn.BatchNorm1d(self.fc_dim)
158
+
159
+ if self.add_fc > 2:
160
+ self.enc_fc_layer3 = nn.Linear(self.fc_dim, self.fc_dim)
161
+ self.dec_fc_layer3 = nn.Linear(self.fc_dim, self.fc_dim)
162
+ self.fc_output_dim = self.fc_dim
163
+ ## use batchnormalization or not (if yes whether the source and target share the same batchnormalization)
164
+ if self.use_bn == 'shared':
165
+ self.bn_enc_layer3 = nn.BatchNorm1d(self.fc_output_dim)
166
+ self.bn_dec_layer3 = nn.BatchNorm1d(self.fc_dim)
167
+ elif self.use_bn == 'separated':
168
+ self.bn_S_enc_layer3 = nn.BatchNorm1d(self.fc_output_dim)
169
+ self.bn_T_enc_layer3 = nn.BatchNorm1d(self.fc_output_dim)
170
+ self.bn_S_dec_layer3 = nn.BatchNorm1d(self.fc_dim)
171
+ self.bn_T_dec_layer3 = nn.BatchNorm1d(self.fc_dim)
172
+
173
+ self.z_2_out = nn.Linear(self.z_dim + self.f_dim, self.fc_output_dim)
174
+
175
+
176
+ ## nonlinearity and dropout
177
+ self.relu = nn.LeakyReLU(0.1)
178
+ self.dropout_f = nn.Dropout(p=self.dropout_rate)
179
+ self.dropout_v = nn.Dropout(p=self.dropout_rate)
180
+ # -------------------------------
181
+
182
+ ## Disentangle strcuture
183
+ # -------------------------------
184
+ #self.hidden_dim = opt.rnn_size
185
+ self.hidden_dim = opt.z_dim
186
+ self.f_rnn_layers = opt.f_rnn_layers
187
+
188
+ # Prior of content is a uniform Gaussian and prior of the dynamics is an LSTM
189
+ self.z_prior_lstm_ly1 = nn.LSTMCell(self.z_dim, self.hidden_dim)
190
+ self.z_prior_lstm_ly2 = nn.LSTMCell(self.hidden_dim, self.hidden_dim)
191
+
192
+ self.z_prior_mean = nn.Linear(self.hidden_dim, self.z_dim)
193
+ self.z_prior_logvar = nn.Linear(self.hidden_dim, self.z_dim)
194
+
195
+ # POSTERIOR DISTRIBUTION NETWORKS
196
+ # content and motion features share one lstm
197
+ self.z_lstm = nn.LSTM(self.fc_output_dim, self.hidden_dim, self.f_rnn_layers, bidirectional=True, batch_first=True)
198
+ self.f_mean = nn.Linear(self.hidden_dim * 2, self.f_dim)
199
+ self.f_logvar = nn.Linear(self.hidden_dim * 2, self.f_dim)
200
+
201
+ self.z_rnn = nn.RNN(self.hidden_dim * 2, self.hidden_dim, batch_first=True)
202
+ # Each timestep is for each z so no reshaping and feature mixing
203
+ self.z_mean = nn.Linear(self.hidden_dim, self.z_dim)
204
+ self.z_logvar = nn.Linear(self.hidden_dim, self.z_dim)
205
+ # -------------------------------
206
+
207
+ ## z_t constraints
208
+ # -------------------------------
209
+ ## adversarial loss for frame features z_t
210
+ self.fc_feature_domain_frame = nn.Linear(self.z_dim, self.z_dim)
211
+ self.fc_classifier_domain_frame = nn.Linear(self.z_dim, 2)
212
+
213
+ ## #------ aggregate frame-based features (frame feature --> video feature) ------#
214
+ if self.frame_aggregation == 'rnn':
215
+ self.bilstm = nn.LSTM(self.z_dim, self.z_dim * 2, self.f_rnn_layers, bidirectional=True, batch_first=True)
216
+ self.feat_aggregated_dim = self.z_dim * 2
217
+ elif self.frame_aggregation == 'trn': # 4. TRN (ECCV 2018) ==> fix segment # for both train/val
218
+ self.num_bottleneck = 256 # 256
219
+ self.TRN = RelationModuleMultiScale(self.z_dim, self.num_bottleneck, self.frames)
220
+ self.bn_trn_S = nn.BatchNorm1d(self.num_bottleneck)
221
+ self.bn_trn_T = nn.BatchNorm1d(self.num_bottleneck)
222
+ self.feat_aggregated_dim = self.num_bottleneck
223
+
224
+ ## adversarial loss for video features
225
+ self.fc_feature_domain_video = nn.Linear(self.feat_aggregated_dim, self.feat_aggregated_dim)
226
+ self.fc_classifier_domain_video = nn.Linear(self.feat_aggregated_dim, 2)
227
+
228
+ ## adversarial loss for each relation of features
229
+ if self.frame_aggregation == 'trn':
230
+ self.relation_domain_classifier_all = nn.ModuleList()
231
+ for i in range(self.frames-1):
232
+ relation_domain_classifier = nn.Sequential(
233
+ nn.Linear(self.feat_aggregated_dim, self.feat_aggregated_dim),
234
+ nn.ReLU(),
235
+ nn.Linear(self.feat_aggregated_dim, 2)
236
+ )
237
+ self.relation_domain_classifier_all += [relation_domain_classifier]
238
+
239
+ ## classifier for action prediction task
240
+ self.pred_classifier_video = nn.Linear(self.feat_aggregated_dim, self.num_class)
241
+
242
+ ## classifier for prediction domains
243
+ self.fc_feature_domain_latent = nn.Linear(self.f_dim, self.f_dim)
244
+ self.fc_classifier_doamin_latent = nn.Linear(self.f_dim, 2)
245
+
246
+ ## attention option
247
+ if self.use_attn == 'general':
248
+ self.attn_layer = nn.Sequential(
249
+ nn.Linear(self.feat_aggregated_dim, self.feat_aggregated_dim),
250
+ nn.Tanh(),
251
+ nn.Linear(self.feat_aggregated_dim, 1)
252
+ )
253
+
254
+ def domain_classifier_frame(self, feat, beta):
255
+ feat_fc_domain_frame = GradReverse.apply(feat, beta)
256
+ feat_fc_domain_frame = self.fc_feature_domain_frame(feat_fc_domain_frame)
257
+ feat_fc_domain_frame = self.relu(feat_fc_domain_frame)
258
+ pred_fc_domain_frame = self.fc_classifier_domain_frame(feat_fc_domain_frame)
259
+ return pred_fc_domain_frame
260
+
261
+ def domain_classifier_video(self, feat_video, beta):
262
+ feat_fc_domain_video = GradReverse.apply(feat_video, beta)
263
+ feat_fc_domain_video = self.fc_feature_domain_video(feat_fc_domain_video)
264
+ feat_fc_domain_video = self.relu(feat_fc_domain_video)
265
+ pred_fc_domain_video = self.fc_classifier_domain_video(feat_fc_domain_video)
266
+ return pred_fc_domain_video
267
+
268
+ def domain_classifier_latent(self, f):
269
+ feat_fc_domain_latent = self.fc_feature_domain_latent(f)
270
+ feat_fc_domain_latent = self.relu(feat_fc_domain_latent)
271
+ pred_fc_domain_latent = self.fc_classifier_doamin_latent(feat_fc_domain_latent)
272
+ return pred_fc_domain_latent
273
+
274
+ def domain_classifier_relation(self, feat_relation, beta):
275
+ pred_fc_domain_relation_video = None
276
+ for i in range(len(self.relation_domain_classifier_all)):
277
+ feat_relation_single = feat_relation[:,i,:].squeeze(1) # 128x1x256 --> 128x256
278
+ feat_fc_domain_relation_single = GradReverse.apply(feat_relation_single, beta) # the same beta for all relations (for now)
279
+
280
+ pred_fc_domain_relation_single = self.relation_domain_classifier_all[i](feat_fc_domain_relation_single)
281
+
282
+ if pred_fc_domain_relation_video is None:
283
+ pred_fc_domain_relation_video = pred_fc_domain_relation_single.view(-1,1,2)
284
+ else:
285
+ pred_fc_domain_relation_video = torch.cat((pred_fc_domain_relation_video, pred_fc_domain_relation_single.view(-1,1,2)), 1)
286
+
287
+ pred_fc_domain_relation_video = pred_fc_domain_relation_video.view(-1,2)
288
+
289
+ return pred_fc_domain_relation_video
290
+
291
+ def get_trans_attn(self, pred_domain):
292
+ softmax = nn.Softmax(dim=1)
293
+ logsoftmax = nn.LogSoftmax(dim=1)
294
+ entropy = torch.sum(-softmax(pred_domain) * logsoftmax(pred_domain), 1)
295
+ weights = 1 - entropy
296
+ return weights
297
+
298
+ def get_general_attn(self, feat):
299
+ num_segments = feat.size()[1]
300
+ feat = feat.view(-1, feat.size()[-1]) # reshape features: 128x4x256 --> (128x4)x256
301
+ weights = self.attn_layer(feat) # e.g. (128x4)x1
302
+ weights = weights.view(-1, num_segments, weights.size()[-1]) # reshape attention weights: (128x4)x1 --> 128x4x1
303
+ weights = F.softmax(weights, dim=1) # softmax over segments ==> 128x4x1
304
+ return weights
305
+
306
+ def get_attn_feat_relation(self, feat_fc, pred_domain, num_segments):
307
+ if self.use_attn == 'TransAttn':
308
+ weights_attn = self.get_trans_attn(pred_domain)
309
+ elif self.use_attn == 'general':
310
+ weights_attn = self.get_general_attn(feat_fc)
311
+
312
+ weights_attn = weights_attn.view(-1, num_segments-1, 1).repeat(1,1,feat_fc.size()[-1]) # reshape & repeat weights (e.g. 16 x 4 x 256)
313
+ feat_fc_attn = (weights_attn+1) * feat_fc
314
+
315
+ return feat_fc_attn, weights_attn[:,:,0]
316
+
317
+
318
+ def encode_and_sample_post(self, x):
319
+ if isinstance(x, list):
320
+ conv_x = self.encoder_frame(x[0])
321
+ else:
322
+ conv_x = self.encoder_frame(x)
323
+
324
+ # pass the bidirectional lstm
325
+ lstm_out, _ = self.z_lstm(conv_x)
326
+
327
+ # get f:
328
+ backward = lstm_out[:, 0, self.hidden_dim:2 * self.hidden_dim]
329
+ frontal = lstm_out[:, self.frames - 1, 0:self.hidden_dim]
330
+ lstm_out_f = torch.cat((frontal, backward), dim=1)
331
+ f_mean = self.f_mean(lstm_out_f)
332
+ f_logvar = self.f_logvar(lstm_out_f)
333
+ f_post = self.reparameterize(f_mean, f_logvar, random_sampling=False)
334
+
335
+ # pass to one direction rnn
336
+ features, _ = self.z_rnn(lstm_out)
337
+ z_mean = self.z_mean(features)
338
+ z_logvar = self.z_logvar(features)
339
+ z_post = self.reparameterize(z_mean, z_logvar, random_sampling=False)
340
+
341
+ if isinstance(x, list):
342
+ f_mean_list = [f_mean]
343
+ f_post_list = [f_post]
344
+ for t in range(1,3,1):
345
+ conv_x = self.encoder_frame(x[t])
346
+ lstm_out, _ = self.z_lstm(conv_x)
347
+ # get f:
348
+ backward = lstm_out[:, 0, self.hidden_dim:2 * self.hidden_dim]
349
+ frontal = lstm_out[:, self.frames - 1, 0:self.hidden_dim]
350
+ lstm_out_f = torch.cat((frontal, backward), dim=1)
351
+ f_mean = self.f_mean(lstm_out_f)
352
+ f_logvar = self.f_logvar(lstm_out_f)
353
+ f_post = self.reparameterize(f_mean, f_logvar, random_sampling=False)
354
+ f_mean_list.append(f_mean)
355
+ f_post_list.append(f_post)
356
+ f_mean = f_mean_list
357
+ f_post = f_post_list
358
+ # f_mean and f_post are list if triple else not
359
+ return f_mean, f_logvar, f_post, z_mean, z_logvar, z_post
360
+
361
+ def decoder_frame(self,zf):
362
+ if self.input_type == 'image':
363
+ recon_x = self.decoder(zf)
364
+ return recon_x
365
+
366
+ if self.input_type == 'feature':
367
+ zf = self.z_2_out(zf) # batch,frames,(z_dim+f_dim) -> batch,frames,fc_output_dim
368
+ zf = self.relu(zf)
369
+
370
+ if self.add_fc > 2:
371
+ zf = self.dec_fc_layer3(zf)
372
+ if self.use_bn == 'shared':
373
+ zf = self.bn_dec_layer3(zf)
374
+ elif self.use_bn == 'separated':
375
+ zf_src = self.bn_S_dec_layer3(zf[:self.batchsize,:,:])
376
+ zf_tar = self.bn_T_dec_layer3(zf[self.batchsize:,:,:])
377
+ zf = torch.cat([zf_src,zf_tar],axis=0)
378
+ zf = self.relu(zf)
379
+
380
+ if self.add_fc > 1:
381
+ zf = self.dec_fc_layer2(zf)
382
+ if self.use_bn == 'shared':
383
+ zf = self.bn_dec_layer2(zf)
384
+ elif self.use_bn == 'separated':
385
+ zf_src = self.bn_S_dec_layer2(zf[:self.batchsize,:,:])
386
+ zf_tar = self.bn_T_dec_layer2(zf[self.batchsize:,:,:])
387
+ zf = torch.cat([zf_src,zf_tar],axis=0)
388
+ zf = self.relu(zf)
389
+
390
+
391
+ zf = self.dec_fc_layer1(zf)
392
+ if self.use_bn == 'shared':
393
+ zf = self.bn_dec_layer2(zf)
394
+ elif self.use_bn == 'separated':
395
+ zf_src = self.bn_S_dec_layer2(zf[:self.batchsize,:,:])
396
+ zf_tar = self.bn_T_dec_layer2(zf[self.batchsize:,:,:])
397
+ zf = torch.cat([zf_src,zf_tar],axis=0)
398
+ recon_x = self.relu(zf)
399
+ return recon_x
400
+
401
+ def encoder_frame(self, x):
402
+ if self.input_type == 'image':
403
+ # input x is list of length Frames [batchsize, channels, size, size]
404
+ # convert it to [batchsize, frames, channels, size, size]
405
+ # [batch_size, frames, channels, size, size] to [batch_size * frames, channels, size, size]
406
+ x_shape = x.shape
407
+ x = x.view(-1, x_shape[-3], x_shape[-2], x_shape[-1])
408
+ x_embed = self.encoder(x)[0]
409
+ # to [batch_size,frames,embed_dim]
410
+
411
+ return x_embed.view(x_shape[0], x_shape[1], -1)
412
+
413
+
414
+ if self.input_type == 'feature':
415
+ # input is [batchsize, framew, input_dim]
416
+ x_embed = self.enc_fc_layer1(x)
417
+ ## use batchnormalization or not (if yes whether the source and target share the same batchnormalization)
418
+ if self.use_bn == 'shared':
419
+ x_embed = self.bn_enc_layer1(x_embed)
420
+ elif self.use_bn == 'separated':
421
+ x_embed_src = self.bn_S_enc_layer1(x_embed[:self.batchsize,:,:])
422
+ x_embed_tar = self.bn_T_enc_layer1(x_embed[self.batchsize:,:,:])
423
+ x_embed = torch.cat([x_embed_src,x_embed_tar],axis=0)
424
+ x_embed = self.relu(x_embed)
425
+
426
+ if self.add_fc > 1:
427
+ x_embed = self.enc_fc_layer2(x_embed)
428
+ if self.use_bn == 'shared':
429
+ x_embed = self.bn_enc_layer2(x_embed)
430
+ elif self.use_bn == 'separated':
431
+ x_embed_src = self.bn_S_enc_layer2(x_embed[:self.batchsize,:,:])
432
+ x_embed_tar = self.bn_T_enc_layer2(x_embed[self.batchsize:,:,:])
433
+ x_embed = torch.cat([x_embed_src,x_embed_tar],axis=0)
434
+ x_embed = self.relu(x_embed)
435
+
436
+ if self.add_fc > 2:
437
+ x_embed = self.enc_fc_layer3(x_embed)
438
+ if self.use_bn == 'shared':
439
+ x_embed = self.bn_enc_layer3(x_embed)
440
+ elif self.use_bn == 'separated':
441
+ x_embed_src = self.bn_S_enc_layer3(x_embed[:self.batchsize,:,:])
442
+ x_embed_tar = self.bn_T_enc_layer3(x_embed[self.batchsize:,:,:])
443
+ x_embed = torch.cat([x_embed_src,x_embed_tar],axis=0)
444
+ x_embed = self.relu(x_embed)
445
+
446
+ ## [batchsize, frame, output_dim]
447
+ return x_embed
448
+
449
+
450
+ def reparameterize(self, mean, logvar, random_sampling=True):
451
+ # Reparametrization occurs only if random sampling is set to true, otherwise mean is returned
452
+ if random_sampling is True:
453
+ eps = torch.randn_like(logvar)
454
+ std = torch.exp(0.5 * logvar)
455
+ z = mean + eps * std
456
+ return z
457
+ else:
458
+ return mean
459
+
460
+ def sample_z_prior_train(self, z_post, random_sampling=True):
461
+ z_out = None # This will ultimately store all z_s in the format [batch_size, frames, z_dim]
462
+ z_means = None
463
+ z_logvars = None
464
+ batch_size = z_post.shape[0]
465
+
466
+ z_t = torch.zeros(batch_size, self.z_dim).cpu()
467
+ h_t_ly1 = torch.zeros(batch_size, self.hidden_dim).cpu()
468
+ c_t_ly1 = torch.zeros(batch_size, self.hidden_dim).cpu()
469
+ h_t_ly2 = torch.zeros(batch_size, self.hidden_dim).cpu()
470
+ c_t_ly2 = torch.zeros(batch_size, self.hidden_dim).cpu()
471
+
472
+ for i in range(self.frames):
473
+ # two layer LSTM and two one-layer FC
474
+ h_t_ly1, c_t_ly1 = self.z_prior_lstm_ly1(z_t, (h_t_ly1, c_t_ly1))
475
+ h_t_ly2, c_t_ly2 = self.z_prior_lstm_ly2(h_t_ly1, (h_t_ly2, c_t_ly2))
476
+
477
+ z_mean_t = self.z_prior_mean(h_t_ly2)
478
+ z_logvar_t = self.z_prior_logvar(h_t_ly2)
479
+ z_prior = self.reparameterize(z_mean_t, z_logvar_t, random_sampling)
480
+ if z_out is None:
481
+ # If z_out is none it means z_t is z_1, hence store it in the format [batch_size, 1, z_dim]
482
+ z_out = z_prior.unsqueeze(1)
483
+ z_means = z_mean_t.unsqueeze(1)
484
+ z_logvars = z_logvar_t.unsqueeze(1)
485
+ else:
486
+ # If z_out is not none, z_t is not the initial z and hence append it to the previous z_ts collected in z_out
487
+ z_out = torch.cat((z_out, z_prior.unsqueeze(1)), dim=1)
488
+ z_means = torch.cat((z_means, z_mean_t.unsqueeze(1)), dim=1)
489
+ z_logvars = torch.cat((z_logvars, z_logvar_t.unsqueeze(1)), dim=1)
490
+ z_t = z_post[:,i,:]
491
+ return z_means, z_logvars, z_out
492
+
493
+ # If random sampling is true, reparametrization occurs else z_t is just set to the mean
494
+ def sample_z(self, batch_size, random_sampling=True):
495
+ z_out = None # This will ultimately store all z_s in the format [batch_size, frames, z_dim]
496
+ z_means = None
497
+ z_logvars = None
498
+
499
+ # All states are initially set to 0, especially z_0 = 0
500
+ z_t = torch.zeros(batch_size, self.z_dim).cpu()
501
+ # z_mean_t = torch.zeros(batch_size, self.z_dim)
502
+ # z_logvar_t = torch.zeros(batch_size, self.z_dim)
503
+ h_t_ly1 = torch.zeros(batch_size, self.hidden_dim).cpu()
504
+ c_t_ly1 = torch.zeros(batch_size, self.hidden_dim).cpu()
505
+ h_t_ly2 = torch.zeros(batch_size, self.hidden_dim).cpu()
506
+ c_t_ly2 = torch.zeros(batch_size, self.hidden_dim).cpu()
507
+ for _ in range(self.frames):
508
+ # h_t, c_t = self.z_prior_lstm(z_t, (h_t, c_t))
509
+ # two layer LSTM and two one-layer FC
510
+ h_t_ly1, c_t_ly1 = self.z_prior_lstm_ly1(z_t, (h_t_ly1, c_t_ly1))
511
+ h_t_ly2, c_t_ly2 = self.z_prior_lstm_ly2(h_t_ly1, (h_t_ly2, c_t_ly2))
512
+
513
+ z_mean_t = self.z_prior_mean(h_t_ly2)
514
+ z_logvar_t = self.z_prior_logvar(h_t_ly2)
515
+ z_t = self.reparameterize(z_mean_t, z_logvar_t, random_sampling)
516
+ if z_out is None:
517
+ # If z_out is none it means z_t is z_1, hence store it in the format [batch_size, 1, z_dim]
518
+ z_out = z_t.unsqueeze(1)
519
+ z_means = z_mean_t.unsqueeze(1)
520
+ z_logvars = z_logvar_t.unsqueeze(1)
521
+ else:
522
+ # If z_out is not none, z_t is not the initial z and hence append it to the previous z_ts collected in z_out
523
+ z_out = torch.cat((z_out, z_t.unsqueeze(1)), dim=1)
524
+ z_means = torch.cat((z_means, z_mean_t.unsqueeze(1)), dim=1)
525
+ z_logvars = torch.cat((z_logvars, z_logvar_t.unsqueeze(1)), dim=1)
526
+ return z_means, z_logvars, z_out
527
+
528
+ def forward(self, x, beta):
529
+ # beta [beta_relation, beta_video, beta_frame]
530
+ f_mean, f_logvar, f_post, z_mean_post, z_logvar_post, z_post = self.encode_and_sample_post(x)
531
+ if self.prior_sample == 'random':
532
+ z_mean_prior, z_logvar_prior, z_prior = self.sample_z(z_post.size(0),random_sampling=False)
533
+ elif self.prior_sample == 'post':
534
+ z_mean_prior, z_logvar_prior, z_prior = self.sample_z_prior_train(z_post, random_sampling=False)
535
+
536
+
537
+ if isinstance(f_post, list):
538
+ f_expand = f_post[0].unsqueeze(1).expand(-1, self.frames, self.f_dim)
539
+ else:
540
+ f_expand = f_post.unsqueeze(1).expand(-1, self.frames, self.f_dim)
541
+ zf = torch.cat((z_post, f_expand), dim=2) # batch,frames,(z_dim+f_dim)
542
+
543
+ ## reconcstruct x
544
+ recon_x = self.decoder_frame(zf)
545
+
546
+ ## For constraints on z_post [batch,frame,z_dim] and f_post [batch,f_dim]
547
+ pred_domain_all = [] # list save domain predictions (1) z_post (frame level) (2) each z_post_relation (if trn) (3) z_post (video level) (4)f_post
548
+
549
+ #1. adversarial on z_post (frame level)
550
+ z_post_feat = z_post.view(-1, z_post.size()[-1]) # e.g. 32 x 5 x 2048 --> 160 x 2048
551
+ z_post_feat = self.dropout_f(z_post_feat)
552
+ pred_fc_domain_frame = self.domain_classifier_frame(z_post_feat, beta[2])
553
+ pred_fc_domain_frame = pred_fc_domain_frame.view((z_post.size(0), self.frames) + pred_fc_domain_frame.size()[-1:])
554
+ pred_domain_all.append(pred_fc_domain_frame)
555
+
556
+ #2 adversarial on z_post (video level, relation level if trn is used)
557
+
558
+ if self.frame_aggregation == 'rnn':
559
+ self.bilstm.flatten_parameters()
560
+ z_post_video_feat, _ = self.bilstm(z_post)
561
+ backward = z_post_video_feat[:, 0, self.z_dim:2 * self.z_dim]
562
+ frontal = z_post_video_feat[:, self.frames - 1, 0:self.z_dim]
563
+ z_post_video_feat = torch.cat((frontal, backward), dim=1)
564
+ pred_fc_domain_relation = []
565
+ pred_domain_all.append(pred_fc_domain_relation)
566
+
567
+ elif self.frame_aggregation == 'trn':
568
+ z_post_video_relation = self.TRN(z_post) ## [batch, frame-1, self.feat_aggregated_dim]
569
+
570
+ # adversarial branch for each relation
571
+ pred_fc_domain_relation = self.domain_classifier_relation(z_post_video_relation, beta[0])
572
+ pred_domain_all.append(pred_fc_domain_relation.view((z_post.size(0), z_post_video_relation.size()[1]) + pred_fc_domain_relation.size()[-1:]))
573
+
574
+ # transferable attention
575
+ if self.use_attn != 'none': # get the attention weighting
576
+ z_post_video_relation_attn, _ = self.get_attn_feat_relation(z_post_video_relation, pred_fc_domain_relation, self.frames)
577
+
578
+ # sum up relation features (ignore 1-relation)
579
+ z_post_video_feat = torch.sum(z_post_video_relation_attn, 1)
580
+
581
+
582
+ z_post_video_feat = self.dropout_v(z_post_video_feat)
583
+
584
+ pred_fc_domain_video = self.domain_classifier_video(z_post_video_feat, beta[1])
585
+ pred_fc_domain_video = pred_fc_domain_video.view((z_post.size(0),) + pred_fc_domain_video.size()[-1:])
586
+ pred_domain_all.append(pred_fc_domain_video)
587
+
588
+
589
+ #3. video prediction
590
+ pred_video_class = self.pred_classifier_video(z_post_video_feat)
591
+
592
+ #4. domain prediction on f
593
+ if isinstance(f_post, list):
594
+ pred_fc_domain_latent = self.domain_classifier_latent(f_post[0])
595
+ else:
596
+ pred_fc_domain_latent = self.domain_classifier_latent(f_post)
597
+ pred_domain_all.append(pred_fc_domain_latent)
598
+
599
+ return f_mean, f_logvar, f_post, z_mean_post, z_logvar_post, z_post, z_mean_prior, z_logvar_prior, z_prior, recon_x, pred_domain_all, pred_video_class
600
+
601
+
602
+ def name2seq(file_name):
603
+ images = []
604
+
605
+ for frame in range(8):
606
+ frame_name = '%d' % (frame)
607
+ image_filename = file_name + frame_name + '.png'
608
+ image = imageio.imread(image_filename)
609
+ images.append(image[:, :, :3])
610
+
611
+ images = np.asarray(images, dtype='f') / 256.0
612
+ images = images.transpose((0, 3, 1, 2))
613
+ print(images.shape)
614
+ images = torch.Tensor(images).unsqueeze(dim=0)
615
+ return images
616
+
617
+
618
  def display_gif(file_name, save_name):
619
  images = []
620
 
 
705
 
706
  gif_target = display_gif_pad(file_name_target, 'avatar_target.gif')
707
 
708
+
709
+ # == Load Model ==
710
+ model = TransferVAE_Video(opt)
711
+ model.load_state_dict(torch.load('TransferVAE.pth.tar', map_location=torch.device('cpu'))['state_dict'])
712
+ model.eval()
713
+
714
  return 'demo.gif'
715
 
716