tts / yarngpt_utils.py
okewunmi's picture
Create yarngpt_utils.py
c816d1a verified
raw
history blame
2.01 kB
# yarngpt_utils.py
import torch
import torchaudio
from outetts.wav_tokenizer.decoder import WavTokenizer
from transformers import AutoTokenizer
class AudioTokenizer:
def __init__(self, hf_path, wav_tokenizer_model_path, wav_tokenizer_config_path):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.tokenizer = AutoTokenizer.from_pretrained(hf_path)
self.wav_tokenizer = WavTokenizer(
checkpoint_path=wav_tokenizer_model_path,
config_path=wav_tokenizer_config_path,
device=self.device
)
self.speakers = ["idera", "emma", "jude", "osagie", "tayo", "zainab",
"joke", "regina", "remi", "umar", "chinenye"]
def create_prompt(self, text, speaker_name=None):
if speaker_name is None or speaker_name not in self.speakers:
speaker_name = self.speakers[torch.randint(0, len(self.speakers), (1,)).item()]
# Create a prompt similar to the original YarnGPT
prompt = f"<|system|>\nYou are a helpful assistant that speaks in {speaker_name}'s voice.\n<|user|>\nSpeak this text: {text}\n<|assistant|>"
return prompt
def tokenize_prompt(self, prompt):
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
return input_ids
def get_codes(self, output):
# Decode the sequence
decoded_str = self.tokenizer.decode(output[0])
# Extract the part after <|assistant|>
speech_part = decoded_str.split("<|assistant|>")[-1].strip()
# Extract code tokens - assuming format like "<audio_001>"
audio_codes = []
for match in re.finditer(r"<audio_(\d+)>", speech_part):
code = int(match.group(1))
audio_codes.append(code)
return audio_codes
def get_audio(self, codes):
audio = self.wav_tokenizer.decode(torch.tensor(codes, device=self.device))
return audio