|
import dataclasses |
|
import pathlib |
|
import libf0 |
|
import librosa |
|
import numpy as np |
|
import resampy |
|
import torch |
|
import torchcrepe |
|
import torchfcpe |
|
import os |
|
|
|
|
|
from rvc.lib.predictors.RMVPE import RMVPE0Predictor |
|
from rvc.configs.config import Config |
|
|
|
config = Config() |
|
|
|
|
|
@dataclasses.dataclass |
|
class F0Extractor: |
|
wav_path: pathlib.Path |
|
sample_rate: int = 44100 |
|
hop_length: int = 512 |
|
f0_min: int = 50 |
|
f0_max: int = 1600 |
|
method: str = "rmvpe" |
|
x: np.ndarray = dataclasses.field(init=False) |
|
|
|
def __post_init__(self): |
|
self.x, self.sample_rate = librosa.load(self.wav_path, sr=self.sample_rate) |
|
|
|
@property |
|
def hop_size(self) -> float: |
|
return self.hop_length / self.sample_rate |
|
|
|
@property |
|
def wav16k(self) -> np.ndarray: |
|
return resampy.resample(self.x, self.sample_rate, 16000) |
|
|
|
def extract_f0(self) -> np.ndarray: |
|
f0 = None |
|
method = self.method |
|
if method == "crepe": |
|
wav16k_torch = torch.FloatTensor(self.wav16k).unsqueeze(0).to(config.device) |
|
f0 = torchcrepe.predict( |
|
wav16k_torch, |
|
sample_rate=16000, |
|
hop_length=160, |
|
batch_size=512, |
|
fmin=self.f0_min, |
|
fmax=self.f0_max, |
|
device=config.device, |
|
) |
|
f0 = f0[0].cpu().numpy() |
|
elif method == "fcpe": |
|
audio = librosa.to_mono(self.x) |
|
audio_length = len(audio) |
|
f0_target_length = (audio_length // self.hop_length) + 1 |
|
audio = ( |
|
torch.from_numpy(audio) |
|
.float() |
|
.unsqueeze(0) |
|
.unsqueeze(-1) |
|
.to(config.device) |
|
) |
|
model = torchfcpe.spawn_bundled_infer_model(device=config.device) |
|
|
|
f0 = model.infer( |
|
audio, |
|
sr=self.sample_rate, |
|
decoder_mode="local_argmax", |
|
threshold=0.006, |
|
f0_min=self.f0_min, |
|
f0_max=self.f0_max, |
|
interp_uv=False, |
|
output_interp_target_length=f0_target_length, |
|
) |
|
f0 = f0.squeeze().cpu().numpy() |
|
elif method == "rmvpe": |
|
model_rmvpe = RMVPE0Predictor( |
|
os.path.join("rvc", "models", "predictors", "rmvpe.pt"), |
|
is_half=config.is_half, |
|
device=config.device, |
|
|
|
) |
|
f0 = model_rmvpe.infer_from_audio(self.wav16k, thred=0.03) |
|
|
|
else: |
|
raise ValueError(f"Unknown method: {self.method}") |
|
return libf0.hz_to_cents(f0, librosa.midi_to_hz(0)) |
|
|
|
def plot_f0(self, f0): |
|
from matplotlib import pyplot as plt |
|
|
|
plt.figure(figsize=(10, 4)) |
|
plt.plot(f0) |
|
plt.title(self.method) |
|
plt.xlabel("Time (frames)") |
|
plt.ylabel("F0 (cents)") |
|
plt.show() |
|
|