File size: 6,222 Bytes
7df097b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import random

import gradio as gr
import huggingface_hub
import imageio
import numpy as np
import onnxruntime as rt
from numpy.random import RandomState
from skimage import transform


class Model:
    def __init__(self):
        self.g_synthesis = None
        self.g_mapping = None
        self.load_models()

    def load_models(self):
        providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
        g_mapping_path = huggingface_hub.hf_hub_download("skytnt/waifu-gan", "g_mapping.onnx")
        g_synthesis_path = huggingface_hub.hf_hub_download("skytnt/waifu-gan", "g_synthesis.onnx")
        self.g_mapping = rt.InferenceSession(g_mapping_path, providers=providers)
        self.g_synthesis = rt.InferenceSession(g_synthesis_path, providers=providers)

    def get_img(self, w):
        img = self.g_synthesis.run(None, {'w': w})[0]
        return (img.transpose(0, 2, 3, 1) * 127.5 + 128).clip(0, 255).astype(np.uint8)[0]

    def get_w(self, z, psi1, psi2):
        return self.g_mapping.run(None, {'z': z, 'psi': np.asarray([psi1, psi2], dtype=np.float32)})[0]

    def gen_video(self, w1, w2, path, frame_num=10):
        video = imageio.get_writer(path, mode='I', fps=frame_num // 2, codec='libx264', bitrate='16M')
        lin = np.linspace(0, 1, frame_num)
        for i in range(0, frame_num):
            img = self.get_img(((1 - lin[i]) * w1) + (lin[i] * w2))
            video.append_data(img)
        video.close()


def get_thumbnail(img):
    img_new = np.full((192, 288, 3), 200, dtype=np.uint8)
    img_new[:, 80:208] = transform.resize(img, (192, 128), preserve_range=True)
    return img_new


def gen_fn(method, seed, psi1, psi2):
    if method == 0:
        seed = random.randint(0, 2 ** 32 -1)
    z = RandomState(int(seed)).randn(1, 1024)
    w = model.get_w(z.astype(dtype=np.float32), psi1, psi2)
    img_out = model.get_img(w)
    return img_out, seed, w, get_thumbnail(img_out)


def gen_video_fn(w1, w2, frame):
    if w1 is None or w2 is None:
        return None
    model.gen_video(w1, w2, "video.mp4", int(frame))
    return "video.mp4"


if __name__ == '__main__':
    model = Model()

    app = gr.Blocks()
    with app:
        gr.Markdown("# Waifu GAN\n\n"
                    "![visitor badge](https://visitor-badge.glitch.me/badge?page_id=skytnt.waifu-gan)\n\n")
        with gr.Tabs():
            with gr.TabItem("generate image"):
                with gr.Row():
                    with gr.Column():
                        with gr.Row():
                            gen_input1 = gr.Radio(label="method", value="random",
                                                  choices=["random", "use seed"], type="index")
                            gen_input2 = gr.Slider(minimum=0, maximum=2 ** 32 - 1, step=1, value=0,
                                                   label="seed")
                        gen_input3 = gr.Slider(minimum=0, maximum=1, step=0.01, value=1, label="truncation psi 1")
                        gen_input4 = gr.Slider(minimum=0, maximum=1, step=0.01, value=1, label="truncation psi 2")
                        with gr.Group():
                            gen_submit = gr.Button("Generate", variant="primary")
                    with gr.Column():
                        gen_output1 = gr.Image(label="output image")
                        select_img_input_w1 = gr.Variable()
                        select_img_input_img1 = gr.Variable()

            with gr.TabItem("generate video"):
                with gr.Row():
                    with gr.Column():
                        gr.Markdown("## generate video between 2 images")
                        with gr.Row():
                            with gr.Column():
                                gr.Markdown("please select image 1")
                                select_img1_dropdown = gr.Radio(label="source", value="current generated image",
                                                                choices=["current generated image"], type="index")
                                with gr.Group():
                                    select_img1_button = gr.Button("Select", variant="primary")
                                select_img1_output_img = gr.Image(label="selected image 1")
                                select_img1_output_w = gr.Variable()
                            with gr.Column():
                                gr.Markdown("please select image 2")
                                select_img2_dropdown = gr.Radio(label="source", value="current generated image",
                                                                choices=["current generated image"], type="index")
                                with gr.Group():
                                    select_img2_button = gr.Button("Select", variant="primary")
                                select_img2_output_img = gr.Image(label="selected image 2")
                                select_img2_output_w = gr.Variable()
                        generate_video_frame = gr.Slider(minimum=10, maximum=30, step=1, label="frame", value=15)
                        with gr.Group():
                            generate_video_button = gr.Button("Generate", variant="primary")
                    with gr.Column():
                        generate_video_output = gr.Video(label="output video")
        gen_submit.click(gen_fn, [gen_input1, gen_input2, gen_input3, gen_input4],
                         [gen_output1, gen_input2, select_img_input_w1, select_img_input_img1])
        select_img1_button.click(lambda i, img1, w1: (img1, w1),
                                 [select_img1_dropdown, select_img_input_img1, select_img_input_w1],
                                 [select_img1_output_img, select_img1_output_w])
        select_img2_button.click(lambda i, img1, w1: (img1, w1),
                                 [select_img2_dropdown, select_img_input_img1, select_img_input_w1],
                                 [select_img2_output_img, select_img2_output_w])
        generate_video_button.click(gen_video_fn,
                                    [select_img1_output_w, select_img2_output_w, generate_video_frame],
                                    [generate_video_output])
    app.launch()