lipSync / nota_wav2lip /inference.py
Suprath's picture
Upload 54 files
9f4b9c7 verified
raw
history blame
4.21 kB
from typing import Iterable, Iterator, List, Tuple
import cv2
import numpy as np
import torch
import torch.nn as nn
from omegaconf import DictConfig
from tqdm import tqdm
from config import hparams as hp
from nota_wav2lip.models.util import count_params, load_model
class Wav2LipInferenceImpl:
def __init__(self, model_name: str, hp_inference_model: DictConfig, device='cpu'):
self.model: nn.Module = load_model(
model_name,
device=device,
**hp_inference_model
)
self.device = device
self._params: str = self._format_param(count_params(self.model))
@property
def params(self):
return self._params
@staticmethod
def _format_param(num_params: int) -> str:
params_in_million = num_params / 1e6
return f"{params_in_million:.1f}M"
@staticmethod
def _reset_batch() -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray], List[List[int]]]:
return [], [], [], []
def get_data_iterator(
self,
audio_iterable: Iterable[np.ndarray],
video_iterable: List[Tuple[np.ndarray, List[int]]]
) -> Iterator[Tuple[np.ndarray, np.ndarray, np.ndarray, List[int]]]:
img_batch, mel_batch, frame_batch, coords_batch = self._reset_batch()
for i, m in enumerate(audio_iterable):
idx = i % len(video_iterable)
_frame_to_save, coords = video_iterable[idx]
frame_to_save = _frame_to_save.copy()
face = frame_to_save[coords[0]:coords[1], coords[2]:coords[3]].copy()
face: np.ndarray = cv2.resize(face, (hp.face.img_size, hp.face.img_size))
img_batch.append(face)
mel_batch.append(m)
frame_batch.append(frame_to_save)
coords_batch.append(coords)
if len(img_batch) >= hp.inference.batch_size:
img_batch = np.asarray(img_batch)
mel_batch = np.asarray(mel_batch)
img_masked = img_batch.copy()
img_masked[:, hp.face.img_size // 2:] = 0
img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
yield img_batch, mel_batch, frame_batch, coords_batch
img_batch, mel_batch, frame_batch, coords_batch = self._reset_batch()
if len(img_batch) > 0:
img_batch = np.asarray(img_batch)
mel_batch = np.asarray(mel_batch)
img_masked = img_batch.copy()
img_masked[:, hp.face.img_size // 2:] = 0
img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
yield img_batch, mel_batch, frame_batch, coords_batch
@torch.no_grad()
def inference_with_iterator(
self,
audio_iterable: Iterable[np.ndarray],
video_iterable: List[Tuple[np.ndarray, List[int]]]
) -> Iterator[np.ndarray]:
data_iterator = self.get_data_iterator(audio_iterable, video_iterable)
for (img_batch, mel_batch, frames, coords) in \
tqdm(data_iterator, total=int(np.ceil(float(len(audio_iterable)) / hp.inference.batch_size))):
img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(self.device)
mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(self.device)
preds: torch.Tensor = self.forward(mel_batch, img_batch)
preds = preds.cpu().numpy().transpose(0, 2, 3, 1) * 255.
for pred, frame, coord in zip(preds, frames, coords):
y1, y2, x1, x2 = coord
pred = cv2.resize(pred.astype(np.uint8), (x2 - x1, y2 - y1))
frame[y1:y2, x1:x2] = pred
yield frame
@torch.no_grad()
def forward(self, audio_sequences: torch.Tensor, face_sequences: torch.Tensor) -> torch.Tensor:
return self.model(audio_sequences, face_sequences)
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)