File size: 4,211 Bytes
9f4b9c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
from typing import Iterable, Iterator, List, Tuple

import cv2
import numpy as np
import torch
import torch.nn as nn
from omegaconf import DictConfig
from tqdm import tqdm

from config import hparams as hp
from nota_wav2lip.models.util import count_params, load_model


class Wav2LipInferenceImpl:
    def __init__(self, model_name: str, hp_inference_model: DictConfig, device='cpu'):
        self.model: nn.Module = load_model(
            model_name,
            device=device,
            **hp_inference_model
        )
        self.device = device
        self._params: str = self._format_param(count_params(self.model))

    @property
    def params(self):
        return self._params

    @staticmethod
    def _format_param(num_params: int) -> str:
        params_in_million = num_params / 1e6
        return f"{params_in_million:.1f}M"

    @staticmethod
    def _reset_batch() -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray], List[List[int]]]:
        return [], [], [], []

    def get_data_iterator(
        self,
        audio_iterable: Iterable[np.ndarray],
        video_iterable: List[Tuple[np.ndarray, List[int]]]
    ) -> Iterator[Tuple[np.ndarray, np.ndarray, np.ndarray, List[int]]]:
        img_batch, mel_batch, frame_batch, coords_batch = self._reset_batch()

        for i, m in enumerate(audio_iterable):
            idx = i % len(video_iterable)
            _frame_to_save, coords = video_iterable[idx]
            frame_to_save = _frame_to_save.copy()
            face = frame_to_save[coords[0]:coords[1], coords[2]:coords[3]].copy()

            face: np.ndarray = cv2.resize(face, (hp.face.img_size, hp.face.img_size))

            img_batch.append(face)
            mel_batch.append(m)
            frame_batch.append(frame_to_save)
            coords_batch.append(coords)

            if len(img_batch) >= hp.inference.batch_size:
                img_batch = np.asarray(img_batch)
                mel_batch = np.asarray(mel_batch)

                img_masked = img_batch.copy()
                img_masked[:, hp.face.img_size // 2:] = 0

                img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
                mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])

                yield img_batch, mel_batch, frame_batch, coords_batch
                img_batch, mel_batch, frame_batch, coords_batch = self._reset_batch()

        if len(img_batch) > 0:
            img_batch = np.asarray(img_batch)
            mel_batch = np.asarray(mel_batch)

            img_masked = img_batch.copy()
            img_masked[:, hp.face.img_size // 2:] = 0

            img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
            mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])

            yield img_batch, mel_batch, frame_batch, coords_batch

    @torch.no_grad()
    def inference_with_iterator(
        self,
        audio_iterable: Iterable[np.ndarray],
        video_iterable: List[Tuple[np.ndarray, List[int]]]
    ) -> Iterator[np.ndarray]:
        data_iterator = self.get_data_iterator(audio_iterable, video_iterable)

        for (img_batch, mel_batch, frames, coords) in \
            tqdm(data_iterator, total=int(np.ceil(float(len(audio_iterable)) / hp.inference.batch_size))):

            img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(self.device)
            mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(self.device)

            preds: torch.Tensor = self.forward(mel_batch, img_batch)

            preds = preds.cpu().numpy().transpose(0, 2, 3, 1) * 255.
            for pred, frame, coord in zip(preds, frames, coords):
                y1, y2, x1, x2 = coord
                pred = cv2.resize(pred.astype(np.uint8), (x2 - x1, y2 - y1))

                frame[y1:y2, x1:x2] = pred
                yield frame

    @torch.no_grad()
    def forward(self, audio_sequences: torch.Tensor, face_sequences: torch.Tensor) -> torch.Tensor:
        return self.model(audio_sequences, face_sequences)

    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)