|
|
|
import torch |
|
import torchaudio |
|
from outetts.wav_tokenizer.decoder import WavTokenizer |
|
from transformers import AutoTokenizer |
|
|
|
class AudioTokenizer: |
|
def __init__(self, hf_path, wav_tokenizer_model_path, wav_tokenizer_config_path): |
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
self.tokenizer = AutoTokenizer.from_pretrained(hf_path) |
|
self.wav_tokenizer = WavTokenizer( |
|
checkpoint_path=wav_tokenizer_model_path, |
|
config_path=wav_tokenizer_config_path, |
|
device=self.device |
|
) |
|
self.speakers = ["idera", "emma", "jude", "osagie", "tayo", "zainab", |
|
"joke", "regina", "remi", "umar", "chinenye"] |
|
|
|
def create_prompt(self, text, speaker_name=None): |
|
if speaker_name is None or speaker_name not in self.speakers: |
|
speaker_name = self.speakers[torch.randint(0, len(self.speakers), (1,)).item()] |
|
|
|
|
|
prompt = f"<|system|>\nYou are a helpful assistant that speaks in {speaker_name}'s voice.\n<|user|>\nSpeak this text: {text}\n<|assistant|>" |
|
return prompt |
|
|
|
def tokenize_prompt(self, prompt): |
|
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device) |
|
return input_ids |
|
|
|
def get_codes(self, output): |
|
|
|
decoded_str = self.tokenizer.decode(output[0]) |
|
|
|
|
|
speech_part = decoded_str.split("<|assistant|>")[-1].strip() |
|
|
|
|
|
audio_codes = [] |
|
for match in re.finditer(r"<audio_(\d+)>", speech_part): |
|
code = int(match.group(1)) |
|
audio_codes.append(code) |
|
|
|
return audio_codes |
|
|
|
def get_audio(self, codes): |
|
audio = self.wav_tokenizer.decode(torch.tensor(codes, device=self.device)) |
|
return audio |