Spaces:
Runtime error
Runtime error
"""T2S model definition. | |
Copyright PolyAI Limited. | |
""" | |
import os | |
import numpy as np | |
from torch import nn | |
from transformers import EvalPrediction, T5Config, T5ForConditionalGeneration | |
from data.collation import get_text_semantic_token_collater | |
def compute_custom_metrics(eval_prediction: EvalPrediction): | |
# eval_prediction: tuple | |
# eval_prediction[0]: tensor of decoder outputs(logits) (n_batch, n_semantic, n_tokens) # noqa | |
# eval_prediction[1]: tensor of encoder outputs (n_batch, n_text/n_phone, n_hidden) # noqa | |
logits = eval_prediction.predictions[0] | |
labels = eval_prediction.label_ids | |
n_vocab = logits.shape[-1] | |
mask = labels == -100 | |
top_1 = np.argmax(logits, axis=-1) == labels | |
top_1[mask] = False | |
top_5 = np.argsort(logits, axis=-1)[:, :, -5:] | |
top_5 = np.any(top_5 == np.expand_dims(labels, axis=-1), axis=-1) | |
top_5[mask] = False | |
top_10 = np.argsort(logits, axis=-1)[:, :, -10:] | |
top_10 = np.any(top_10 == np.expand_dims(labels, axis=-1), axis=-1) | |
top_10[mask] = False | |
top_1_accuracy = np.sum(top_1) / np.sum(~mask) | |
top_5_accuracy = np.sum(top_5) / np.sum(~mask) | |
top_10_accuracy = np.sum(top_10) / np.sum(~mask) | |
return { | |
"top_1_accuracy": top_1_accuracy, | |
"top_5_accuracy": top_5_accuracy, | |
"top_10_accuracy": top_10_accuracy, | |
} | |
class T2S(nn.Module): | |
def __init__(self, hp): | |
super().__init__() | |
self.text_tokens_file = "ckpt/unique_text_tokens.k2symbols" | |
self.collater = get_text_semantic_token_collater(self.text_tokens_file) | |
self.model_size = hp.model_size | |
self.vocab_size = len(self.collater.idx2token) | |
self.config = self._define_model_config(self.model_size) | |
print(f"{self.config = }") | |
self.t2s = T5ForConditionalGeneration(self.config) | |
def _define_model_config(self, model_size): | |
if model_size == "test": | |
# n_params = 16M | |
d_ff = 16 | |
d_model = 8 | |
d_kv = 32 | |
num_heads = 1 | |
num_decoder_layers = 1 | |
num_layers = 1 | |
elif model_size == "tiny": | |
# n_params = 16M | |
d_ff = 1024 | |
d_model = 256 | |
d_kv = 32 | |
num_heads = 4 | |
num_decoder_layers = 4 | |
num_layers = 4 | |
elif model_size == "t5small": | |
# n_params = 60M | |
d_ff = 2048 | |
d_model = 512 | |
d_kv = 64 | |
num_heads = 8 | |
num_decoder_layers = 6 | |
num_layers = 6 | |
elif model_size == "large": | |
# n_params = 100M | |
d_ff = 2048 | |
d_model = 512 | |
d_kv = 64 | |
num_heads = 8 | |
num_decoder_layers = 14 | |
num_layers = 14 | |
elif model_size == "Large": | |
# n_params = 114M | |
d_ff = 4096 | |
d_model = 512 | |
d_kv = 64 | |
num_heads = 8 | |
num_decoder_layers = 6 | |
num_layers = 10 | |
else: | |
raise ValueError(f"unknown {model_size}") | |
config = T5Config( | |
d_ff=d_ff, | |
d_model=d_model, | |
d_kv=d_kv, | |
num_heads=num_heads, | |
num_decoder_layers=num_decoder_layers, | |
num_layers=num_layers, | |
decoder_start_token_id=0, | |
eos_token_id=2, | |
vocab_size=self.vocab_size, | |
) | |
return config | |