File size: 6,524 Bytes
1e9934f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
# @title Model Architecture
import tensorflow as tf
from tensorflow.keras import layers, activations, initializers
class MiniSunConfig:
def __init__(self, vocab_size=30522, max_position_embeddings=1024, hidden_size=512,
num_attention_heads=8, intermediate_size=2048, num_hidden_layers=8,
dropout_rate=0.1, weight_decay=0.01, learning_rate=1e-4):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.dropout_rate = dropout_rate
self.weight_decay = weight_decay
self.learning_rate = learning_rate
@tf.keras.utils.register_keras_serializable()
class MiniSunModel(tf.keras.Model):
def __init__(self, config):
super(MiniSunModel, self).__init__()
self.config = config
# Embedding layers for token and position
self.token_embedding = layers.Embedding(config.vocab_size, config.hidden_size)
self.position_embedding = layers.Embedding(config.max_position_embeddings, config.hidden_size)
# Transformer decoder blocks
self.decoder_blocks = [self._build_decoder_block() for _ in range(config.num_hidden_layers)]
# Final normalization and head
self.layer_norm = layers.LayerNormalization(epsilon=1e-6)
self.lm_head = layers.Dense(config.vocab_size, kernel_initializer=initializers.he_normal())
def _build_decoder_block(self):
# Decoder block consisting of multi-head attention and feed-forward layers
return [
layers.MultiHeadAttention(num_heads=self.config.num_attention_heads, key_dim=self.config.hidden_size,
kernel_initializer=initializers.he_normal()),
layers.LayerNormalization(epsilon=1e-6),
layers.Dense(self.config.intermediate_size, activation=activations.elu,
kernel_initializer=initializers.he_normal()),
layers.Dense(self.config.hidden_size, kernel_initializer=initializers.he_normal()),
layers.Dropout(self.config.dropout_rate)
]
def call(self, inputs, attention_mask=None, training=False):
input_ids = inputs['input_ids']
position_ids = tf.range(start=0, limit=tf.shape(input_ids)[-1], delta=1)
# Token and position embeddings
embeddings = self.token_embedding(input_ids) + self.position_embedding(position_ids)
# Adjust attention mask to correct shape [batch_size, 1, 1, seq_len]
if attention_mask is not None:
attention_mask = tf.cast(attention_mask[:, tf.newaxis, tf.newaxis, :], dtype=tf.float32)
# Apply decoder blocks
hidden_states = embeddings
for mha, norm, ffn1, ffn2, dropout in self.decoder_blocks:
attn_output = mha(hidden_states, hidden_states, attention_mask=attention_mask, training=training)
attn_output = dropout(attn_output, training=training)
hidden_states = norm(attn_output + hidden_states) # Add & Norm
# Feed-forward layers
ffn_output = ffn1(hidden_states)
ffn_output = ffn2(ffn_output)
ffn_output = dropout(ffn_output, training=training)
hidden_states = norm(ffn_output + hidden_states) # Add & Norm
# Final layer normalization
hidden_states = self.layer_norm(hidden_states)
# LM Head for token generation
logits = self.lm_head(hidden_states)
return logits
def get_config(self):
# Return the configuration of the model
return {
'config': self.config.__dict__
}
@classmethod
def from_config(cls, config):
# Create an instance of the model from the config
return cls(MiniSunConfig(**config['config']))
def compute_loss(self, labels, logits):
"""Computes the loss between labels and logits."""
# Ensure labels and logits are not None
if labels is None or logits is None:
raise ValueError("Labels and logits cannot be None.")
# Check for None values in nested structures if any
# (e.g., if labels or logits are dictionaries or lists)
# You might need to add specific checks based on your data structure
# Calculate and return the loss
return tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True)
def train_step(self, data):
# Unpack the data (expects a tuple: (inputs, labels))
inputs, labels = data
# Ensure inputs is a dictionary with input_ids and attention_mask
input_ids = inputs['input_ids']
attention_mask = inputs['attention_mask']
with tf.GradientTape() as tape:
# Forward pass
logits = self(inputs, training=True)
# Compute the loss using compute_loss
loss = self.compute_loss(labels, logits)
# Compute gradients
trainable_vars = self.trainable_variables
gradients = tape.gradient(loss, trainable_vars)
# Update weights with optimizer
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
# Update metrics
# Flatten logits and labels to match shapes for metric calculation
logits_for_metrics = tf.argmax(logits, axis=-1) # Get predicted token indices
logits_for_metrics = tf.reshape(logits_for_metrics, [-1]) # [batch_size * sequence_length]
labels_for_metrics = tf.reshape(labels, [-1]) # [batch_size * sequence_length]
for metric in self.metrics:
# Use reshaped logits and labels for metric update
metric.update_state(labels_for_metrics, logits_for_metrics)
return {m.name: m.result() for m in self.metrics}
def create_model(config):
model = MiniSunModel(config)
# Optimizer with weight decay
optimizer = tf.keras.optimizers.AdamW(learning_rate=config.learning_rate, weight_decay=config.weight_decay)
# Compile model with ELU activation and smoother weight update process
model.compile(
optimizer=optimizer,
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy']
)
return model
# Configuration
config = MiniSunConfig()
# Initialize model with He initialization
model = create_model(config) |