Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import librosa | |
import yaml | |
from transformers import Wav2Vec2BertModel, SeamlessM4TFeatureExtractor | |
import safetensors | |
import accelerate | |
import soundfile as sf | |
import math | |
from einops import rearrange | |
from modules.audio_tokenizer.rep_codec import RepCodec | |
class AudioTokenizer(object): | |
def __init__(self, **kwargs): | |
self.device = kwargs.pop('device') | |
print(self.device) | |
# tokenize | |
feat_stats = kwargs.pop('feat_stats') | |
feat_stats = torch.load(feat_stats, map_location='cpu') | |
self.feat_mean = feat_stats['mean'] | |
self.feat_std = torch.sqrt(feat_stats['var']) | |
wav2vec_ckpt = kwargs.pop("wav2vec_ckpt") | |
self.semantic_model = Wav2Vec2BertModel.from_pretrained(wav2vec_ckpt) | |
self.semantic_model.eval() | |
self.semantic_model.to(self.device) | |
self.semantic_processor = SeamlessM4TFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0") | |
self.semantic_codec = RepCodec() | |
self.semantic_codec.eval() | |
pretrained_path = kwargs.pop("semantic_codec_ckpt") | |
safetensors.torch.load_model(self.semantic_codec, pretrained_path) | |
self.semantic_codec.to(self.device) | |
self.max_length = 2048 | |
def tokenize(self, speech): | |
# Input: | |
# speech: torch tensor, shape[B, N_speech] | |
# Output: | |
# semantic token: torch tensor, shape[B, N] | |
inputs = self.semantic_processor(speech.cpu(), sampling_rate=16000, return_tensors="pt") | |
input_features = inputs["input_features"].to(self.device) | |
attention_mask = inputs["attention_mask"].to(self.device) | |
seg_num = math.ceil(input_features.shape[1] / self.max_length) | |
pad_num = seg_num * self.max_length - input_features.shape[1] | |
input_features = torch.nn.functional.pad(input_features, (0, 0, 0, pad_num, 0,0), value=0) | |
attention_mask = torch.nn.functional.pad(attention_mask, (0, pad_num, 0, 0), value=0) | |
input_features = rearrange(input_features, "b (s n) d -> (b s) n d", s =seg_num) | |
attention_mask = rearrange(attention_mask, "b (s n) -> (b s) n", s=seg_num) | |
feats = self.semantic_model( | |
input_features=input_features, | |
attention_mask=attention_mask, | |
output_hidden_states=True, | |
) | |
feat = feats.hidden_states[17] | |
feat = rearrange(feat, "(b s) n d -> b (s n) d", s=seg_num) | |
feat = feat[:, :feat.shape[1]-pad_num, :] | |
feat = (feat - self.feat_mean.to(feat)) / self.feat_std.to(feat) | |
semantic_token, _ = self.semantic_codec.quantize(feat) | |
return semantic_token | |
def get_audio_tokenizer(): | |
config = dict() | |
config['device'] = 'cuda' if torch.cuda.is_available() else 'cpu' | |
config['feat_stats'] = 'resources/audio_tokenizer/stats.pt' | |
config['wav2vec_ckpt'] = 'facebook/w2v-bert-2.0' | |
config['semantic_codec_ckpt'] = 'resources/audio_tokenizer/model.safetensors' | |
audio_tokenizer = AudioTokenizer(**config) | |
return audio_tokenizer | |