File size: 902 Bytes
9f4b9c7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
from typing import Dict, Type
import torch
from nota_wav2lip.models import NotaWav2Lip, Wav2Lip, Wav2LipBase
MODEL_REGISTRY: Dict[str, Type[Wav2LipBase]] = {
'wav2lip': Wav2Lip,
'nota_wav2lip': NotaWav2Lip
}
def _load(checkpoint_path, device):
assert device in ['cpu', 'cuda']
print(f"Load checkpoint from: {checkpoint_path}")
if device == 'cuda':
return torch.load(checkpoint_path)
return torch.load(checkpoint_path, map_location=lambda storage, _: storage)
def load_model(model_name: str, device, checkpoint, **kwargs) -> Wav2LipBase:
cls = MODEL_REGISTRY[model_name.lower()]
assert issubclass(cls, Wav2LipBase)
model = cls(**kwargs)
checkpoint = _load(checkpoint, device)
model.load_state_dict(checkpoint)
model = model.to(device)
return model.eval()
def count_params(model):
return sum(p.numel() for p in model.parameters())
|