|
|
|
import tensorflow as tf |
|
from tensorflow.keras import layers, activations, initializers |
|
|
|
class MiniSunConfig: |
|
def __init__(self, vocab_size=30522, max_position_embeddings=1024, hidden_size=512, |
|
num_attention_heads=8, intermediate_size=2048, num_hidden_layers=8, |
|
dropout_rate=0.1, weight_decay=0.01, learning_rate=1e-4): |
|
self.vocab_size = vocab_size |
|
self.max_position_embeddings = max_position_embeddings |
|
self.hidden_size = hidden_size |
|
self.num_attention_heads = num_attention_heads |
|
self.intermediate_size = intermediate_size |
|
self.num_hidden_layers = num_hidden_layers |
|
self.dropout_rate = dropout_rate |
|
self.weight_decay = weight_decay |
|
self.learning_rate = learning_rate |
|
|
|
@tf.keras.utils.register_keras_serializable() |
|
class MiniSunModel(tf.keras.Model): |
|
def __init__(self, config): |
|
super(MiniSunModel, self).__init__() |
|
self.config = config |
|
|
|
|
|
self.token_embedding = layers.Embedding(config.vocab_size, config.hidden_size) |
|
self.position_embedding = layers.Embedding(config.max_position_embeddings, config.hidden_size) |
|
|
|
|
|
self.decoder_blocks = [self._build_decoder_block() for _ in range(config.num_hidden_layers)] |
|
|
|
|
|
self.layer_norm = layers.LayerNormalization(epsilon=1e-6) |
|
self.lm_head = layers.Dense(config.vocab_size, kernel_initializer=initializers.he_normal()) |
|
|
|
def _build_decoder_block(self): |
|
|
|
return [ |
|
layers.MultiHeadAttention(num_heads=self.config.num_attention_heads, key_dim=self.config.hidden_size, |
|
kernel_initializer=initializers.he_normal()), |
|
layers.LayerNormalization(epsilon=1e-6), |
|
layers.Dense(self.config.intermediate_size, activation=activations.elu, |
|
kernel_initializer=initializers.he_normal()), |
|
layers.Dense(self.config.hidden_size, kernel_initializer=initializers.he_normal()), |
|
layers.Dropout(self.config.dropout_rate) |
|
] |
|
|
|
def call(self, inputs, attention_mask=None, training=False): |
|
input_ids = inputs['input_ids'] |
|
position_ids = tf.range(start=0, limit=tf.shape(input_ids)[-1], delta=1) |
|
|
|
|
|
embeddings = self.token_embedding(input_ids) + self.position_embedding(position_ids) |
|
|
|
|
|
if attention_mask is not None: |
|
attention_mask = tf.cast(attention_mask[:, tf.newaxis, tf.newaxis, :], dtype=tf.float32) |
|
|
|
|
|
hidden_states = embeddings |
|
for mha, norm, ffn1, ffn2, dropout in self.decoder_blocks: |
|
attn_output = mha(hidden_states, hidden_states, attention_mask=attention_mask, training=training) |
|
attn_output = dropout(attn_output, training=training) |
|
hidden_states = norm(attn_output + hidden_states) |
|
|
|
|
|
ffn_output = ffn1(hidden_states) |
|
ffn_output = ffn2(ffn_output) |
|
ffn_output = dropout(ffn_output, training=training) |
|
hidden_states = norm(ffn_output + hidden_states) |
|
|
|
|
|
hidden_states = self.layer_norm(hidden_states) |
|
|
|
|
|
logits = self.lm_head(hidden_states) |
|
return logits |
|
|
|
def get_config(self): |
|
|
|
return { |
|
'config': self.config.__dict__ |
|
} |
|
|
|
@classmethod |
|
def from_config(cls, config): |
|
|
|
return cls(MiniSunConfig(**config['config'])) |
|
|
|
def train_step(self, data): |
|
|
|
inputs, labels = data |
|
|
|
|
|
input_ids = inputs['input_ids'] |
|
attention_mask = inputs['attention_mask'] |
|
|
|
with tf.GradientTape() as tape: |
|
|
|
logits = self(inputs, training=True) |
|
|
|
loss = self.compute_loss(labels, logits) |
|
|
|
|
|
trainable_vars = self.trainable_variables |
|
gradients = tape.gradient(loss, trainable_vars) |
|
|
|
|
|
self.optimizer.apply_gradients(zip(gradients, trainable_vars)) |
|
|
|
|
|
for metric in self.metrics: |
|
metric.update_state(labels, logits) |
|
|
|
return {m.name: m.result() for m in self.metrics} |
|
|
|
def create_model(config): |
|
model = MiniSunModel(config) |
|
|
|
|
|
optimizer = tf.keras.optimizers.AdamW(learning_rate=config.learning_rate, weight_decay=config.weight_decay) |
|
|
|
|
|
model.compile( |
|
optimizer=optimizer, |
|
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), |
|
metrics=['accuracy'] |
|
) |
|
return model |
|
|
|
|
|
config = MiniSunConfig() |
|
|
|
|
|
model = create_model(config) |