File size: 3,725 Bytes
8b9f861 98c1e0d 8b9f861 0009612 8b9f861 0009612 8b9f861 98c1e0d d4991cc 98c1e0d 9aa8011 98c1e0d 9aa8011 0009612 98c1e0d 9aa8011 d4991cc 98c1e0d d4991cc 98c1e0d d4991cc 9aa8011 0009612 98c1e0d 9aa8011 0009612 2344f4c 98c1e0d 2c8ca5e 98c1e0d d4991cc 98c1e0d 9aa8011 98c1e0d d4991cc 98c1e0d d4991cc 98c1e0d 9aa8011 98c1e0d 8b9f861 fd29508 98c1e0d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
import gradio as gr
from refacer import Refacer
import argparse
import os
import requests
import tempfile
import shutil
# Hugging Face URL to download the model
model_url = "https://huggingface.co/ofter/4x-UltraSharp/resolve/main/inswapper_128.onnx"
model_path = "/home/user/app/inswapper_128.onnx" # absolute path for the model in your environment
# Function to download the model if not exists
def download_model():
if not os.path.exists(model_path):
print("Downloading the inswapper_128.onnx model...")
response = requests.get(model_url)
if response.status_code == 200:
with open(model_path, 'wb') as f:
f.write(response.content)
print("Model downloaded successfully!")
else:
print(f"Error: Model download failed. Status code: {response.status_code}")
else:
print("Model already exists.")
# Download the model when the script runs
download_model()
# Argument parser
parser = argparse.ArgumentParser(description='Refacer')
parser.add_argument("--max_num_faces", type=int, help="Max number of faces on UI", default=5)
parser.add_argument("--force_cpu", help="Force CPU mode", default=False, action="store_true")
parser.add_argument("--share_gradio", help="Share Gradio", default=False, action="store_true")
parser.add_argument("--server_name", type=str, help="Server IP address", default="127.0.0.1")
parser.add_argument("--server_port", type=int, help="Server port", default=7860)
parser.add_argument("--colab_performance", help="Use in colab for better performance", default=False, action="store_true")
args = parser.parse_args()
# Initialize the Refacer class
refacer = Refacer(force_cpu=args.force_cpu, colab_performance=args.colab_performance)
num_faces = args.max_num_faces
# Run function for refacing video
def run(*vars):
video_path = vars[0]
origins = vars[1:(num_faces+1)]
destinations = vars[(num_faces+1):(num_faces*2)+1]
thresholds = vars[(num_faces*2)+1:]
faces = []
for k in range(0, num_faces):
if origins[k] is not None and destinations[k] is not None:
faces.append({
'origin': origins[k],
'destination': destinations[k],
'threshold': thresholds[k]
})
# Call refacer to process video and get refaced video path
refaced_video_path = refacer.reface(video_path, faces) # Get refaced video path
print(f"Refaced video can be found at {refaced_video_path}")
# Directly return the path to the Gradio UI without using ffmpeg or temp files
return refaced_video_path # Gradio will handle the video display
# Prepare Gradio components
origin = []
destination = []
thresholds = []
with gr.Blocks() as demo:
with gr.Row():
gr.Markdown("# Refacer")
with gr.Row():
video = gr.Video(label="Original video", format="mp4")
video2 = gr.Video(label="Refaced video", interactive=False, format="mp4")
for i in range(0, num_faces):
with gr.Tab(f"Face #{i+1}"):
with gr.Row():
origin.append(gr.Image(label="Face to replace"))
destination.append(gr.Image(label="Destination face"))
with gr.Row():
thresholds.append(gr.Slider(label="Threshold", minimum=0.0, maximum=1.0, value=0.2))
with gr.Row():
button = gr.Button("Reface", variant="primary")
# Click event: Refacing the video and showing the refaced video in Gradio
button.click(fn=run, inputs=[video] + origin + destination + thresholds, outputs=[video2])
# Launch the Gradio app
demo.queue().launch(show_error=True, share=args.share_gradio, server_name="0.0.0.0", server_port=args.server_port)
|