Spaces:
Running
Running
import os | |
import shutil | |
import gradio as gr | |
from PIL import Image | |
import subprocess | |
#os.chdir('Restormer') | |
from demo_dagan import * | |
# Download sample images | |
examples = [['project/cartoon2.jpg','project/video1.mp4'], | |
['project/cartoon3.jpg','project/video2.mp4'], | |
['project/celeb1.jpg','project/video1.mp4'], | |
['project/celeb2.jpg','project/video2.mp4'], | |
] | |
inference_on = ['Full Resolution Image', 'Downsampled Image'] | |
title = "DaGAN" | |
description = """ | |
Gradio demo for <b>Depth-Aware Generative Adversarial Network for Talking Head Video Generation</b>, CVPR 2022L. <a href='https://arxiv.org/abs/2203.06605'>[Paper]</a><a href='https://github.com/harlanhong/CVPR2022-DaGAN'>[Github Code]</a>\n | |
""" | |
##With Restormer, you can perform: (1) Image Denoising, (2) Defocus Deblurring, (3) Motion Deblurring, and (4) Image Deraining. | |
##To use it, simply upload your own image, or click one of the examples provided below. | |
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2203.06605'>Depth-Aware Generative Adversarial Network for Talking Head Video Generation</a> | <a href='https://github.com/harlanhong/CVPR2022-DaGAN'>Github Repo</a></p>" | |
def inference(source_image, video): | |
if not os.path.exists('temp'): | |
os.system('mkdir temp') | |
cmd = f"ffmpeg -y -ss 00:00:00 -i {video} -to 00:00:08 -c copy video_input.mp4" | |
subprocess.run(cmd.split()) | |
driving_video = "video_input.mp4" | |
output = "rst.mp4" | |
with open("config/vox-adv-256.yaml") as f: | |
config = yaml.load(f) | |
generator = G.SPADEDepthAwareGenerator(**config['model_params']['generator_params'],**config['model_params']['common_params']) | |
config['model_params']['common_params']['num_channels'] = 4 | |
kp_detector = KPD.KPDetector(**config['model_params']['kp_detector_params'],**config['model_params']['common_params']) | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
g_checkpoint = torch.load("generator.pt", map_location=device) | |
kp_checkpoint = torch.load("kp_detector.pt", map_location=device) | |
ckp_generator = OrderedDict((k.replace('module.',''),v) for k,v in g_checkpoint.items()) | |
generator.load_state_dict(ckp_generator) | |
ckp_kp_detector = OrderedDict((k.replace('module.',''),v) for k,v in kp_checkpoint.items()) | |
kp_detector.load_state_dict(ckp_kp_detector) | |
depth_encoder = depth.ResnetEncoder(18, False) | |
depth_decoder = depth.DepthDecoder(num_ch_enc=depth_encoder.num_ch_enc, scales=range(4)) | |
loaded_dict_enc = torch.load('encoder.pth') | |
loaded_dict_dec = torch.load('depth.pth') | |
filtered_dict_enc = {k: v for k, v in loaded_dict_enc.items() if k in depth_encoder.state_dict()} | |
depth_encoder.load_state_dict(filtered_dict_enc) | |
ckp_depth_decoder= {k: v for k, v in loaded_dict_dec.items() if k in depth_decoder.state_dict()} | |
depth_decoder.load_state_dict(ckp_depth_decoder) | |
depth_encoder.eval() | |
depth_decoder.eval() | |
# device = torch.device('cpu') | |
# stx() | |
generator = generator.to(device) | |
kp_detector = kp_detector.to(device) | |
depth_encoder = depth_encoder.to(device) | |
depth_decoder = depth_decoder.to(device) | |
generator.eval() | |
kp_detector.eval() | |
depth_encoder.eval() | |
depth_decoder.eval() | |
img_multiple_of = 8 | |
with torch.inference_mode(): | |
if torch.cuda.is_available(): | |
torch.cuda.ipc_collect() | |
torch.cuda.empty_cache() | |
source_image = imageio.imread(source_image) | |
reader = imageio.get_reader(driving_video) | |
fps = reader.get_meta_data()['fps'] | |
driving_video = [] | |
try: | |
for im in reader: | |
driving_video.append(im) | |
except RuntimeError: | |
pass | |
reader.close() | |
source_image = resize(source_image, (256, 256))[..., :3] | |
driving_video = [resize(frame, (256, 256))[..., :3] for frame in driving_video] | |
i = find_best_frame(source_image, driving_video) | |
print ("Best frame: " + str(i)) | |
driving_forward = driving_video[i:] | |
driving_backward = driving_video[:(i+1)][::-1] | |
sources_forward, drivings_forward, predictions_forward,depth_forward = make_animation(source_image, driving_forward, generator, kp_detector, relative=True, adapt_movement_scale=True, cpu=False) | |
sources_backward, drivings_backward, predictions_backward,depth_backward = make_animation(source_image, driving_backward, generator, kp_detector, relative=True, adapt_movement_scale=True, cpu=False) | |
predictions = predictions_backward[::-1] + predictions_forward[1:] | |
sources = sources_backward[::-1] + sources_forward[1:] | |
drivings = drivings_backward[::-1] + drivings_forward[1:] | |
depth_gray = depth_backward[::-1] + depth_forward[1:] | |
imageio.mimsave(output, [np.concatenate((img_as_ubyte(s),img_as_ubyte(d),img_as_ubyte(p)),1) for (s,d,p) in zip(sources, drivings, predictions)], fps=fps) | |
imageio.mimsave("gray.mp4", depth_gray, fps=fps) | |
# merge the gray video | |
animation = np.array(imageio.mimread(output,memtest=False)) | |
gray = np.array(imageio.mimread("gray.mp4",memtest=False)) | |
src_dst = animation[:,:,:512,:] | |
animate = animation[:,:,512:,:] | |
merge = np.concatenate((src_dst,gray,animate),2) | |
imageio.mimsave(output, merge, fps=fps) | |
return output | |
gr.Interface( | |
inference, | |
[ | |
gr.inputs.Image(type="filepath", label="Source Image"), | |
gr.inputs.Video(type='mp4',label="Driving Video"), | |
], | |
gr.outputs.Video(type="mp4", label="Output Video"), | |
title=title, | |
description=description, | |
article=article, | |
theme ="huggingface", | |
examples=examples, | |
allow_flagging=False, | |
).launch(debug=False,enable_queue=True) | |