Spaces:
Runtime error
Runtime error
"""Vocoder wrapper. | |
Copyright PolyAI Limited. | |
""" | |
import enum | |
import numpy as np | |
import soundfile as sf | |
import torch | |
import torch.nn as nn | |
from speechtokenizer import SpeechTokenizer | |
class VocoderType(enum.Enum): | |
SPEECHTOKENIZER = ("SPEECHTOKENIZER", 320) | |
def __init__(self, name, compression_ratio): | |
self._name_ = name | |
self.compression_ratio = compression_ratio | |
def get_vocoder(self, ckpt_path, config_path, **kwargs): | |
if self.name == "SPEECHTOKENIZER": | |
if ckpt_path: | |
vocoder = STWrapper(ckpt_path, config_path) | |
else: | |
vocoder = STWrapper() | |
else: | |
raise ValueError(f"Unknown vocoder type {self.name}") | |
return vocoder | |
class STWrapper(nn.Module): | |
def __init__( | |
self, | |
ckpt_path: str = './ckpt/speechtokenizer/SpeechTokenizer.pt', | |
config_path = './ckpt/speechtokenizer/config.json', | |
): | |
super().__init__() | |
self.model = SpeechTokenizer.load_from_checkpoint( | |
config_path, ckpt_path) | |
def eval(self): | |
self.model.eval() | |
def decode(self, codes: torch.Tensor, verbose: bool = False): | |
original_device = codes.device | |
codes = codes.to(self.device) | |
audio_array = self.model.decode(codes) | |
return audio_array.to(original_device) | |
def decode_to_file(self, codes_path, out_path) -> None: | |
codes = np.load(codes_path) | |
codes = torch.from_numpy(codes) | |
wav = self.decode(codes).cpu().numpy() | |
sf.write(out_path, wav, samplerate=self.model.sample_rate) | |
def encode(self, wav, verbose=False, n_quantizers: int = None): | |
original_device = wav.device | |
wav = wav.to(self.device) | |
codes = self.model.encode(wav) # codes: (n_q, B, T) | |
return codes.to(original_device) | |
def encode_to_file(self, wav_path, out_path) -> None: | |
wav, _ = sf.read(wav_path, dtype='float32') | |
wav = torch.from_numpy(wav).unsqueeze(0).unsqueeze(0) | |
codes = self.encode(wav).cpu().numpy() | |
np.save(out_path, codes) | |
def remove_weight_norm(self): | |
pass | |
def device(self): | |
return next(self.model.parameters()).device | |