File size: 5,661 Bytes
a3a3ae4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cb40006
a3a3ae4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
from typing import List
from PIL import Image
import numpy as np
import math
import random
import cv2
from typing import List

import torch
import einops
from pytorch_lightning import seed_everything
from transparent_background import Remover

from dataset.opencv_transforms.functional import to_tensor, center_crop
from vtdm.model import create_model
from vtdm.util import tensor2vid

remover = Remover(jit=False)


def pil_to_cv2(pil_image: Image.Image) -> np.ndarray:
    cv_image = np.array(pil_image)
    cv_image = cv2.cvtColor(cv_image, cv2.COLOR_RGB2BGR)
    return cv_image

def prepare_white_image(input_image: Image.Image) -> Image.Image:
    # remove bg
    output = remover.process(input_image, type='rgba')

    # expand image
    width, height = output.size
    max_side = max(width, height)
    white_image = Image.new('RGBA', (max_side, max_side), (0, 0, 0, 0))
    x_offset = (max_side - width) // 2
    y_offset = (max_side - height) // 2
    white_image.paste(output, (x_offset, y_offset))

    return white_image


class MultiViewGenerator:
    def __init__(self, checkpoint_path, config_path="inference.yaml"):
        self.models = {}
        denoising_model = create_model(config_path).cpu()
        denoising_model.init_from_ckpt(checkpoint_path)
        denoising_model = denoising_model.cuda().half()
        self.models["denoising_model"] = denoising_model

    def denoising(self, frames, args):
        with torch.no_grad():
            C, T, H, W = frames.shape
            batch = {"video": frames.unsqueeze(0)}
            batch["elevation"] = (
                torch.Tensor([args["elevation"]]).to(torch.int64).to(frames.device)
            )
            batch["fps_id"] = torch.Tensor([7]).to(torch.int64).to(frames.device)
            batch["motion_bucket_id"] = (
                torch.Tensor([127]).to(torch.int64).to(frames.device)
            )
            batch = self.models["denoising_model"].add_custom_cond(batch, infer=True)

            with torch.autocast(device_type="cuda", dtype=torch.float16):
                c, uc = self.models[
                    "denoising_model"
                ].conditioner.get_unconditional_conditioning(
                    batch,
                    force_uc_zero_embeddings=["cond_frames", "cond_frames_without_noise"],
                )

            additional_model_inputs = {
                "image_only_indicator": torch.zeros(2, T).to(
                    self.models["denoising_model"].device
                ),
                "num_video_frames": batch["num_video_frames"],
            }

            def denoiser(input, sigma, c):
                return self.models["denoising_model"].denoiser(
                    self.models["denoising_model"].model,
                    input,
                    sigma,
                    c,
                    **additional_model_inputs
                )

            with torch.autocast(device_type="cuda", dtype=torch.float16):
                randn = torch.randn(
                    [T, 4, H // 8, W // 8], device=self.models["denoising_model"].device
                )
                samples = self.models["denoising_model"].sampler(denoiser, randn, cond=c, uc=uc)

            samples = self.models["denoising_model"].decode_first_stage(samples.half())
            samples = einops.rearrange(samples, "(b t) c h w -> b c t h w", t=T)

        return tensor2vid(samples)

    def video_pipeline(self, frames, args) -> List[Image.Image]:
        num_iter = args["num_iter"]
        out_list = []

        for _ in range(num_iter):
            with torch.no_grad():
                results = self.denoising(frames, args)

            if len(out_list) == 0:
                out_list = out_list + results
            else:
                out_list = out_list + results[1:]

            img = out_list[-1]
            img = to_tensor(img)
            img = (img - 0.5) * 2.0
            frames[:, 0] = img

        result = []

        for i, frame in enumerate(out_list):
            input_image = Image.fromarray(frame)
            output_image = remover.process(input_image, type='rgba')
            result.append(output_image)

        return result

    def process(self, white_image: Image.Image, args) -> List[Image.Image]:
        img = pil_to_cv2(white_image)
        frame_list = [img] * args["clip_size"]

        h, w = frame_list[0].shape[0:2]
        rate = max(
            args["input_resolution"][0] * 1.0 / h, args["input_resolution"][1] * 1.0 / w
        )
        frame_list = [
            cv2.resize(f, [math.ceil(w * rate), math.ceil(h * rate)]) for f in frame_list
        ]
        frame_list = [
            center_crop(f, [args["input_resolution"][0], args["input_resolution"][1]])
            for f in frame_list
        ]
        frame_list = [cv2.cvtColor(f, cv2.COLOR_BGR2RGB) for f in frame_list]

        frame_list = [to_tensor(f) for f in frame_list]
        frame_list = [(f - 0.5) * 2.0 for f in frame_list]
        frames = torch.stack(frame_list, 1)
        frames = frames.cuda()

        self.models["denoising_model"].num_samples = args["clip_size"]
        self.models["denoising_model"].image_size = args["input_resolution"]

        return self.video_pipeline(frames, args)

    def infer(self, white_image: Image.Image) -> List[Image.Image]:
        seed = random.randint(0, 65535)
        seed_everything(seed)

        params = {
            "clip_size": 25,
            "input_resolution": [512, 512],
            "num_iter": 1,
            "aes": 6.0,
            "mv": [0.0, 0.0, 0.0, 10.0],
            "elevation": 0,
        }

        return self.process(white_image, params)