fffiloni's picture
Update app.py
b86bc3a
raw
history blame
3.88 kB
import os
os.system("wget https://huggingface.co/akhaliq/lama/resolve/main/best.ckpt")
os.system("pip install imageio")
os.system("pip install albumentations==0.5.2")
os.system("pip install opencv-python")
os.system("pip install ffmpeg-python")
os.system("pip install moviepy")
import cv2
import paddlehub as hub
import gradio as gr
import torch
from PIL import Image, ImageOps
import numpy as np
import imageio
from moviepy.editor import *
os.mkdir("data")
os.rename("best.ckpt", "models/best.ckpt")
os.mkdir("dataout")
def get_frames(video_in):
frames = []
#resize the video
clip = VideoFileClip(video_in)
#check fps
if clip.fps > 30:
print("vide rate is over 30, resetting to 30")
clip_resized = clip.resize(height=256)
clip_resized.write_videofile("video_resized.mp4", fps=30)
else:
print("video rate is OK")
clip_resized = clip.resize(height=256)
clip_resized.write_videofile("video_resized.mp4", fps=clip.fps)
print("video resized to 512 height")
# Opens the Video file with CV2
cap= cv2.VideoCapture("video_resized.mp4")
fps = cap.get(cv2.CAP_PROP_FPS)
print("video fps: " + str(fps))
i=0
while(cap.isOpened()):
ret, frame = cap.read()
if ret == False:
break
cv2.imwrite('kang'+str(i)+'.jpg',frame)
frames.append('kang'+str(i)+'.jpg')
i+=1
cap.release()
cv2.destroyAllWindows()
print("broke the video into frames")
return frames, fps
def create_video(frames, fps, type):
print("building video result")
clip = ImageSequenceClip(frames, fps=fps)
clip.write_videofile(type + "_result.mp4", fps=fps)
return type + "_result.mp4"
def magic_lama(img):
i = img
img = Image.open(img)
mask = Image.open("./masks/modelscope-mask.png")
inverted_mask = ImageOps.invert(mask)
imageio.imwrite(f"./data/data.png", img)
imageio.imwrite(f"./data/data_mask.png", inverted_mask)
os.system('python predict.py model.path=/home/user/app/ indir=/home/user/app/data/ outdir=/home/user/app/dataout/ device=cpu')
return f"./dataout/data_mask.png"
def infer(video_in):
# 1. break video into frames and get FPS
break_vid = get_frames(video_in)
frames_list= break_vid[0]
fps = break_vid[1]
#n_frame = int(trim_value*fps)
n_frame = len(frames_list)
if n_frame >= len(frames_list):
print("video is shorter than the cut value")
n_frame = len(frames_list)
# 2. prepare frames result arrays
result_frames = []
print("set stop frames to: " + str(n_frame))
for i in frames_list[0:int(n_frame)]:
lama_frame = magic_lama(i)
lama_frame = Image.open(lama_frame)
imageio.imwrite(f"cleaned_frame_{i}", lama_frame)
result_frames.append(f"cleaned_frame_{i}")
print("frame " + i + "/" + str(n_frame) + ": done;")
final_vid = create_video(result_frames, fps, "cleaned")
files = [final_vid]
return final_vid, files
inputs = [gr.Video(label="Input", source="upload", type="filepath")]
outputs = [gr.Video(label="output"),
gr.Files(label="Download Video")]
title = "LaMa Video Watermark Remover"
description = "LaMa: Resolution-robust Large Mask Inpainting with Fourier Convolutions. <br />This demo in meant to be used as a watermark remover on Modelscope generated videos. <br />Simply upload your modelscope video and hit Submit"
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2109.07161' target='_blank'>Resolution-robust Large Mask Inpainting with Fourier Convolutions</a> | <a href='https://github.com/saic-mdal/lama' target='_blank'>Github Repo</a></p>"
gr.Interface(infer, inputs, outputs, title=title,
description=description, article=article).launch()