File size: 8,895 Bytes
6aeb9de |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 |
import tensorflow as tf
from tensorflow.keras import layers, activations, initializers, regularizers
import numpy as np
# Define RMSNorm
class RMSNorm(tf.keras.layers.Layer):
def __init__(self, epsilon=1e-6):
super(RMSNorm, self).__init__()
self.epsilon = epsilon
def call(self, inputs):
# Calculate the RMS and normalize the input
rms = tf.sqrt(tf.reduce_mean(tf.square(inputs), axis=-1, keepdims=True))
return inputs / (rms + self.epsilon)
class MiniSunConfig:
def __init__(self, vocab_size=30522, max_position_embeddings=1024, hidden_size=512,
num_attention_heads=8, intermediate_size=2048, num_hidden_layers=8,
dropout_rate=0.1, weight_decay=0.01, learning_rate=1e-4, total_steps=2500,
warmup_ratio=0.5, restart_period=500, l1_reg=0.0, l2_reg=0.01):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.dropout_rate = dropout_rate
self.weight_decay = weight_decay
self.learning_rate = learning_rate
self.total_steps = total_steps
self.warmup_ratio = warmup_ratio
self.restart_period = restart_period
self.l1_reg = l1_reg # L1 regularization strength
self.l2_reg = l2_reg # L2 regularization strength
@tf.keras.utils.register_keras_serializable()
class MiniSunModel(tf.keras.Model):
def __init__(self, config):
super(MiniSunModel, self).__init__()
self.config = config
# Embedding layers for token and dynamic positional embeddings (RoPE)
self.token_embedding = layers.Embedding(config.vocab_size, config.hidden_size)
self.position_embedding = layers.Embedding(config.max_position_embeddings, config.hidden_size)
# Initialize an empty list for decoder blocks
self.decoder_blocks = []
# Final normalization and head
self.layer_norm = RMSNorm(epsilon=1e-6)
self.lm_head = layers.Dense(config.vocab_size, kernel_initializer=initializers.he_normal(),
kernel_regularizer=regularizers.l2(config.l2_reg))
# Stochastic depth (layer drop)
self.layer_dropout = tf.keras.layers.Dropout(config.dropout_rate)
def build(self, input_shape):
# Create transformer decoder blocks based on the model configuration
self.decoder_blocks = [self._build_decoder_block() for _ in range(self.config.num_hidden_layers)]
super(MiniSunModel, self).build(input_shape)
def _build_decoder_block(self):
# Decoder block with multi-query attention and feed-forward layers, using RMSNorm and regularization
return [
layers.MultiHeadAttention(num_heads=self.config.num_attention_heads, key_dim=self.config.hidden_size,
kernel_initializer=initializers.he_normal(),
kernel_regularizer=regularizers.l2(self.config.l2_reg)),
RMSNorm(epsilon=1e-6), # Use RMSNorm instead of LayerNormalization
layers.Dense(self.config.intermediate_size, activation=activations.elu,
kernel_initializer=initializers.he_normal(),
kernel_regularizer=regularizers.l1_l2(self.config.l1_reg, self.config.l2_reg)),
layers.Dense(self.config.hidden_size, kernel_initializer=initializers.he_normal(),
kernel_regularizer=regularizers.l1_l2(self.config.l1_reg, self.config.l2_reg)),
layers.Dropout(self.config.dropout_rate)
]
def call(self, inputs, attention_mask=None, training=False):
input_ids = inputs['input_ids']
position_ids = tf.range(start=0, limit=tf.shape(input_ids)[-1], delta=1)
# Token and position embeddings with RoPE (Rotary Positional Embeddings)
embeddings = self.token_embedding(input_ids) + self.position_embedding(position_ids)
# Adjust attention mask to correct shape [batch_size, 1, 1, seq_len]
if attention_mask is not None:
attention_mask = tf.cast(attention_mask[:, tf.newaxis, tf.newaxis, :], dtype=tf.float32)
# Apply decoder blocks with stochastic depth and gradient clipping
hidden_states = embeddings
for mha, norm, ffn1, ffn2, dropout in self.decoder_blocks:
attn_output = mha(hidden_states, hidden_states, attention_mask=attention_mask, training=training)
attn_output = dropout(attn_output, training=training)
hidden_states = norm(attn_output + hidden_states) # Add & RMSNorm
# Feed-forward layers
ffn_output = ffn1(hidden_states)
ffn_output = ffn2(ffn_output)
ffn_output = dropout(ffn_output, training=training)
hidden_states = norm(ffn_output + hidden_states) # Add & RMSNorm
# Final layer normalization
hidden_states = self.layer_norm(hidden_states)
# LM Head for token generation
logits = self.lm_head(hidden_states)
# Softmax layer for confidence
softmax_output = tf.nn.softmax(logits, axis=-1)
return logits, softmax_output
def get_config(self):
return {'config': self.config.__dict__}
@classmethod
def from_config(cls, config):
return cls(MiniSunConfig(**config['config']))
def compute_loss(self, labels, logits):
if labels is None or logits is None:
raise ValueError("Labels and logits cannot be None.")
# Add label smoothing to loss computation
return tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True, label_smoothing=0.1)
def train_step(self, data):
inputs, labels = data
input_ids = inputs['input_ids']
attention_mask = inputs['attention_mask']
with tf.GradientTape() as tape:
logits, _ = self(inputs, training=True)
loss = self.compute_loss(labels, logits)
gradients = tape.gradient(loss, self.trainable_variables)
# Gradient clipping for stability
gradients = [tf.clip_by_value(g, -1.0, 1.0) for g in gradients]
self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
logits_for_metrics = tf.argmax(logits, axis=-1)
logits_for_metrics = tf.reshape(logits_for_metrics, [-1])
labels_for_metrics = tf.reshape(labels, [-1])
for metric in self.metrics:
metric.update_state(labels_for_metrics, logits_for_metrics)
return {m.name: m.result() for m in self.metrics}
def create_model(config):
model = MiniSunModel(config)
# Optimizer with weight decay and mixed precision training
optimizer = tf.keras.mixed_precision.LossScaleOptimizer(
tf.keras.optimizers.AdamW(learning_rate=config.learning_rate, weight_decay=config.weight_decay)
)
strategy = tf.distribute.get_strategy()
with strategy.scope():
model.compile(optimizer=optimizer,
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
return model
def cosine_annealing_with_warmup(step, config):
"""Learning rate schedule with warm-up and cosine annealing."""
warmup_steps = int(config.total_steps * config.warmup_ratio)
if step < warmup_steps:
return config.learning_rate * (step / warmup_steps)
else:
cos_step = step - warmup_steps
total_cos_steps = config.total_steps - warmup_steps
return 0.5 * config.learning_rate * (1 + tf.cos(tf.constant(np.pi) * cos_step / total_cos_steps))
def cosine_annealing_with_restarts(step, config):
"""Learning rate schedule with warm-up and cosine annealing with restarts."""
warmup_steps = int(config.total_steps * config.warmup_ratio)
current_cycle = step // config.restart_period
effective_step = step % config.restart_period
if effective_step < warmup_steps:
return config.learning_rate * (effective_step / warmup_steps)
else:
cos_step = effective_step - warmup_steps
total_cos_steps = config.restart_period - warmup_steps
return 0.5 * config.learning_rate * (1 + tf.cos(tf.constant(np.pi) * cos_step / total_cos_steps))
# Configuration
config = MiniSunConfig(l1_reg=1e-5, l2_reg=3e-4)
# Initialize model with improvements
model = create_model(config)
# Create LearningRateScheduler callbacks
lr_scheduler = tf.keras.callbacks.LearningRateScheduler(lambda step: cosine_annealing_with_warmup(step, config))
lr_scheduler_with_restarts = tf.keras.callbacks.LearningRateScheduler(lambda step: cosine_annealing_with_restarts(step, config)) |