File size: 3,061 Bytes
a3e05e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import torch
import librosa
import yaml
from transformers import Wav2Vec2BertModel, SeamlessM4TFeatureExtractor
import safetensors
import accelerate
import soundfile as sf
import math
from einops import rearrange
from modules.audio_tokenizer.rep_codec import RepCodec


class AudioTokenizer(object):
    def __init__(self, **kwargs):
        self.device = kwargs.pop('device')
        print(self.device)
        # tokenize
        feat_stats = kwargs.pop('feat_stats')
        feat_stats = torch.load(feat_stats, map_location='cpu')
        self.feat_mean = feat_stats['mean']
        self.feat_std = torch.sqrt(feat_stats['var'])
        wav2vec_ckpt = kwargs.pop("wav2vec_ckpt")
        self.semantic_model = Wav2Vec2BertModel.from_pretrained(wav2vec_ckpt)
        self.semantic_model.eval()
        self.semantic_model.to(self.device)
        self.semantic_processor = SeamlessM4TFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0")

        self.semantic_codec = RepCodec()
        self.semantic_codec.eval()
        pretrained_path = kwargs.pop("semantic_codec_ckpt") 
        safetensors.torch.load_model(self.semantic_codec, pretrained_path)
        self.semantic_codec.to(self.device)

        self.max_length = 2048
        

    @torch.no_grad()
    def tokenize(self, speech):
        # Input:
        # speech: torch tensor, shape[B, N_speech]
        # Output:
        # semantic token: torch tensor, shape[B, N]

        inputs = self.semantic_processor(speech.cpu(), sampling_rate=16000, return_tensors="pt")
        input_features = inputs["input_features"].to(self.device)
        attention_mask = inputs["attention_mask"].to(self.device)
        seg_num = math.ceil(input_features.shape[1] / self.max_length)
        pad_num = seg_num * self.max_length - input_features.shape[1]
        input_features = torch.nn.functional.pad(input_features, (0, 0, 0, pad_num, 0,0), value=0)
        attention_mask = torch.nn.functional.pad(attention_mask, (0, pad_num, 0, 0), value=0)
        input_features = rearrange(input_features, "b (s n) d -> (b s) n d", s =seg_num)
        attention_mask = rearrange(attention_mask, "b (s n) -> (b s) n", s=seg_num)


        feats = self.semantic_model(
            input_features=input_features,
            attention_mask=attention_mask,
            output_hidden_states=True,
        )
        feat = feats.hidden_states[17]  
        feat = rearrange(feat, "(b s) n d -> b (s n) d", s=seg_num)
        feat = feat[:, :feat.shape[1]-pad_num, :]
        feat = (feat - self.feat_mean.to(feat)) / self.feat_std.to(feat)
        semantic_token, _ = self.semantic_codec.quantize(feat)  
        return semantic_token

def get_audio_tokenizer():
    config = dict()
    config['device'] = 'cuda' if torch.cuda.is_available() else 'cpu'
    config['feat_stats'] = 'resources/audio_tokenizer/stats.pt'
    config['wav2vec_ckpt'] = 'facebook/w2v-bert-2.0'
    config['semantic_codec_ckpt'] = 'resources/audio_tokenizer/model.safetensors'
    audio_tokenizer = AudioTokenizer(**config)
    return audio_tokenizer