pheme / modules /t2s_model.py
taras-sereda's picture
minimal set of files to run inference; pheme-small checkpoint
96ee597
"""T2S model definition.
Copyright PolyAI Limited.
"""
import os
import numpy as np
from torch import nn
from transformers import EvalPrediction, T5Config, T5ForConditionalGeneration
from data.collation import get_text_semantic_token_collater
def compute_custom_metrics(eval_prediction: EvalPrediction):
# eval_prediction: tuple
# eval_prediction[0]: tensor of decoder outputs(logits) (n_batch, n_semantic, n_tokens) # noqa
# eval_prediction[1]: tensor of encoder outputs (n_batch, n_text/n_phone, n_hidden) # noqa
logits = eval_prediction.predictions[0]
labels = eval_prediction.label_ids
n_vocab = logits.shape[-1]
mask = labels == -100
top_1 = np.argmax(logits, axis=-1) == labels
top_1[mask] = False
top_5 = np.argsort(logits, axis=-1)[:, :, -5:]
top_5 = np.any(top_5 == np.expand_dims(labels, axis=-1), axis=-1)
top_5[mask] = False
top_10 = np.argsort(logits, axis=-1)[:, :, -10:]
top_10 = np.any(top_10 == np.expand_dims(labels, axis=-1), axis=-1)
top_10[mask] = False
top_1_accuracy = np.sum(top_1) / np.sum(~mask)
top_5_accuracy = np.sum(top_5) / np.sum(~mask)
top_10_accuracy = np.sum(top_10) / np.sum(~mask)
return {
"top_1_accuracy": top_1_accuracy,
"top_5_accuracy": top_5_accuracy,
"top_10_accuracy": top_10_accuracy,
}
class T2S(nn.Module):
def __init__(self, hp):
super().__init__()
self.text_tokens_file = "ckpt/unique_text_tokens.k2symbols"
self.collater = get_text_semantic_token_collater(self.text_tokens_file)
self.model_size = hp.model_size
self.vocab_size = len(self.collater.idx2token)
self.config = self._define_model_config(self.model_size)
print(f"{self.config = }")
self.t2s = T5ForConditionalGeneration(self.config)
def _define_model_config(self, model_size):
if model_size == "test":
# n_params = 16M
d_ff = 16
d_model = 8
d_kv = 32
num_heads = 1
num_decoder_layers = 1
num_layers = 1
elif model_size == "tiny":
# n_params = 16M
d_ff = 1024
d_model = 256
d_kv = 32
num_heads = 4
num_decoder_layers = 4
num_layers = 4
elif model_size == "t5small":
# n_params = 60M
d_ff = 2048
d_model = 512
d_kv = 64
num_heads = 8
num_decoder_layers = 6
num_layers = 6
elif model_size == "large":
# n_params = 100M
d_ff = 2048
d_model = 512
d_kv = 64
num_heads = 8
num_decoder_layers = 14
num_layers = 14
elif model_size == "Large":
# n_params = 114M
d_ff = 4096
d_model = 512
d_kv = 64
num_heads = 8
num_decoder_layers = 6
num_layers = 10
else:
raise ValueError(f"unknown {model_size}")
config = T5Config(
d_ff=d_ff,
d_model=d_model,
d_kv=d_kv,
num_heads=num_heads,
num_decoder_layers=num_decoder_layers,
num_layers=num_layers,
decoder_start_token_id=0,
eos_token_id=2,
vocab_size=self.vocab_size,
)
return config