ldkong commited on
Commit
b61fd46
·
1 Parent(s): bb64240

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +109 -3
app.py CHANGED
@@ -610,7 +610,6 @@ def name2seq(file_name):
610
 
611
  images = np.asarray(images, dtype='f') / 256.0
612
  images = images.transpose((0, 3, 1, 2))
613
- print(images.shape)
614
  images = torch.Tensor(images).unsqueeze(dim=0)
615
  return images
616
 
@@ -649,6 +648,58 @@ def display_image(file_name):
649
  imageio.imwrite('image.png', image)
650
 
651
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
652
  def run(domain_source, action_source, hair_source, top_source, bottom_source, domain_target, action_target, hair_target, top_target, bottom_target):
653
 
654
  # == Source Avatar ==
@@ -676,7 +727,6 @@ def run(domain_source, action_source, hair_source, top_source, bottom_source, do
676
  file_name_source = './Sprite/frames/domain_1/' + action_source + '/'
677
  file_name_source = file_name_source + 'front' + '_' + str(body_source) + str(bottom_source) + str(top_source) + str(hair_source) + '_'
678
 
679
- gif = display_gif_pad(file_name_source, 'avatar_source.gif')
680
 
681
  # == Target Avatar ==
682
  # body
@@ -703,7 +753,10 @@ def run(domain_source, action_source, hair_source, top_source, bottom_source, do
703
  file_name_target = './Sprite/frames/domain_2/' + action_target + '/'
704
  file_name_target = file_name_target + 'front' + '_' + str(body_target) + str(bottom_target) + str(top_target) + str(hair_target) + '_'
705
 
706
- gif_target = display_gif_pad(file_name_target, 'avatar_target.gif')
 
 
 
707
 
708
 
709
  # == Load Model ==
@@ -711,6 +764,59 @@ def run(domain_source, action_source, hair_source, top_source, bottom_source, do
711
  model.load_state_dict(torch.load('TransferVAE.pth.tar', map_location=torch.device('cpu'))['state_dict'])
712
  model.eval()
713
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
714
  return 'demo.gif'
715
 
716
 
 
610
 
611
  images = np.asarray(images, dtype='f') / 256.0
612
  images = images.transpose((0, 3, 1, 2))
 
613
  images = torch.Tensor(images).unsqueeze(dim=0)
614
  return images
615
 
 
648
  imageio.imwrite('image.png', image)
649
 
650
 
651
+ def concat(file_name):
652
+ images = []
653
+
654
+ for frame in range(8):
655
+ frame_name = '%d' % (frame)
656
+ image_filename = file_name + frame_name + '.png'
657
+ image = imageio.imread(image_filename)
658
+ images.append(image)
659
+
660
+ gif_filename = 'demo.gif'
661
+ return imageio.mimsave(gif_filename, images)
662
+
663
+
664
+ def MyPlot(frame_id, src_orig, tar_orig, src_recon, tar_recon, src_Zt, tar_Zt, src_Zf_tar_Zt, tar_Zf_src_Zt):
665
+
666
+ fig, axs = plt.subplots(2, 4, sharex=True, sharey=True, figsize=(10, 5))
667
+
668
+ axs[0, 0].imshow(src_orig)
669
+ axs[0, 0].set_title("\n\n\nOriginal\nInput")
670
+ axs[0, 0].axis('off')
671
+
672
+ axs[1, 0].imshow(tar_orig)
673
+ axs[1, 0].axis('off')
674
+
675
+ axs[0, 1].imshow(src_recon)
676
+ axs[0, 1].set_title("\n\n\nReconstructed\nOutput")
677
+ axs[0, 1].axis('off')
678
+
679
+ axs[1, 1].imshow(tar_recon)
680
+ axs[1, 1].axis('off')
681
+
682
+ axs[0, 2].imshow(src_Zt)
683
+ axs[0, 2].set_title("\n\n\nOutput\nw/ Zt")
684
+ axs[0, 2].axis('off')
685
+
686
+ axs[1, 2].imshow(tar_Zt)
687
+ axs[1, 2].axis('off')
688
+
689
+ axs[0, 3].imshow(tar_Zf_src_Zt)
690
+ axs[0, 3].set_title("\n\n\nExchange\nZt and Zf")
691
+ axs[0, 3].axis('off')
692
+
693
+ axs[1, 3].imshow(src_Zf_tar_Zt)
694
+ axs[1, 3].axis('off')
695
+
696
+ plt.subplots_adjust(hspace=0.06, wspace=0.05)
697
+
698
+ save_name = 'MyPlot_{}.png'.format(frame_id)
699
+
700
+ plt.savefig(save_name, dpi=200, format='png', bbox_inches='tight', pad_inches=0.0)
701
+
702
+
703
  def run(domain_source, action_source, hair_source, top_source, bottom_source, domain_target, action_target, hair_target, top_target, bottom_target):
704
 
705
  # == Source Avatar ==
 
727
  file_name_source = './Sprite/frames/domain_1/' + action_source + '/'
728
  file_name_source = file_name_source + 'front' + '_' + str(body_source) + str(bottom_source) + str(top_source) + str(hair_source) + '_'
729
 
 
730
 
731
  # == Target Avatar ==
732
  # body
 
753
  file_name_target = './Sprite/frames/domain_2/' + action_target + '/'
754
  file_name_target = file_name_target + 'front' + '_' + str(body_target) + str(bottom_target) + str(top_target) + str(hair_target) + '_'
755
 
756
+
757
+ # == Load Input ==
758
+ images_source = name2seq(file_name_source)
759
+ images_target = name2seq(file_name_target)
760
 
761
 
762
  # == Load Model ==
 
764
  model.load_state_dict(torch.load('TransferVAE.pth.tar', map_location=torch.device('cpu'))['state_dict'])
765
  model.eval()
766
 
767
+
768
+ # == Forward ==
769
+ with torch.no_grad():
770
+ f_mean, f_logvar, f_post, z_post_mean, z_post_logvar, z_post, z_prior_mean, z_prior_logvar, z_prior, recon_x, pred_domain_all, pred_video_class = model(x, [0]*3)
771
+
772
+ src_orig_sample = x[0, :, :, :, :]
773
+ src_recon_sample = recon_x[0, :, :, :, :]
774
+ src_f_post = f_post[0, :].unsqueeze(0)
775
+ src_z_post = z_post[0, :, :].unsqueeze(0)
776
+
777
+ tar_orig_sample = x[1, :, :, :, :]
778
+ tar_recon_sample = recon_x[1, :, :, :, :]
779
+ tar_f_post = f_post[1, :].unsqueeze(0)
780
+ tar_z_post = z_post[1, :, :].unsqueeze(0)
781
+
782
+
783
+ # == Visualize ==
784
+ for frame in range(8):
785
+
786
+ # original frame
787
+ src_orig = src_orig_sample[frame, :, :, :].detach().numpy().transpose((1, 2, 0))
788
+ tar_orig = tar_orig_sample[frame, :, :, :].detach().numpy().transpose((1, 2, 0))
789
+
790
+ # reconstructed frame
791
+ src_recon = src_recon_sample[frame, :, :, :].detach().numpy().transpose((1, 2, 0))
792
+ tar_recon = tar_recon_sample[frame, :, :, :].detach().numpy().transpose((1, 2, 0))
793
+
794
+ # Zt
795
+ f_expand_src = 0 * src_f_post.unsqueeze(1).expand(-1, 8, opt.f_dim)
796
+ zf_src = torch.cat((src_z_post, f_expand_src), dim=2)
797
+ recon_x_src = model.decoder_frame(zf_src)
798
+ src_Zt = recon_x_src.squeeze()[frame, :, :, :].detach().numpy().transpose((1, 2, 0))
799
+
800
+ f_expand_tar = 0 * tar_f_post.unsqueeze(1).expand(-1, 8, opt.f_dim)
801
+ zf_tar = torch.cat((tar_z_post, f_expand_tar), dim=2) # batch,frames,(z_dim+f_dim)
802
+ recon_x_tar = model.decoder_frame(zf_tar)
803
+ tar_Zt = recon_x_tar.squeeze()[frame, :, :, :].detach().numpy().transpose((1, 2, 0))
804
+
805
+ # Zf_Zt
806
+ f_expand_src = src_f_post.unsqueeze(1).expand(-1, 8, opt.f_dim)
807
+ zf_srcZf_tarZt = torch.cat((tar_z_post, f_expand_src), dim=2) # batch,frames,(z_dim+f_dim)
808
+ recon_x_srcZf_tarZt = model.decoder_frame(zf_srcZf_tarZt)
809
+ src_Zf_tar_Zt = recon_x_srcZf_tarZt.squeeze()[frame, :, :, :].detach().numpy().transpose((1, 2, 0))
810
+
811
+ f_expand_tar = tar_f_post.unsqueeze(1).expand(-1, 8, opt.f_dim)
812
+ zf_tarZf_srcZt = torch.cat((src_z_post, f_expand_tar), dim=2) # batch,frames,(z_dim+f_dim)
813
+ recon_x_tarZf_srcZt = model.decoder_frame(zf_tarZf_srcZt)
814
+ tar_Zf_src_Zt = recon_x_tarZf_srcZt.squeeze()[frame, :, :, :].detach().numpy().transpose((1, 2, 0))
815
+
816
+ MyPlot(frame, src_orig, tar_orig, src_recon, tar_recon, src_Zt, tar_Zt, src_Zf_tar_Zt, tar_Zf_src_Zt)
817
+
818
+ a = concat('MyPlot_')
819
+
820
  return 'demo.gif'
821
 
822