File size: 6,735 Bytes
261ef4c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
import torch
import av
import pims

import numpy as np

from typing import Optional, Tuple
from torchvision import transforms
from torch.utils.data import Dataset
from PIL import Image
from torch.utils.data import DataLoader
from tqdm import tqdm


class VideoReader(Dataset):
    def __init__(self, path, transform=None):
        self.video = pims.PyAVVideoReader(path)
        self.rate = self.video.frame_rate
        self.transform = transform

    @property
    def frame_rate(self):
        return self.rate

    def __len__(self):
        return len(self.video)

    def __getitem__(self, idx):
        frame = self.video[idx]
        frame = Image.fromarray(np.asarray(frame))
        if self.transform is not None:
            frame = self.transform(frame)
        return frame


class VideoWriter:
    def __init__(self, path, frame_rate, bit_rate=1000000):
        self.container = av.open(path, mode="w")
        self.stream = self.container.add_stream("h264", rate=f"{frame_rate:.4f}")
        self.stream.pix_fmt = "yuv420p"
        self.stream.bit_rate = bit_rate

    def write(self, frames):
        # frames: [T, C, H, W]
        self.stream.width = frames.size(3)
        self.stream.height = frames.size(2)
        if frames.size(1) == 1:
            frames = frames.repeat(1, 3, 1, 1)  # convert grayscale to RGB
        frames = frames.mul(255).byte().cpu().permute(0, 2, 3, 1).numpy()
        for t in range(frames.shape[0]):
            frame = frames[t]
            frame = av.VideoFrame.from_ndarray(frame, format="rgb24")
            self.container.mux(self.stream.encode(frame))

    def close(self):
        self.container.mux(self.stream.encode())
        self.container.close()


def auto_downsample_ratio(h, w):
    """
    Automatically find a downsample ratio so that the largest side of the resolution be 512px.
    """
    return min(512 / max(h, w), 1)


def convert_video(
    model,
    input_source: str,
    input_resize: Optional[Tuple[int, int]] = None,
    downsample_ratio: Optional[float] = None,
    output_composition: Optional[str] = None,
    output_alpha: Optional[str] = None,
    output_foreground: Optional[str] = None,
    output_video_mbps: Optional[float] = None,
    seq_chunk: int = 1,
    num_workers: int = 0,
    progress: bool = True,
    device: Optional[str] = None,
    dtype: Optional[torch.dtype] = None,
):
    """
    Args:
        input_source:A video file, or an image sequence directory. Images must be sorted in accending order, support png and jpg.
        input_resize: If provided, the input are first resized to (w, h).
        downsample_ratio: The model's downsample_ratio hyperparameter. If not provided, model automatically set one.
        output_type: Options: ["video", "png_sequence"].
        output_composition:
            The composition output path. File path if output_type == 'video'. Directory path if output_type == 'png_sequence'.
            If output_type == 'video', the composition has green screen background.
            If output_type == 'png_sequence'. the composition is RGBA png images.
        output_alpha: The alpha output from the model.
        output_foreground: The foreground output from the model.
        seq_chunk: Number of frames to process at once. Increase it for better parallelism.
        num_workers: PyTorch's DataLoader workers. Only use >0 for image input.
        progress: Show progress bar.
        device: Only need to manually provide if model is a TorchScript freezed model.
        dtype: Only need to manually provide if model is a TorchScript freezed model.
    """

    assert downsample_ratio is None or (
        downsample_ratio > 0 and downsample_ratio <= 1
    ), "Downsample ratio must be between 0 (exclusive) and 1 (inclusive)."
    assert any(
        [output_composition, output_alpha, output_foreground]
    ), "Must provide at least one output."
    assert seq_chunk >= 1, "Sequence chunk must be >= 1"
    assert num_workers >= 0, "Number of workers must be >= 0"

    # Initialize transform
    if input_resize is not None:
        transform = transforms.Compose(
            [transforms.Resize(input_resize[::-1]), transforms.ToTensor()]
        )
    else:
        transform = transforms.ToTensor()

    # Initialize reader
    source = VideoReader(input_source, transform)
    reader = DataLoader(
        source, batch_size=seq_chunk, pin_memory=True, num_workers=num_workers
    )

    # Initialize writers
    frame_rate = source.frame_rate if isinstance(source, VideoReader) else 30
    output_video_mbps = 1 if output_video_mbps is None else output_video_mbps
    if output_composition is not None:
        writer_com = VideoWriter(
            path=output_composition,
            frame_rate=frame_rate,
            bit_rate=int(output_video_mbps * 1000000),
        )
    if output_alpha is not None:
        writer_pha = VideoWriter(
            path=output_alpha,
            frame_rate=frame_rate,
            bit_rate=int(output_video_mbps * 1000000),
        )
    if output_foreground is not None:
        writer_fgr = VideoWriter(
            path=output_foreground,
            frame_rate=frame_rate,
            bit_rate=int(output_video_mbps * 1000000),
        )

    # Inference
    model = model.eval()
    if device is None or dtype is None:
        param = next(model.parameters())
        dtype = param.dtype
        device = param.device

    if output_composition is not None:
        bgr = (
            torch.tensor([0, 0, 0], device=device, dtype=dtype)
            .div(255)
            .view(1, 1, 3, 1, 1)
        )

    try:
        with torch.no_grad():
            bar = tqdm(total=len(source), disable=not progress, dynamic_ncols=True)
            rec = [None] * 4
            for src in reader:
                if downsample_ratio is None:
                    downsample_ratio = auto_downsample_ratio(*src.shape[2:])

                src = src.to(device, dtype, non_blocking=True).unsqueeze(
                    0
                )  # [B, T, C, H, W]
                fgr, pha, *rec = model(src, *rec, downsample_ratio)

                if output_foreground is not None:
                    writer_fgr.write(fgr[0])
                if output_alpha is not None:
                    writer_pha.write(pha[0])
                if output_composition is not None:
                    com = fgr * pha + bgr * (1 - pha)
                    writer_com.write(com[0])

                bar.update(src.size(1))

    finally:
        # Clean up
        if output_composition is not None:
            writer_com.close()
        if output_alpha is not None:
            writer_pha.close()
        if output_foreground is not None:
            writer_fgr.close()