|
import logging |
|
from functools import cache |
|
|
|
import torch |
|
|
|
from ..denoiser.denoiser import Denoiser |
|
|
|
from ..inference import inference |
|
from .hparams import HParams |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
@cache |
|
def load_denoiser(run_dir, device): |
|
if run_dir is None: |
|
return Denoiser(HParams()) |
|
hp = HParams.load(run_dir) |
|
denoiser = Denoiser(hp) |
|
path = run_dir / "ds" / "G" / "default" / "mp_rank_00_model_states.pt" |
|
state_dict = torch.load(path, map_location="cpu")["module"] |
|
denoiser.load_state_dict(state_dict) |
|
denoiser.eval() |
|
denoiser.to(device) |
|
return denoiser |
|
|
|
|
|
@torch.inference_mode() |
|
def denoise(dwav, sr, run_dir, device): |
|
denoiser = load_denoiser(run_dir, device) |
|
return inference(model=denoiser, dwav=dwav, sr=sr, device=device) |
|
|