Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,061 Bytes
a3e05e8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
import torch
import librosa
import yaml
from transformers import Wav2Vec2BertModel, SeamlessM4TFeatureExtractor
import safetensors
import accelerate
import soundfile as sf
import math
from einops import rearrange
from modules.audio_tokenizer.rep_codec import RepCodec
class AudioTokenizer(object):
def __init__(self, **kwargs):
self.device = kwargs.pop('device')
print(self.device)
# tokenize
feat_stats = kwargs.pop('feat_stats')
feat_stats = torch.load(feat_stats, map_location='cpu')
self.feat_mean = feat_stats['mean']
self.feat_std = torch.sqrt(feat_stats['var'])
wav2vec_ckpt = kwargs.pop("wav2vec_ckpt")
self.semantic_model = Wav2Vec2BertModel.from_pretrained(wav2vec_ckpt)
self.semantic_model.eval()
self.semantic_model.to(self.device)
self.semantic_processor = SeamlessM4TFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0")
self.semantic_codec = RepCodec()
self.semantic_codec.eval()
pretrained_path = kwargs.pop("semantic_codec_ckpt")
safetensors.torch.load_model(self.semantic_codec, pretrained_path)
self.semantic_codec.to(self.device)
self.max_length = 2048
@torch.no_grad()
def tokenize(self, speech):
# Input:
# speech: torch tensor, shape[B, N_speech]
# Output:
# semantic token: torch tensor, shape[B, N]
inputs = self.semantic_processor(speech.cpu(), sampling_rate=16000, return_tensors="pt")
input_features = inputs["input_features"].to(self.device)
attention_mask = inputs["attention_mask"].to(self.device)
seg_num = math.ceil(input_features.shape[1] / self.max_length)
pad_num = seg_num * self.max_length - input_features.shape[1]
input_features = torch.nn.functional.pad(input_features, (0, 0, 0, pad_num, 0,0), value=0)
attention_mask = torch.nn.functional.pad(attention_mask, (0, pad_num, 0, 0), value=0)
input_features = rearrange(input_features, "b (s n) d -> (b s) n d", s =seg_num)
attention_mask = rearrange(attention_mask, "b (s n) -> (b s) n", s=seg_num)
feats = self.semantic_model(
input_features=input_features,
attention_mask=attention_mask,
output_hidden_states=True,
)
feat = feats.hidden_states[17]
feat = rearrange(feat, "(b s) n d -> b (s n) d", s=seg_num)
feat = feat[:, :feat.shape[1]-pad_num, :]
feat = (feat - self.feat_mean.to(feat)) / self.feat_std.to(feat)
semantic_token, _ = self.semantic_codec.quantize(feat)
return semantic_token
def get_audio_tokenizer():
config = dict()
config['device'] = 'cuda' if torch.cuda.is_available() else 'cpu'
config['feat_stats'] = 'resources/audio_tokenizer/stats.pt'
config['wav2vec_ckpt'] = 'facebook/w2v-bert-2.0'
config['semantic_codec_ckpt'] = 'resources/audio_tokenizer/model.safetensors'
audio_tokenizer = AudioTokenizer(**config)
return audio_tokenizer
|