arnavmehta7 commited on
Commit
5840a7d
1 Parent(s): 7fc5566

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +77 -16
inference.py CHANGED
@@ -4,7 +4,7 @@ import torch.nn as nn
4
  import torch.nn.functional as F
5
  import logging
6
  import json
7
- from typing import Optional
8
  from pathlib import Path
9
  from dataclasses import dataclass
10
  import os
@@ -18,6 +18,8 @@ from mars5.minbpe.codebook import CodebookTokenizer
18
  from mars5.ar_generate import ar_generate
19
  from mars5.utils import nuke_weight_norm
20
  from mars5.trim import trim
 
 
21
  import tempfile
22
  import logging
23
 
@@ -64,9 +66,7 @@ class InferenceConfig():
64
  beam_width: int = 1 # only beam width of 1 is currently supported
65
  ref_audio_pad: float = 0
66
 
67
-
68
- class Mars5TTS(nn.Module):
69
-
70
  def __init__(self, ar_ckpt, nar_ckpt, device: str = None) -> None:
71
  super().__init__()
72
 
@@ -79,16 +79,16 @@ class Mars5TTS(nn.Module):
79
 
80
  # save and load text tokenize
81
  self.texttok = RegexTokenizer(GPT4_SPLIT_PATTERN)
82
- tfn = tempfile.mkstemp(suffix='texttok.model')[1]
83
- Path(tfn).write_text(ar_ckpt['vocab']['texttok.model'])
84
- self.texttok.load(tfn)
85
- os.remove(tfn)
86
  # save and load speech tokenizer
87
- sfn = tempfile.mkstemp(suffix='speechtok.model')[1]
88
  self.speechtok = CodebookTokenizer(GPT4_SPLIT_PATTERN)
89
- Path(sfn).write_text(ar_ckpt['vocab']['speechtok.model'])
90
- self.speechtok.load(sfn)
91
- os.remove(sfn)
 
92
  # keep track of tokenization things.
93
  self.n_vocab = len(self.texttok.vocab) + len(self.speechtok.vocab)
94
  self.n_text_vocab = len(self.texttok.vocab) + 1
@@ -111,7 +111,42 @@ class Mars5TTS(nn.Module):
111
  self.vocos = Vocos.from_pretrained("charactr/vocos-encodec-24khz").to(self.device).eval()
112
  nuke_weight_norm(self.codec)
113
  nuke_weight_norm(self.vocos)
114
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  @torch.inference_mode
116
  def vocode(self, tokens: Tensor) -> Tensor:
117
  """ Vocodes tokens of shape (seq_len, n_q) """
@@ -126,6 +161,33 @@ class Mars5TTS(nn.Module):
126
  wav_diffusion = self.vocos.decode(features, bandwidth_id=bandwidth_id)
127
  return wav_diffusion.cpu().squeeze()[None]
128
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  @torch.inference_mode
130
  def tts(self, text: str, ref_audio: Tensor, ref_transcript: Optional[str] = None,
131
  cfg: Optional[InferenceConfig] = InferenceConfig()) -> Tensor:
@@ -183,12 +245,12 @@ class Mars5TTS(nn.Module):
183
  first_codec_idx = prompt.shape[-1] - n_speech_inp + 1
184
 
185
  # ---> perform AR code generation
186
-
187
  logging.debug(f"Raw acoustic prompt length: {raw_prompt_acoustic_len}")
188
 
189
  ar_codes = ar_generate(self.texttok, self.speechtok, self.codeclm,
190
  prompt, spk_ref_codec, first_codec_idx,
191
- max_len=cfg.generate_max_len_override if cfg.generate_max_len_override > 1 else 2000,
 
192
  temperature=cfg.temperature, topk=cfg.top_k, top_p=cfg.top_p, typical_p=cfg.typical_p,
193
  alpha_frequency=cfg.freq_penalty, alpha_presence=cfg.presence_penalty, penalty_window=cfg.rep_penalty_window,
194
  eos_penalty_decay=cfg.eos_penalty_decay, eos_penalty_factor=cfg.eos_penalty_factor,
@@ -211,7 +273,6 @@ class Mars5TTS(nn.Module):
211
  x_padding_mask = torch.zeros((1, _x.shape[1]), dtype=torch.bool, device=_x.device)
212
 
213
  # ---> perform DDPM NAR inference
214
-
215
  T = self.default_T
216
  diff = MultinomialDiffusion(self.diffusion_n_classes, timesteps=T, device=self.device)
217
 
 
4
  import torch.nn.functional as F
5
  import logging
6
  import json
7
+ from typing import Optional, Dict, Type, Union, List, Tuple
8
  from pathlib import Path
9
  from dataclasses import dataclass
10
  import os
 
18
  from mars5.ar_generate import ar_generate
19
  from mars5.utils import nuke_weight_norm
20
  from mars5.trim import trim
21
+ from huggingface_hub import ModelHubMixin, hf_hub_download
22
+ from safetensors import safe_open
23
  import tempfile
24
  import logging
25
 
 
66
  beam_width: int = 1 # only beam width of 1 is currently supported
67
  ref_audio_pad: float = 0
68
 
69
+ class Mars5TTS(nn.Module, ModelHubMixin):
 
 
70
  def __init__(self, ar_ckpt, nar_ckpt, device: str = None) -> None:
71
  super().__init__()
72
 
 
79
 
80
  # save and load text tokenize
81
  self.texttok = RegexTokenizer(GPT4_SPLIT_PATTERN)
82
+ texttok_data = io.BytesIO(ar_ckpt['vocab']['texttok.model'].encode('utf-8'))
83
+ self.texttok.load(texttok_data)
84
+ texttok_data.close()
85
+
86
  # save and load speech tokenizer
 
87
  self.speechtok = CodebookTokenizer(GPT4_SPLIT_PATTERN)
88
+ speechtok_data = io.BytesIO(ar_ckpt['vocab']['speechtok.model'].encode('utf-8'))
89
+ self.speechtok.load(speechtok_data)
90
+ speechtok_data.close()
91
+
92
  # keep track of tokenization things.
93
  self.n_vocab = len(self.texttok.vocab) + len(self.speechtok.vocab)
94
  self.n_text_vocab = len(self.texttok.vocab) + 1
 
111
  self.vocos = Vocos.from_pretrained("charactr/vocos-encodec-24khz").to(self.device).eval()
112
  nuke_weight_norm(self.codec)
113
  nuke_weight_norm(self.vocos)
114
+
115
+ @classmethod
116
+ def _from_pretrained(
117
+ cls: Type["Mars5TTS"],
118
+ *,
119
+ model_id: str,
120
+ revision: Optional[str],
121
+ cache_dir: Optional[Union[str, Path]],
122
+ force_download: bool,
123
+ proxies: Optional[Dict],
124
+ local_files_only: bool,
125
+ token: Optional[Union[str, bool]],
126
+ device: str = None,
127
+ **model_kwargs,
128
+ ) -> "Mars5TTS":
129
+ # Download files from Hub
130
+ ar_ckpt_path = hf_hub_download(repo_id=model_id, filename="mars5_ar.safetensors", revision=revision, cache_dir=cache_dir, force_download=force_download, proxies=proxies, local_files_only=local_files_only, token=token)
131
+ nar_ckpt_path = hf_hub_download(repo_id=model_id, filename="mars5_nar.safetensors", revision=revision, cache_dir=cache_dir, force_download=force_download, proxies=proxies, local_files_only=local_files_only, token=token)
132
+
133
+ ar_ckpt = {}
134
+ with safe_open(ar_ckpt_path, framework='pt', device='cpu') as f:
135
+ metadata = f.metadata()
136
+ ar_ckpt['vocab'] = {'texttok.model': metadata['texttok.model'], 'speechtok.model': metadata['speechtok.model']}
137
+ ar_ckpt['model'] = {}
138
+ for k in f.keys(): ar_ckpt['model'][k] = f.get_tensor(k)
139
+ nar_ckpt = {}
140
+ with safe_open(nar_ckpt_path, framework='pt', device='cpu') as f:
141
+ metadata = f.metadata()
142
+ nar_ckpt['vocab'] = {'texttok.model': metadata['texttok.model'], 'speechtok.model': metadata['speechtok.model']}
143
+ nar_ckpt['model'] = {}
144
+ for k in f.keys(): nar_ckpt['model'][k] = f.get_tensor(k)
145
+
146
+
147
+ # Init
148
+ return cls(ar_ckpt=ar_ckpt, nar_ckpt=nar_ckpt, device=device)
149
+
150
  @torch.inference_mode
151
  def vocode(self, tokens: Tensor) -> Tensor:
152
  """ Vocodes tokens of shape (seq_len, n_q) """
 
161
  wav_diffusion = self.vocos.decode(features, bandwidth_id=bandwidth_id)
162
  return wav_diffusion.cpu().squeeze()[None]
163
 
164
+ @torch.inference_mode
165
+ def get_speaker_embedding(self, ref_audio: Tensor) -> Tensor:
166
+ """ Given `ref_audio` (bs, T) audio tensor, compute the implicit speakre embedding of shape (bs, dim). """
167
+ if ref_audio.dim() == 1: ref_audio = ref_audio[None]
168
+ spk_reference = self.codec.encode(ref_audio[None].to(self.device))[0][0]
169
+ spk_reference = spk_reference.permute(0, 2, 1)
170
+ bs = spk_reference.shape[0]
171
+ if bs != 1:
172
+ raise AssertionError(f"Speaker embedding extraction only implemented using for bs=1 currently.")
173
+ spk_seq = self.codeclm.ref_chunked_emb(spk_reference) # (bs, sl, dim)
174
+ spk_ref_emb = self.codeclm.spk_identity_emb.weight[None].expand(bs, -1, -1) # (bs, 1, dim)
175
+
176
+ spk_seq = torch.cat([spk_ref_emb, spk_seq], dim=1) # (bs, 1+sl, dim)
177
+ # add pos encoding
178
+ spk_seq = self.codeclm.pos_embedding(spk_seq)
179
+ # codebook goes from indices 0->1023, padding is idx 1024 (the 1025th entry)
180
+ src_key_padding_mask = construct_padding_mask(spk_reference[:, :, 0], 1024)
181
+ src_key_padding_mask = torch.cat((
182
+ # append a zero here since we DO want to attend to initial position.
183
+ torch.zeros(src_key_padding_mask.shape[0], 1, dtype=bool, device=src_key_padding_mask.device),
184
+ src_key_padding_mask
185
+ ),
186
+ dim=1)
187
+ # pass through transformer
188
+ res = self.codeclm.spk_encoder(spk_seq, is_causal=False, src_key_padding_mask=src_key_padding_mask)[:, :1] # select first element -> now (bs, 1, dim).
189
+ return res.squeeze(1)
190
+
191
  @torch.inference_mode
192
  def tts(self, text: str, ref_audio: Tensor, ref_transcript: Optional[str] = None,
193
  cfg: Optional[InferenceConfig] = InferenceConfig()) -> Tensor:
 
245
  first_codec_idx = prompt.shape[-1] - n_speech_inp + 1
246
 
247
  # ---> perform AR code generation
 
248
  logging.debug(f"Raw acoustic prompt length: {raw_prompt_acoustic_len}")
249
 
250
  ar_codes = ar_generate(self.texttok, self.speechtok, self.codeclm,
251
  prompt, spk_ref_codec, first_codec_idx,
252
+ max_len=cfg.generate_max_len_override if cfg.generate_max_len_override > 1 else 2000,
253
+ fp16=True if torch.cuda.is_available() else False,
254
  temperature=cfg.temperature, topk=cfg.top_k, top_p=cfg.top_p, typical_p=cfg.typical_p,
255
  alpha_frequency=cfg.freq_penalty, alpha_presence=cfg.presence_penalty, penalty_window=cfg.rep_penalty_window,
256
  eos_penalty_decay=cfg.eos_penalty_decay, eos_penalty_factor=cfg.eos_penalty_factor,
 
273
  x_padding_mask = torch.zeros((1, _x.shape[1]), dtype=torch.bool, device=_x.device)
274
 
275
  # ---> perform DDPM NAR inference
 
276
  T = self.default_T
277
  diff = MultinomialDiffusion(self.diffusion_n_classes, timesteps=T, device=self.device)
278