File size: 2,102 Bytes
c816d1a
 
 
06a1ff8
c816d1a
 
 
 
 
 
 
06a1ff8
388787c
c816d1a
388787c
 
c816d1a
 
06a1ff8
c816d1a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
# yarngpt_utils.py
import torch
import torchaudio
import re
from outetts.wav_tokenizer.decoder import WavTokenizer
from transformers import AutoTokenizer

class AudioTokenizer:
    def __init__(self, hf_path, wav_tokenizer_model_path, wav_tokenizer_config_path):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.tokenizer = AutoTokenizer.from_pretrained(hf_path)
        
        # Fix: Use the correct parameter names for WavTokenizer
        self.wav_tokenizer = WavTokenizer(
            checkpoint_path=wav_tokenizer_model_path,
            config_path=wav_tokenizer_config_path,
            device=self.device
        )
        
        self.speakers = ["idera", "emma", "jude", "osagie", "tayo", "zainab", 
                         "joke", "regina", "remi", "umar", "chinenye"]
        
    def create_prompt(self, text, speaker_name=None):
        if speaker_name is None or speaker_name not in self.speakers:
            speaker_name = self.speakers[torch.randint(0, len(self.speakers), (1,)).item()]
        
        # Create a prompt similar to the original YarnGPT
        prompt = f"<|system|>\nYou are a helpful assistant that speaks in {speaker_name}'s voice.\n<|user|>\nSpeak this text: {text}\n<|assistant|>"
        return prompt
    
    def tokenize_prompt(self, prompt):
        input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
        return input_ids
    
    def get_codes(self, output):
        # Decode the sequence
        decoded_str = self.tokenizer.decode(output[0])
        
        # Extract the part after <|assistant|>
        speech_part = decoded_str.split("<|assistant|>")[-1].strip()
        
        # Extract code tokens - assuming format like "<audio_001>"
        audio_codes = []
        for match in re.finditer(r"<audio_(\d+)>", speech_part):
            code = int(match.group(1))
            audio_codes.append(code)
        
        return audio_codes
    
    def get_audio(self, codes):
        audio = self.wav_tokenizer.decode(torch.tensor(codes, device=self.device))
        return audio