pheme / modules /t2s_model.py
taras-sereda's picture
minimal set of files to run inference; pheme-small checkpoint
96ee597
raw
history blame
3.53 kB
"""T2S model definition.
Copyright PolyAI Limited.
"""
import os
import numpy as np
from torch import nn
from transformers import EvalPrediction, T5Config, T5ForConditionalGeneration
from data.collation import get_text_semantic_token_collater
def compute_custom_metrics(eval_prediction: EvalPrediction):
# eval_prediction: tuple
# eval_prediction[0]: tensor of decoder outputs(logits) (n_batch, n_semantic, n_tokens) # noqa
# eval_prediction[1]: tensor of encoder outputs (n_batch, n_text/n_phone, n_hidden) # noqa
logits = eval_prediction.predictions[0]
labels = eval_prediction.label_ids
n_vocab = logits.shape[-1]
mask = labels == -100
top_1 = np.argmax(logits, axis=-1) == labels
top_1[mask] = False
top_5 = np.argsort(logits, axis=-1)[:, :, -5:]
top_5 = np.any(top_5 == np.expand_dims(labels, axis=-1), axis=-1)
top_5[mask] = False
top_10 = np.argsort(logits, axis=-1)[:, :, -10:]
top_10 = np.any(top_10 == np.expand_dims(labels, axis=-1), axis=-1)
top_10[mask] = False
top_1_accuracy = np.sum(top_1) / np.sum(~mask)
top_5_accuracy = np.sum(top_5) / np.sum(~mask)
top_10_accuracy = np.sum(top_10) / np.sum(~mask)
return {
"top_1_accuracy": top_1_accuracy,
"top_5_accuracy": top_5_accuracy,
"top_10_accuracy": top_10_accuracy,
}
class T2S(nn.Module):
def __init__(self, hp):
super().__init__()
self.text_tokens_file = "ckpt/unique_text_tokens.k2symbols"
self.collater = get_text_semantic_token_collater(self.text_tokens_file)
self.model_size = hp.model_size
self.vocab_size = len(self.collater.idx2token)
self.config = self._define_model_config(self.model_size)
print(f"{self.config = }")
self.t2s = T5ForConditionalGeneration(self.config)
def _define_model_config(self, model_size):
if model_size == "test":
# n_params = 16M
d_ff = 16
d_model = 8
d_kv = 32
num_heads = 1
num_decoder_layers = 1
num_layers = 1
elif model_size == "tiny":
# n_params = 16M
d_ff = 1024
d_model = 256
d_kv = 32
num_heads = 4
num_decoder_layers = 4
num_layers = 4
elif model_size == "t5small":
# n_params = 60M
d_ff = 2048
d_model = 512
d_kv = 64
num_heads = 8
num_decoder_layers = 6
num_layers = 6
elif model_size == "large":
# n_params = 100M
d_ff = 2048
d_model = 512
d_kv = 64
num_heads = 8
num_decoder_layers = 14
num_layers = 14
elif model_size == "Large":
# n_params = 114M
d_ff = 4096
d_model = 512
d_kv = 64
num_heads = 8
num_decoder_layers = 6
num_layers = 10
else:
raise ValueError(f"unknown {model_size}")
config = T5Config(
d_ff=d_ff,
d_model=d_model,
d_kv=d_kv,
num_heads=num_heads,
num_decoder_layers=num_decoder_layers,
num_layers=num_layers,
decoder_start_token_id=0,
eos_token_id=2,
vocab_size=self.vocab_size,
)
return config