Spaces:
Runtime error
Runtime error
File size: 5,720 Bytes
96ee597 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
"""Collators for T2S and S2A.
Copyright PolyAI Limited.
"""
from pathlib import Path
from typing import List, Tuple, Union
import numpy as np
import torch
from utils.symbol_table import SymbolTable
class GlobalCollater:
def __init__(self, n_codes, n_semantic_codes):
self.n_codes = n_codes
self.sem_mask_id = n_semantic_codes
def collate(self, batch):
output = {
'speaker': [],
'tts_quantize_input': [],
'tts_quantize_output': [],
'quantize_mask': [],
'f_names': [],
'semantic_tokens': [],
'quantization_lengths': [],
}
# Get the max length of everything
max_len_q = 0
for _, q_s, q_e, _, _ in batch:
if len(q_s) > max_len_q:
max_len_q = len(q_s)
output['quantization_lengths'].append(len(q_s))
# Pad each element, create mask
for spkr, qs, qe, itm_name, s_tokens in batch:
# Deal with quantizations
q_mask = np.array(
[False] * len(qs) + [True] * (max_len_q - len(qs)))
qs = np.pad(
qs,
[[0, max_len_q-len(qs)], [0, 0]],
constant_values=self.n_codes
)
qe = np.pad(
qe,
[[0, max_len_q-len(qe)], [0, 0]],
constant_values=self.n_codes
)
# Deal with semantics
s_tokens = s_tokens.flatten()
s_tokens = np.pad(
s_tokens,
(0, max_len_q-len(s_tokens)),
constant_values=self.sem_mask_id
)
# Speaker padding
spkr = np.concatenate(
(spkr, np.zeros((max_len_q - len(spkr), 512))))
# Aggregate
output['speaker'].append(spkr)
output['tts_quantize_input'].append(qs)
output['tts_quantize_output'].append(qe)
output['quantize_mask'].append(q_mask)
output['f_names'].append(itm_name)
output["semantic_tokens"].append(s_tokens)
for k in output.keys():
if k == 'f_names':
continue
output[k] = np.array(output[k])
if 'mask' in k:
output[k] = torch.BoolTensor(output[k])
elif k in [
'tts_quantize_input', 'tts_quantize_output',
'semantic_tokens', 'quantization_lengths'
]:
output[k] = torch.LongTensor(output[k])
else:
output[k] = torch.FloatTensor(output[k])
return output
class TextTokenCollater:
def __init__(
self,
text_tokens: List[str],
add_eos: bool = True,
add_bos: bool = True,
pad_symbol: str = "<pad>",
bos_symbol: str = "<bos>",
eos_symbol: str = "<eos>",
spkr_1_symbol: str = "spkr_1",
spkr_2_symbol: str = "spkr_2",
):
self.pad_symbol = pad_symbol
self.add_eos = add_eos
self.add_bos = add_bos
self.bos_symbol = bos_symbol
self.eos_symbol = eos_symbol
self.spkr_1_symbol = spkr_1_symbol
self.spkr_2_symbol = spkr_2_symbol
unique_tokens = (
[pad_symbol]
+ ([bos_symbol] if add_bos else [])
+ ([eos_symbol] if add_eos else [])
+ ([spkr_1_symbol])
+ ([spkr_2_symbol])
+ sorted(text_tokens)
)
self.token2idx = {token: idx for idx, token in enumerate(unique_tokens)}
self.idx2token = [token for token in unique_tokens]
def __call__(
self, texts: List[str], texts_2: Union[None, List[str]] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
tokens_seqs = [[p for p in text] for text in texts]
if texts_2 is None:
seqs = [
([self.bos_symbol] if self.add_bos else [])
+ [self.spkr_1_symbol]
+ list(seq)
+ ([self.eos_symbol] if self.add_eos else [])
for seq in tokens_seqs
]
else:
tokens_seqs_2 = [[p for p in text] for text in texts_2]
seqs = [
([self.bos_symbol] if self.add_bos else [])
+ [self.spkr_1_symbol]
+ list(seq)
+ ([self.spkr_2_symbol])
+ list(seq_2)
+ ([self.eos_symbol] if self.add_eos else [])
for seq, seq_2 in zip(tokens_seqs, tokens_seqs_2)
]
tokens_batch = torch.from_numpy(
np.array(
[[self.token2idx[token] for token in seq] for seq in seqs],
dtype=np.int64,
)
)
return tokens_batch
def get_text_token_collater(text_tokens_file: str) -> TextTokenCollater:
text_tokens_path = Path(text_tokens_file)
unique_tokens = SymbolTable.from_file(text_tokens_path)
collater = TextTokenCollater(
unique_tokens.symbols, add_bos=True, add_eos=True
)
return collater
def get_text_semantic_token_collater(
text_tokens_file: str, n_semantic_tokens=1024) -> TextTokenCollater:
text_tokens_path = Path(text_tokens_file)
unique_tokens = SymbolTable.from_file(text_tokens_path)
for semantic_idx in range(n_semantic_tokens):
unique_tokens.add(str(semantic_idx))
collater = TextTokenCollater(
unique_tokens.symbols, add_bos=True, add_eos=True
)
return collater
if __name__ == '__main__':
text_tokens_file = 'ckpt/unique_text_tokens.k2symbols'
collater = get_text_semantic_token_collater(text_tokens_file)
|