Update app.py
Browse files
app.py
CHANGED
@@ -610,7 +610,6 @@ def name2seq(file_name):
|
|
610 |
|
611 |
images = np.asarray(images, dtype='f') / 256.0
|
612 |
images = images.transpose((0, 3, 1, 2))
|
613 |
-
print(images.shape)
|
614 |
images = torch.Tensor(images).unsqueeze(dim=0)
|
615 |
return images
|
616 |
|
@@ -649,6 +648,58 @@ def display_image(file_name):
|
|
649 |
imageio.imwrite('image.png', image)
|
650 |
|
651 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
652 |
def run(domain_source, action_source, hair_source, top_source, bottom_source, domain_target, action_target, hair_target, top_target, bottom_target):
|
653 |
|
654 |
# == Source Avatar ==
|
@@ -676,7 +727,6 @@ def run(domain_source, action_source, hair_source, top_source, bottom_source, do
|
|
676 |
file_name_source = './Sprite/frames/domain_1/' + action_source + '/'
|
677 |
file_name_source = file_name_source + 'front' + '_' + str(body_source) + str(bottom_source) + str(top_source) + str(hair_source) + '_'
|
678 |
|
679 |
-
gif = display_gif_pad(file_name_source, 'avatar_source.gif')
|
680 |
|
681 |
# == Target Avatar ==
|
682 |
# body
|
@@ -703,7 +753,10 @@ def run(domain_source, action_source, hair_source, top_source, bottom_source, do
|
|
703 |
file_name_target = './Sprite/frames/domain_2/' + action_target + '/'
|
704 |
file_name_target = file_name_target + 'front' + '_' + str(body_target) + str(bottom_target) + str(top_target) + str(hair_target) + '_'
|
705 |
|
706 |
-
|
|
|
|
|
|
|
707 |
|
708 |
|
709 |
# == Load Model ==
|
@@ -711,6 +764,59 @@ def run(domain_source, action_source, hair_source, top_source, bottom_source, do
|
|
711 |
model.load_state_dict(torch.load('TransferVAE.pth.tar', map_location=torch.device('cpu'))['state_dict'])
|
712 |
model.eval()
|
713 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
714 |
return 'demo.gif'
|
715 |
|
716 |
|
|
|
610 |
|
611 |
images = np.asarray(images, dtype='f') / 256.0
|
612 |
images = images.transpose((0, 3, 1, 2))
|
|
|
613 |
images = torch.Tensor(images).unsqueeze(dim=0)
|
614 |
return images
|
615 |
|
|
|
648 |
imageio.imwrite('image.png', image)
|
649 |
|
650 |
|
651 |
+
def concat(file_name):
|
652 |
+
images = []
|
653 |
+
|
654 |
+
for frame in range(8):
|
655 |
+
frame_name = '%d' % (frame)
|
656 |
+
image_filename = file_name + frame_name + '.png'
|
657 |
+
image = imageio.imread(image_filename)
|
658 |
+
images.append(image)
|
659 |
+
|
660 |
+
gif_filename = 'demo.gif'
|
661 |
+
return imageio.mimsave(gif_filename, images)
|
662 |
+
|
663 |
+
|
664 |
+
def MyPlot(frame_id, src_orig, tar_orig, src_recon, tar_recon, src_Zt, tar_Zt, src_Zf_tar_Zt, tar_Zf_src_Zt):
|
665 |
+
|
666 |
+
fig, axs = plt.subplots(2, 4, sharex=True, sharey=True, figsize=(10, 5))
|
667 |
+
|
668 |
+
axs[0, 0].imshow(src_orig)
|
669 |
+
axs[0, 0].set_title("\n\n\nOriginal\nInput")
|
670 |
+
axs[0, 0].axis('off')
|
671 |
+
|
672 |
+
axs[1, 0].imshow(tar_orig)
|
673 |
+
axs[1, 0].axis('off')
|
674 |
+
|
675 |
+
axs[0, 1].imshow(src_recon)
|
676 |
+
axs[0, 1].set_title("\n\n\nReconstructed\nOutput")
|
677 |
+
axs[0, 1].axis('off')
|
678 |
+
|
679 |
+
axs[1, 1].imshow(tar_recon)
|
680 |
+
axs[1, 1].axis('off')
|
681 |
+
|
682 |
+
axs[0, 2].imshow(src_Zt)
|
683 |
+
axs[0, 2].set_title("\n\n\nOutput\nw/ Zt")
|
684 |
+
axs[0, 2].axis('off')
|
685 |
+
|
686 |
+
axs[1, 2].imshow(tar_Zt)
|
687 |
+
axs[1, 2].axis('off')
|
688 |
+
|
689 |
+
axs[0, 3].imshow(tar_Zf_src_Zt)
|
690 |
+
axs[0, 3].set_title("\n\n\nExchange\nZt and Zf")
|
691 |
+
axs[0, 3].axis('off')
|
692 |
+
|
693 |
+
axs[1, 3].imshow(src_Zf_tar_Zt)
|
694 |
+
axs[1, 3].axis('off')
|
695 |
+
|
696 |
+
plt.subplots_adjust(hspace=0.06, wspace=0.05)
|
697 |
+
|
698 |
+
save_name = 'MyPlot_{}.png'.format(frame_id)
|
699 |
+
|
700 |
+
plt.savefig(save_name, dpi=200, format='png', bbox_inches='tight', pad_inches=0.0)
|
701 |
+
|
702 |
+
|
703 |
def run(domain_source, action_source, hair_source, top_source, bottom_source, domain_target, action_target, hair_target, top_target, bottom_target):
|
704 |
|
705 |
# == Source Avatar ==
|
|
|
727 |
file_name_source = './Sprite/frames/domain_1/' + action_source + '/'
|
728 |
file_name_source = file_name_source + 'front' + '_' + str(body_source) + str(bottom_source) + str(top_source) + str(hair_source) + '_'
|
729 |
|
|
|
730 |
|
731 |
# == Target Avatar ==
|
732 |
# body
|
|
|
753 |
file_name_target = './Sprite/frames/domain_2/' + action_target + '/'
|
754 |
file_name_target = file_name_target + 'front' + '_' + str(body_target) + str(bottom_target) + str(top_target) + str(hair_target) + '_'
|
755 |
|
756 |
+
|
757 |
+
# == Load Input ==
|
758 |
+
images_source = name2seq(file_name_source)
|
759 |
+
images_target = name2seq(file_name_target)
|
760 |
|
761 |
|
762 |
# == Load Model ==
|
|
|
764 |
model.load_state_dict(torch.load('TransferVAE.pth.tar', map_location=torch.device('cpu'))['state_dict'])
|
765 |
model.eval()
|
766 |
|
767 |
+
|
768 |
+
# == Forward ==
|
769 |
+
with torch.no_grad():
|
770 |
+
f_mean, f_logvar, f_post, z_post_mean, z_post_logvar, z_post, z_prior_mean, z_prior_logvar, z_prior, recon_x, pred_domain_all, pred_video_class = model(x, [0]*3)
|
771 |
+
|
772 |
+
src_orig_sample = x[0, :, :, :, :]
|
773 |
+
src_recon_sample = recon_x[0, :, :, :, :]
|
774 |
+
src_f_post = f_post[0, :].unsqueeze(0)
|
775 |
+
src_z_post = z_post[0, :, :].unsqueeze(0)
|
776 |
+
|
777 |
+
tar_orig_sample = x[1, :, :, :, :]
|
778 |
+
tar_recon_sample = recon_x[1, :, :, :, :]
|
779 |
+
tar_f_post = f_post[1, :].unsqueeze(0)
|
780 |
+
tar_z_post = z_post[1, :, :].unsqueeze(0)
|
781 |
+
|
782 |
+
|
783 |
+
# == Visualize ==
|
784 |
+
for frame in range(8):
|
785 |
+
|
786 |
+
# original frame
|
787 |
+
src_orig = src_orig_sample[frame, :, :, :].detach().numpy().transpose((1, 2, 0))
|
788 |
+
tar_orig = tar_orig_sample[frame, :, :, :].detach().numpy().transpose((1, 2, 0))
|
789 |
+
|
790 |
+
# reconstructed frame
|
791 |
+
src_recon = src_recon_sample[frame, :, :, :].detach().numpy().transpose((1, 2, 0))
|
792 |
+
tar_recon = tar_recon_sample[frame, :, :, :].detach().numpy().transpose((1, 2, 0))
|
793 |
+
|
794 |
+
# Zt
|
795 |
+
f_expand_src = 0 * src_f_post.unsqueeze(1).expand(-1, 8, opt.f_dim)
|
796 |
+
zf_src = torch.cat((src_z_post, f_expand_src), dim=2)
|
797 |
+
recon_x_src = model.decoder_frame(zf_src)
|
798 |
+
src_Zt = recon_x_src.squeeze()[frame, :, :, :].detach().numpy().transpose((1, 2, 0))
|
799 |
+
|
800 |
+
f_expand_tar = 0 * tar_f_post.unsqueeze(1).expand(-1, 8, opt.f_dim)
|
801 |
+
zf_tar = torch.cat((tar_z_post, f_expand_tar), dim=2) # batch,frames,(z_dim+f_dim)
|
802 |
+
recon_x_tar = model.decoder_frame(zf_tar)
|
803 |
+
tar_Zt = recon_x_tar.squeeze()[frame, :, :, :].detach().numpy().transpose((1, 2, 0))
|
804 |
+
|
805 |
+
# Zf_Zt
|
806 |
+
f_expand_src = src_f_post.unsqueeze(1).expand(-1, 8, opt.f_dim)
|
807 |
+
zf_srcZf_tarZt = torch.cat((tar_z_post, f_expand_src), dim=2) # batch,frames,(z_dim+f_dim)
|
808 |
+
recon_x_srcZf_tarZt = model.decoder_frame(zf_srcZf_tarZt)
|
809 |
+
src_Zf_tar_Zt = recon_x_srcZf_tarZt.squeeze()[frame, :, :, :].detach().numpy().transpose((1, 2, 0))
|
810 |
+
|
811 |
+
f_expand_tar = tar_f_post.unsqueeze(1).expand(-1, 8, opt.f_dim)
|
812 |
+
zf_tarZf_srcZt = torch.cat((src_z_post, f_expand_tar), dim=2) # batch,frames,(z_dim+f_dim)
|
813 |
+
recon_x_tarZf_srcZt = model.decoder_frame(zf_tarZf_srcZt)
|
814 |
+
tar_Zf_src_Zt = recon_x_tarZf_srcZt.squeeze()[frame, :, :, :].detach().numpy().transpose((1, 2, 0))
|
815 |
+
|
816 |
+
MyPlot(frame, src_orig, tar_orig, src_recon, tar_recon, src_Zt, tar_Zt, src_Zf_tar_Zt, tar_Zf_src_Zt)
|
817 |
+
|
818 |
+
a = concat('MyPlot_')
|
819 |
+
|
820 |
return 'demo.gif'
|
821 |
|
822 |
|