finnstrom3693 commited on
Commit
6aeb9de
1 Parent(s): ad1f612

using rms norm

Browse files
Files changed (1) hide show
  1. modeling-dev2.py +194 -0
modeling-dev2.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from tensorflow.keras import layers, activations, initializers, regularizers
3
+ import numpy as np
4
+
5
+ # Define RMSNorm
6
+ class RMSNorm(tf.keras.layers.Layer):
7
+ def __init__(self, epsilon=1e-6):
8
+ super(RMSNorm, self).__init__()
9
+ self.epsilon = epsilon
10
+
11
+ def call(self, inputs):
12
+ # Calculate the RMS and normalize the input
13
+ rms = tf.sqrt(tf.reduce_mean(tf.square(inputs), axis=-1, keepdims=True))
14
+ return inputs / (rms + self.epsilon)
15
+
16
+ class MiniSunConfig:
17
+ def __init__(self, vocab_size=30522, max_position_embeddings=1024, hidden_size=512,
18
+ num_attention_heads=8, intermediate_size=2048, num_hidden_layers=8,
19
+ dropout_rate=0.1, weight_decay=0.01, learning_rate=1e-4, total_steps=2500,
20
+ warmup_ratio=0.5, restart_period=500, l1_reg=0.0, l2_reg=0.01):
21
+ self.vocab_size = vocab_size
22
+ self.max_position_embeddings = max_position_embeddings
23
+ self.hidden_size = hidden_size
24
+ self.num_attention_heads = num_attention_heads
25
+ self.intermediate_size = intermediate_size
26
+ self.num_hidden_layers = num_hidden_layers
27
+ self.dropout_rate = dropout_rate
28
+ self.weight_decay = weight_decay
29
+ self.learning_rate = learning_rate
30
+ self.total_steps = total_steps
31
+ self.warmup_ratio = warmup_ratio
32
+ self.restart_period = restart_period
33
+ self.l1_reg = l1_reg # L1 regularization strength
34
+ self.l2_reg = l2_reg # L2 regularization strength
35
+
36
+ @tf.keras.utils.register_keras_serializable()
37
+ class MiniSunModel(tf.keras.Model):
38
+ def __init__(self, config):
39
+ super(MiniSunModel, self).__init__()
40
+ self.config = config
41
+
42
+ # Embedding layers for token and dynamic positional embeddings (RoPE)
43
+ self.token_embedding = layers.Embedding(config.vocab_size, config.hidden_size)
44
+ self.position_embedding = layers.Embedding(config.max_position_embeddings, config.hidden_size)
45
+
46
+ # Initialize an empty list for decoder blocks
47
+ self.decoder_blocks = []
48
+
49
+ # Final normalization and head
50
+ self.layer_norm = RMSNorm(epsilon=1e-6)
51
+ self.lm_head = layers.Dense(config.vocab_size, kernel_initializer=initializers.he_normal(),
52
+ kernel_regularizer=regularizers.l2(config.l2_reg))
53
+
54
+ # Stochastic depth (layer drop)
55
+ self.layer_dropout = tf.keras.layers.Dropout(config.dropout_rate)
56
+
57
+ def build(self, input_shape):
58
+ # Create transformer decoder blocks based on the model configuration
59
+ self.decoder_blocks = [self._build_decoder_block() for _ in range(self.config.num_hidden_layers)]
60
+ super(MiniSunModel, self).build(input_shape)
61
+
62
+ def _build_decoder_block(self):
63
+ # Decoder block with multi-query attention and feed-forward layers, using RMSNorm and regularization
64
+ return [
65
+ layers.MultiHeadAttention(num_heads=self.config.num_attention_heads, key_dim=self.config.hidden_size,
66
+ kernel_initializer=initializers.he_normal(),
67
+ kernel_regularizer=regularizers.l2(self.config.l2_reg)),
68
+ RMSNorm(epsilon=1e-6), # Use RMSNorm instead of LayerNormalization
69
+ layers.Dense(self.config.intermediate_size, activation=activations.elu,
70
+ kernel_initializer=initializers.he_normal(),
71
+ kernel_regularizer=regularizers.l1_l2(self.config.l1_reg, self.config.l2_reg)),
72
+ layers.Dense(self.config.hidden_size, kernel_initializer=initializers.he_normal(),
73
+ kernel_regularizer=regularizers.l1_l2(self.config.l1_reg, self.config.l2_reg)),
74
+ layers.Dropout(self.config.dropout_rate)
75
+ ]
76
+
77
+ def call(self, inputs, attention_mask=None, training=False):
78
+ input_ids = inputs['input_ids']
79
+ position_ids = tf.range(start=0, limit=tf.shape(input_ids)[-1], delta=1)
80
+
81
+ # Token and position embeddings with RoPE (Rotary Positional Embeddings)
82
+ embeddings = self.token_embedding(input_ids) + self.position_embedding(position_ids)
83
+
84
+ # Adjust attention mask to correct shape [batch_size, 1, 1, seq_len]
85
+ if attention_mask is not None:
86
+ attention_mask = tf.cast(attention_mask[:, tf.newaxis, tf.newaxis, :], dtype=tf.float32)
87
+
88
+ # Apply decoder blocks with stochastic depth and gradient clipping
89
+ hidden_states = embeddings
90
+ for mha, norm, ffn1, ffn2, dropout in self.decoder_blocks:
91
+ attn_output = mha(hidden_states, hidden_states, attention_mask=attention_mask, training=training)
92
+ attn_output = dropout(attn_output, training=training)
93
+ hidden_states = norm(attn_output + hidden_states) # Add & RMSNorm
94
+
95
+ # Feed-forward layers
96
+ ffn_output = ffn1(hidden_states)
97
+ ffn_output = ffn2(ffn_output)
98
+ ffn_output = dropout(ffn_output, training=training)
99
+ hidden_states = norm(ffn_output + hidden_states) # Add & RMSNorm
100
+
101
+ # Final layer normalization
102
+ hidden_states = self.layer_norm(hidden_states)
103
+
104
+ # LM Head for token generation
105
+ logits = self.lm_head(hidden_states)
106
+
107
+ # Softmax layer for confidence
108
+ softmax_output = tf.nn.softmax(logits, axis=-1)
109
+
110
+ return logits, softmax_output
111
+
112
+ def get_config(self):
113
+ return {'config': self.config.__dict__}
114
+
115
+ @classmethod
116
+ def from_config(cls, config):
117
+ return cls(MiniSunConfig(**config['config']))
118
+
119
+ def compute_loss(self, labels, logits):
120
+ if labels is None or logits is None:
121
+ raise ValueError("Labels and logits cannot be None.")
122
+ # Add label smoothing to loss computation
123
+ return tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True, label_smoothing=0.1)
124
+
125
+ def train_step(self, data):
126
+ inputs, labels = data
127
+ input_ids = inputs['input_ids']
128
+ attention_mask = inputs['attention_mask']
129
+
130
+ with tf.GradientTape() as tape:
131
+ logits, _ = self(inputs, training=True)
132
+ loss = self.compute_loss(labels, logits)
133
+
134
+ gradients = tape.gradient(loss, self.trainable_variables)
135
+
136
+ # Gradient clipping for stability
137
+ gradients = [tf.clip_by_value(g, -1.0, 1.0) for g in gradients]
138
+ self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
139
+
140
+ logits_for_metrics = tf.argmax(logits, axis=-1)
141
+ logits_for_metrics = tf.reshape(logits_for_metrics, [-1])
142
+ labels_for_metrics = tf.reshape(labels, [-1])
143
+
144
+ for metric in self.metrics:
145
+ metric.update_state(labels_for_metrics, logits_for_metrics)
146
+
147
+ return {m.name: m.result() for m in self.metrics}
148
+
149
+ def create_model(config):
150
+ model = MiniSunModel(config)
151
+
152
+ # Optimizer with weight decay and mixed precision training
153
+ optimizer = tf.keras.mixed_precision.LossScaleOptimizer(
154
+ tf.keras.optimizers.AdamW(learning_rate=config.learning_rate, weight_decay=config.weight_decay)
155
+ )
156
+ strategy = tf.distribute.get_strategy()
157
+ with strategy.scope():
158
+ model.compile(optimizer=optimizer,
159
+ loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
160
+ metrics=['accuracy'])
161
+ return model
162
+
163
+ def cosine_annealing_with_warmup(step, config):
164
+ """Learning rate schedule with warm-up and cosine annealing."""
165
+ warmup_steps = int(config.total_steps * config.warmup_ratio)
166
+ if step < warmup_steps:
167
+ return config.learning_rate * (step / warmup_steps)
168
+ else:
169
+ cos_step = step - warmup_steps
170
+ total_cos_steps = config.total_steps - warmup_steps
171
+ return 0.5 * config.learning_rate * (1 + tf.cos(tf.constant(np.pi) * cos_step / total_cos_steps))
172
+
173
+ def cosine_annealing_with_restarts(step, config):
174
+ """Learning rate schedule with warm-up and cosine annealing with restarts."""
175
+ warmup_steps = int(config.total_steps * config.warmup_ratio)
176
+ current_cycle = step // config.restart_period
177
+ effective_step = step % config.restart_period
178
+
179
+ if effective_step < warmup_steps:
180
+ return config.learning_rate * (effective_step / warmup_steps)
181
+ else:
182
+ cos_step = effective_step - warmup_steps
183
+ total_cos_steps = config.restart_period - warmup_steps
184
+ return 0.5 * config.learning_rate * (1 + tf.cos(tf.constant(np.pi) * cos_step / total_cos_steps))
185
+
186
+ # Configuration
187
+ config = MiniSunConfig(l1_reg=1e-5, l2_reg=3e-4)
188
+
189
+ # Initialize model with improvements
190
+ model = create_model(config)
191
+
192
+ # Create LearningRateScheduler callbacks
193
+ lr_scheduler = tf.keras.callbacks.LearningRateScheduler(lambda step: cosine_annealing_with_warmup(step, config))
194
+ lr_scheduler_with_restarts = tf.keras.callbacks.LearningRateScheduler(lambda step: cosine_annealing_with_restarts(step, config))