File size: 7,426 Bytes
c42ab1b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
# @title Model Architecture
import tensorflow as tf
from tensorflow.keras import layers, activations, initializers
class MiniSunConfig:
def __init__(self, vocab_size=30522, max_position_embeddings=1024, hidden_size=512,
num_attention_heads=8, intermediate_size=2048, num_hidden_layers=8,
dropout_rate=0.1, weight_decay=0.01, learning_rate=1e-4, total_steps=2500, warmup_steps=0.2):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.dropout_rate = dropout_rate
self.weight_decay = weight_decay
self.learning_rate = learning_rate
self.total_steps = total_steps
self.warmup_steps = warmup_steps
@tf.keras.utils.register_keras_serializable()
class MiniSunModel(tf.keras.Model):
def __init__(self, config):
super(MiniSunModel, self).__init__()
self.config = config
# Embedding layers for token and position
self.token_embedding = layers.Embedding(config.vocab_size, config.hidden_size)
self.position_embedding = layers.Embedding(config.max_position_embeddings, config.hidden_size)
# Transformer decoder blocks
self.decoder_blocks = [self._build_decoder_block() for _ in range(config.num_hidden_layers)]
# Final normalization and head
self.layer_norm = layers.LayerNormalization(epsilon=1e-6)
self.lm_head = layers.Dense(config.vocab_size, kernel_initializer=initializers.he_normal())
def _build_decoder_block(self):
# Decoder block consisting of multi-head attention and feed-forward layers
return [
layers.MultiHeadAttention(num_heads=self.config.num_attention_heads, key_dim=self.config.hidden_size,
kernel_initializer=initializers.he_normal()),
layers.LayerNormalization(epsilon=1e-6),
layers.Dense(self.config.intermediate_size, activation=activations.elu,
kernel_initializer=initializers.he_normal()),
layers.Dense(self.config.hidden_size, kernel_initializer=initializers.he_normal()),
layers.Dropout(self.config.dropout_rate)
]
def call(self, inputs, attention_mask=None, training=False):
input_ids = inputs['input_ids']
position_ids = tf.range(start=0, limit=tf.shape(input_ids)[-1], delta=1)
# Token and position embeddings
embeddings = self.token_embedding(input_ids) + self.position_embedding(position_ids)
# Adjust attention mask to correct shape [batch_size, 1, 1, seq_len]
if attention_mask is not None:
attention_mask = tf.cast(attention_mask[:, tf.newaxis, tf.newaxis, :], dtype=tf.float32)
# Apply decoder blocks
hidden_states = embeddings
for mha, norm, ffn1, ffn2, dropout in self.decoder_blocks:
attn_output = mha(hidden_states, hidden_states, attention_mask=attention_mask, training=training)
attn_output = dropout(attn_output, training=training)
hidden_states = norm(attn_output + hidden_states) # Add & Norm
# Feed-forward layers
ffn_output = ffn1(hidden_states)
ffn_output = ffn2(ffn_output)
ffn_output = dropout(ffn_output, training=training)
hidden_states = norm(ffn_output + hidden_states) # Add & Norm
# Final layer normalization
hidden_states = self.layer_norm(hidden_states)
# LM Head for token generation
logits = self.lm_head(hidden_states)
return logits
def get_config(self):
# Return the configuration of the model
return {
'config': self.config.__dict__
}
@classmethod
def from_config(cls, config):
# Create an instance of the model from the config
return cls(MiniSunConfig(**config['config']))
def compute_loss(self, labels, logits):
"""Computes the loss between labels and logits."""
# Ensure labels and logits are not None
if labels is None or logits is None:
raise ValueError("Labels and logits cannot be None.")
return tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True)
def train_step(self, data):
inputs, labels = data
input_ids = inputs['input_ids']
attention_mask = inputs['attention_mask']
with tf.GradientTape() as tape:
logits = self(inputs, training=True)
loss = self.compute_loss(labels, logits)
gradients = tape.gradient(loss, self.trainable_variables)
self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
logits_for_metrics = tf.argmax(logits, axis=-1)
logits_for_metrics = tf.reshape(logits_for_metrics, [-1])
labels_for_metrics = tf.reshape(labels, [-1])
for metric in self.metrics:
metric.update_state(labels_for_metrics, logits_for_metrics)
return {m.name: m.result() for m in self.metrics}
def create_model(config):
model = MiniSunModel(config)
# Optimizer with weight decay
optimizer = tf.keras.optimizers.AdamW(learning_rate=config.learning_rate, weight_decay=config.weight_decay)
model.compile(
optimizer=optimizer,
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy']
)
return model
def cosine_annealing_with_warmup(step, config):
"""Learning rate schedule with warm-up and cosine annealing."""
warmup_steps = int(config.total_steps * config.warmup_steps)
if step < warmup_steps:
return config.learning_rate * (step / warmup_steps)
else:
# Calculate the cosine decay
cos_step = step - warmup_steps
total_cos_steps = config.total_steps - warmup_steps
return 0.5 * config.learning_rate * (1 + tf.cos(tf.constant(np.pi) * cos_step / total_cos_steps))
def cosine_annealing_with_restarts(step, config, restart_period, cycle_num):
"""Learning rate schedule with warm-up and cosine annealing with restarts."""
warmup_steps = int(config.total_steps * config.warmup_steps)
# Determine the current cycle based on step and restart_period
current_cycle = step // restart_period
# Calculate the effective step within the current cycle
effective_step = step % restart_period
if effective_step < warmup_steps:
return config.learning_rate * (effective_step / warmup_steps)
else:
# Calculate the cosine decay within the current cycle
cos_step = effective_step - warmup_steps
total_cos_steps = restart_period - warmup_steps
return 0.5 * config.learning_rate * (1 + tf.cos(tf.constant(np.pi) * cos_step / total_cos_steps))
# Configuration
config = MiniSunConfig()
# Initialize model with He initialization
model = create_model(config)
# Create a LearningRateScheduler callback
lr_scheduler = tf.keras.callbacks.LearningRateScheduler(lambda step: cosine_annealing_with_warmup(step, config))
#lr_scheduler_with_restarts = tf.keras.callbacks.LearningRateScheduler(lambda step: cosine_annealing_with_restarts(step, config, restart_period=1000, cycle_num=1)) |