Spaces:
Runtime error
Runtime error
"""Semantic tokens loading logic. | |
Copyright PolyAI Limited. | |
""" | |
import json | |
import logging | |
import random | |
import re | |
from logging import getLogger | |
from pathlib import Path | |
from typing import List, Pattern, Union | |
import numpy as np | |
import torch | |
from phonemizer.backend import EspeakBackend | |
from phonemizer.backend.espeak.language_switch import LanguageSwitch | |
from phonemizer.backend.espeak.words_mismatch import WordMismatch | |
from phonemizer.punctuation import Punctuation | |
from phonemizer.separator import Separator | |
from torch.utils.data import DataLoader, Dataset | |
from tqdm import tqdm | |
from data.collation import get_text_semantic_token_collater | |
class TextTokenizer: | |
"""Phonemize Text.""" | |
def __init__( | |
self, | |
language="en-us", | |
backend="espeak", | |
separator=Separator(word="_", syllable="-", phone="|"), | |
preserve_punctuation=True, | |
punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(), | |
with_stress: bool = False, | |
tie: Union[bool, str] = False, | |
language_switch: LanguageSwitch = "keep-flags", | |
words_mismatch: WordMismatch = "ignore", | |
) -> None: | |
logger = getLogger("phonemizer") | |
logger.setLevel(logging.ERROR) | |
if backend == "espeak": | |
phonemizer = EspeakBackend( | |
language, | |
punctuation_marks=punctuation_marks, | |
preserve_punctuation=preserve_punctuation, | |
with_stress=with_stress, | |
tie=tie, | |
language_switch=language_switch, | |
words_mismatch=words_mismatch, | |
logger=logger, | |
) | |
else: | |
raise NotImplementedError(f"{backend}") | |
self.backend = phonemizer | |
self.separator = separator | |
def to_list(self, phonemized: str) -> List[str]: | |
fields = [] | |
for word in phonemized.split(self.separator.word): | |
# "ɐ m|iː|n?" ɹ|ɪ|z|ɜː|v; h|ɪ|z. | |
pp = re.findall(r"\w+|[^\w\s]", word, re.UNICODE) | |
fields.extend( | |
[p for p in pp if p != self.separator.phone] + [self.separator.word] | |
) | |
assert len("".join(fields[:-1])) == len(phonemized) - phonemized.count( | |
self.separator.phone | |
) | |
return fields[:-1] | |
def __call__(self, text, strip=True) -> List[List[str]]: | |
if isinstance(text, str): | |
text = [text] | |
phonemized = self.backend.phonemize( | |
text, separator=self.separator, strip=strip, njobs=1 | |
) | |
return [self.to_list(p) for p in phonemized] | |
class Collator: | |
def collate(self, batch): | |
input_ids = [item["input_ids"] for item in batch] | |
output_sequences = [item["labels"] for item in batch] | |
# Pad sequences to the maximum length in the batch | |
input_ids = torch.nn.utils.rnn.pad_sequence( | |
input_ids, batch_first=True, padding_value=0 | |
) | |
output_sequences = torch.nn.utils.rnn.pad_sequence( | |
output_sequences, batch_first=True, padding_value=-100 | |
) | |
# 1 - token is unmasked, 0 - token is masked. | |
attention_mask = input_ids != 0 | |
return { | |
"input_ids": input_ids, | |
"attention_mask": attention_mask, | |
"labels": output_sequences, | |
} | |
class ConcatenateSemanticDataset(Dataset): | |
def __init__( | |
self, manifest_path: str, symbol_table_path: str, | |
n_samples: int = 0, max_duration=15): | |
self.data = [] | |
self.phonemizer = TextTokenizer() | |
self.text_collater = get_text_semantic_token_collater( | |
symbol_table_path) | |
self.manifest_path = manifest_path | |
self.n_samples = n_samples | |
self.max_duration = max_duration | |
if manifest_path is not None: | |
self._build() | |
def __len__(self): | |
if self.n_samples: | |
return min(self.n_samples, len(self.data)) | |
return len(self.data) | |
def remove_unknown_symbols(self, text: List[str]): | |
res = [] | |
for sym in text: | |
if sym not in self.text_collater.token2idx: | |
# print(f'{sym} is unk') | |
continue | |
res.append(sym) | |
return res | |
def __getitem__(self, idx): | |
item = self.data[idx] | |
input_ids = item["phoneme"].split("|") | |
input_ids = self.remove_unknown_symbols(input_ids) | |
input_ids_2 = None | |
if item.get("phoneme_2"): | |
input_ids_2 = item["phoneme_2"].split("|") | |
input_ids_2 = [self.remove_unknown_symbols(input_ids_2)] | |
input_ids = self.text_collater( | |
[input_ids], input_ids_2).to(dtype=torch.long) | |
input_ids = input_ids.to(dtype=torch.long) | |
labels = np.load(item["semantic_path"]) | |
labels = [str(lbl) for lbl in labels] | |
labels_2 = None | |
if item.get("semantic_path_2"): | |
labels_2 = np.load(item["semantic_path_2"]) | |
labels_2 = [[str(lbl) for lbl in labels_2]] | |
labels = self.text_collater([labels], labels_2).to(dtype=torch.long) | |
return {"input_ids": input_ids.squeeze(0), "labels": labels.squeeze(0)} | |
# TODO - remove this to not load to the memory | |
def _build(self): | |
for manifest_path in self.manifest_path: | |
dataset_path = Path(manifest_path).parent | |
with open(manifest_path, "r") as manifest_file: | |
manifest_data = json.load(manifest_file) | |
for key, value in tqdm(manifest_data.items()): | |
if float(value["duration"]) > self.max_duration: | |
continue | |
text = value["text"] | |
phoneme = value["phoneme"] | |
npy_path = f"{dataset_path}/audios-speech-tokenizer/semantic/{key.split('.wav')[0]}.npy" # noqa | |
datapoint = { | |
"text": text, | |
"semantic_path": npy_path, | |
"phoneme": phoneme | |
} | |
self.data.append(datapoint) | |
print(f"Total length of the dataset {manifest_path}: {len(self.data)}") | |
random.shuffle(self.data) | |
if __name__ == "__main__": | |
# Create an instance of the dataset | |
manifest_path = "datasets/ljspeech-training-data/dev.json" | |
text_tokens_file = "ckpt/unique_text_tokens.k2symbols" | |
seq2seq_dataset = ConcatenateSemanticDataset( | |
[manifest_path, manifest_path], text_tokens_file) | |
# seq2seq_dataset.phonemize_and_rewrite_manifest() | |
batch_size = 1 # Adjust to your desired batch size | |
dataloader = DataLoader( | |
seq2seq_dataset, | |
batch_size=batch_size, | |
shuffle=True, | |
collate_fn=Collator().collate, | |
) | |
for batch in dataloader: | |
print(batch["input_ids"]) | |
print(batch["labels"]) | |
print(batch["input_ids"][0].unique().max()) | |
print(batch["input_ids"][0].unique().min()) | |
print(batch["input_ids"].shape) | |
print(batch["labels"].shape) | |
break # Stop after the first batch if needed | |