File size: 7,667 Bytes
5be2e91 5875b95 5be2e91 0bf0be0 5be2e91 0bf0be0 5be2e91 29085b2 3181a38 29085b2 3181a38 5be2e91 0bf0be0 5be2e91 0bf0be0 5be2e91 5875b95 5be2e91 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
# @title Model Architecture
import tensorflow as tf
from tensorflow.keras import layers, activations, initializers
import numpy as np
class MiniSunConfig:
def __init__(self, vocab_size=30522, max_position_embeddings=1024, hidden_size=512,
num_attention_heads=8, intermediate_size=2048, num_hidden_layers=8,
dropout_rate=0.1, weight_decay=0.01, learning_rate=1e-4, total_steps=2500,
warmup_ratio=0.5, restart_period=500):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.dropout_rate = dropout_rate
self.weight_decay = weight_decay
self.learning_rate = learning_rate
self.total_steps = total_steps
self.warmup_ratio = warmup_ratio
self.restart_period = restart_period
@tf.keras.utils.register_keras_serializable()
class MiniSunModel(tf.keras.Model):
def __init__(self, config):
super(MiniSunModel, self).__init__()
self.config = config
# Embedding layers for token and position
self.token_embedding = layers.Embedding(config.vocab_size, config.hidden_size)
self.position_embedding = layers.Embedding(config.max_position_embeddings, config.hidden_size)
# Initialize an empty list for decoder blocks
self.decoder_blocks = []
# Final normalization and head
self.layer_norm = layers.LayerNormalization(epsilon=1e-6)
self.lm_head = layers.Dense(config.vocab_size, kernel_initializer=initializers.he_normal())
def build(self, input_shape):
# Create transformer decoder blocks based on the model configuration
self.decoder_blocks = [self._build_decoder_block() for _ in range(self.config.num_hidden_layers)]
# Call the superclass's build method
super(MiniSunModel, self).build(input_shape)
def _build_decoder_block(self):
# Decoder block consisting of multi-head attention and feed-forward layers
return [
layers.MultiHeadAttention(num_heads=self.config.num_attention_heads, key_dim=self.config.hidden_size,
kernel_initializer=initializers.he_normal()),
layers.LayerNormalization(epsilon=1e-6),
layers.Dense(self.config.intermediate_size, activation=activations.elu,
kernel_initializer=initializers.he_normal()),
layers.Dense(self.config.hidden_size, kernel_initializer=initializers.he_normal()),
layers.Dropout(self.config.dropout_rate)
]
def call(self, inputs, attention_mask=None, training=False):
input_ids = inputs['input_ids']
position_ids = tf.range(start=0, limit=tf.shape(input_ids)[-1], delta=1)
# Token and position embeddings
embeddings = self.token_embedding(input_ids) + self.position_embedding(position_ids)
# Adjust attention mask to correct shape [batch_size, 1, 1, seq_len]
if attention_mask is not None:
attention_mask = tf.cast(attention_mask[:, tf.newaxis, tf.newaxis, :], dtype=tf.float32)
# Apply decoder blocks
hidden_states = embeddings
for mha, norm, ffn1, ffn2, dropout in self.decoder_blocks:
attn_output = mha(hidden_states, hidden_states, attention_mask=attention_mask, training=training)
attn_output = dropout(attn_output, training=training)
hidden_states = norm(attn_output + hidden_states) # Add & Norm
# Feed-forward layers
ffn_output = ffn1(hidden_states)
ffn_output = ffn2(ffn_output)
ffn_output = dropout(ffn_output, training=training)
hidden_states = norm(ffn_output + hidden_states) # Add & Norm
# Final layer normalization
hidden_states = self.layer_norm(hidden_states)
# LM Head for token generation
logits = self.lm_head(hidden_states)
# Softmax layer for confidence
softmax_output = tf.nn.softmax(logits, axis=-1)
return logits, softmax_output
def get_config(self):
# Return the configuration of the model
return {
'config': self.config.__dict__
}
@classmethod
def from_config(cls, config):
# Create an instance of the model from the config
return cls(MiniSunConfig(**config['config']))
def compute_loss(self, labels, logits):
"""Computes the loss between labels and logits."""
if labels is None or logits is None:
raise ValueError("Labels and logits cannot be None.")
return tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True)
def train_step(self, data):
inputs, labels = data
input_ids = inputs['input_ids']
attention_mask = inputs['attention_mask']
with tf.GradientTape() as tape:
logits = self(inputs, training=True)
loss = self.compute_loss(labels, logits)
gradients = tape.gradient(loss, self.trainable_variables)
self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
logits_for_metrics = tf.argmax(logits, axis=-1)
logits_for_metrics = tf.reshape(logits_for_metrics, [-1])
labels_for_metrics = tf.reshape(labels, [-1])
for metric in self.metrics:
metric.update_state(labels_for_metrics, logits_for_metrics)
return {m.name: m.result() for m in self.metrics}
def create_model(config):
model = MiniSunModel(config)
# Optimizer with weight decay
optimizer = tf.keras.optimizers.AdamW(learning_rate=config.learning_rate, weight_decay=config.weight_decay)
strategy = tf.distribute.get_strategy()
with strategy.scope():
model.compile(optimizer=optimizer,
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
return model
def cosine_annealing_with_warmup(step, config):
"""Learning rate schedule with warm-up and cosine annealing."""
warmup_steps = int(config.total_steps * config.warmup_ratio)
if step < warmup_steps:
return config.learning_rate * (step / warmup_steps)
else:
cos_step = step - warmup_steps
total_cos_steps = config.total_steps - warmup_steps
return 0.5 * config.learning_rate * (1 + tf.cos(tf.constant(np.pi) * cos_step / total_cos_steps))
def cosine_annealing_with_restarts(step, config):
"""Learning rate schedule with warm-up and cosine annealing with restarts."""
warmup_steps = int(config.total_steps * config.warmup_ratio)
current_cycle = step // config.restart_period
effective_step = step % config.restart_period
if effective_step < warmup_steps:
return config.learning_rate * (effective_step / warmup_steps)
else:
cos_step = effective_step - warmup_steps
total_cos_steps = config.restart_period - warmup_steps
return 0.5 * config.learning_rate * (1 + tf.cos(tf.constant(np.pi) * cos_step / total_cos_steps))
# Configuration
config = MiniSunConfig()
# Initialize model with He initialization
model = create_model(config)
# Create a LearningRateScheduler callback
lr_scheduler = tf.keras.callbacks.LearningRateScheduler(lambda step: cosine_annealing_with_warmup(step, config))
lr_scheduler_with_restarts = tf.keras.callbacks.LearningRateScheduler(lambda step: cosine_annealing_with_restarts(step, config)) |