File size: 3,725 Bytes
8b9f861
 
98c1e0d
8b9f861
0009612
 
 
8b9f861
0009612
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8b9f861
98c1e0d
 
 
 
 
 
 
 
 
 
d4991cc
98c1e0d
9aa8011
98c1e0d
9aa8011
0009612
98c1e0d
 
 
 
 
9aa8011
d4991cc
98c1e0d
 
d4991cc
98c1e0d
 
 
d4991cc
9aa8011
0009612
 
98c1e0d
9aa8011
0009612
2344f4c
98c1e0d
 
 
 
 
2c8ca5e
98c1e0d
 
 
d4991cc
98c1e0d
 
9aa8011
98c1e0d
d4991cc
98c1e0d
 
 
 
 
 
d4991cc
98c1e0d
9aa8011
98c1e0d
 
8b9f861
fd29508
98c1e0d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import gradio as gr
from refacer import Refacer
import argparse
import os
import requests
import tempfile
import shutil

# Hugging Face URL to download the model
model_url = "https://huggingface.co/ofter/4x-UltraSharp/resolve/main/inswapper_128.onnx"
model_path = "/home/user/app/inswapper_128.onnx"  # absolute path for the model in your environment

# Function to download the model if not exists
def download_model():
    if not os.path.exists(model_path):
        print("Downloading the inswapper_128.onnx model...")
        response = requests.get(model_url)
        if response.status_code == 200:
            with open(model_path, 'wb') as f:
                f.write(response.content)
            print("Model downloaded successfully!")
        else:
            print(f"Error: Model download failed. Status code: {response.status_code}")
    else:
        print("Model already exists.")

# Download the model when the script runs
download_model()

# Argument parser
parser = argparse.ArgumentParser(description='Refacer')
parser.add_argument("--max_num_faces", type=int, help="Max number of faces on UI", default=5)
parser.add_argument("--force_cpu", help="Force CPU mode", default=False, action="store_true")
parser.add_argument("--share_gradio", help="Share Gradio", default=False, action="store_true")
parser.add_argument("--server_name", type=str, help="Server IP address", default="127.0.0.1")
parser.add_argument("--server_port", type=int, help="Server port", default=7860)
parser.add_argument("--colab_performance", help="Use in colab for better performance", default=False, action="store_true")
args = parser.parse_args()

# Initialize the Refacer class
refacer = Refacer(force_cpu=args.force_cpu, colab_performance=args.colab_performance)

num_faces = args.max_num_faces

# Run function for refacing video
def run(*vars):
    video_path = vars[0]
    origins = vars[1:(num_faces+1)]
    destinations = vars[(num_faces+1):(num_faces*2)+1]
    thresholds = vars[(num_faces*2)+1:]

    faces = []
    for k in range(0, num_faces):
        if origins[k] is not None and destinations[k] is not None:
            faces.append({
                'origin': origins[k],
                'destination': destinations[k],
                'threshold': thresholds[k]
            })

    # Call refacer to process video and get refaced video path
    refaced_video_path = refacer.reface(video_path, faces)  # Get refaced video path
    print(f"Refaced video can be found at {refaced_video_path}")

    # Directly return the path to the Gradio UI without using ffmpeg or temp files
    return refaced_video_path  # Gradio will handle the video display

# Prepare Gradio components
origin = []
destination = []
thresholds = []

with gr.Blocks() as demo:
    with gr.Row():
        gr.Markdown("# Refacer")
    with gr.Row():
        video = gr.Video(label="Original video", format="mp4")
        video2 = gr.Video(label="Refaced video", interactive=False, format="mp4")

    for i in range(0, num_faces):
        with gr.Tab(f"Face #{i+1}"):
            with gr.Row():
                origin.append(gr.Image(label="Face to replace"))
                destination.append(gr.Image(label="Destination face"))
            with gr.Row():
                thresholds.append(gr.Slider(label="Threshold", minimum=0.0, maximum=1.0, value=0.2))
    
    with gr.Row():
        button = gr.Button("Reface", variant="primary")

    # Click event: Refacing the video and showing the refaced video in Gradio
    button.click(fn=run, inputs=[video] + origin + destination + thresholds, outputs=[video2])

# Launch the Gradio app
demo.queue().launch(show_error=True, share=args.share_gradio, server_name="0.0.0.0", server_port=args.server_port)