finnstrom3693 commited on
Commit
78ae00b
1 Parent(s): 3181a38

test improvement

Browse files
Files changed (1) hide show
  1. modeling-dev.py +183 -0
modeling-dev.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from tensorflow.keras import layers, activations, initializers, regularizers
3
+ import numpy as np
4
+
5
+ class MiniSunConfig:
6
+ def __init__(self, vocab_size=30522, max_position_embeddings=1024, hidden_size=512,
7
+ num_attention_heads=8, intermediate_size=2048, num_hidden_layers=8,
8
+ dropout_rate=0.1, weight_decay=0.01, learning_rate=1e-4, total_steps=2500,
9
+ warmup_ratio=0.5, restart_period=500, l1_reg=0.0, l2_reg=0.01):
10
+ self.vocab_size = vocab_size
11
+ self.max_position_embeddings = max_position_embeddings
12
+ self.hidden_size = hidden_size
13
+ self.num_attention_heads = num_attention_heads
14
+ self.intermediate_size = intermediate_size
15
+ self.num_hidden_layers = num_hidden_layers
16
+ self.dropout_rate = dropout_rate
17
+ self.weight_decay = weight_decay
18
+ self.learning_rate = learning_rate
19
+ self.total_steps = total_steps
20
+ self.warmup_ratio = warmup_ratio
21
+ self.restart_period = restart_period
22
+ self.l1_reg = l1_reg # L1 regularization strength
23
+ self.l2_reg = l2_reg # L2 regularization strength
24
+
25
+ @tf.keras.utils.register_keras_serializable()
26
+ class MiniSunModel(tf.keras.Model):
27
+ def __init__(self, config):
28
+ super(MiniSunModel, self).__init__()
29
+ self.config = config
30
+
31
+ # Embedding layers for token and dynamic positional embeddings (RoPE)
32
+ self.token_embedding = layers.Embedding(config.vocab_size, config.hidden_size)
33
+ self.position_embedding = layers.Embedding(config.max_position_embeddings, config.hidden_size)
34
+
35
+ # Initialize an empty list for decoder blocks
36
+ self.decoder_blocks = []
37
+
38
+ # Final normalization and head
39
+ self.layer_norm = layers.LayerNormalization(epsilon=1e-6)
40
+ self.lm_head = layers.Dense(config.vocab_size, kernel_initializer=initializers.he_normal(),
41
+ kernel_regularizer=regularizers.l2(config.l2_reg))
42
+
43
+ # Stochastic depth (layer drop)
44
+ self.layer_dropout = tf.keras.layers.Dropout(config.dropout_rate)
45
+
46
+ def build(self, input_shape):
47
+ # Create transformer decoder blocks based on the model configuration
48
+ self.decoder_blocks = [self._build_decoder_block() for _ in range(self.config.num_hidden_layers)]
49
+ super(MiniSunModel, self).build(input_shape)
50
+
51
+ def _build_decoder_block(self):
52
+ # Decoder block with multi-query attention and feed-forward layers, using RMSNorm and regularization
53
+ return [
54
+ layers.MultiHeadAttention(num_heads=self.config.num_attention_heads, key_dim=self.config.hidden_size,
55
+ kernel_initializer=initializers.he_normal(),
56
+ kernel_regularizer=regularizers.l2(self.config.l2_reg)),
57
+ layers.LayerNormalization(epsilon=1e-6),
58
+ layers.Dense(self.config.intermediate_size, activation=activations.elu,
59
+ kernel_initializer=initializers.he_normal(),
60
+ kernel_regularizer=regularizers.l1_l2(self.config.l1_reg, self.config.l2_reg)),
61
+ layers.Dense(self.config.hidden_size, kernel_initializer=initializers.he_normal(),
62
+ kernel_regularizer=regularizers.l1_l2(self.config.l1_reg, self.config.l2_reg)),
63
+ layers.Dropout(self.config.dropout_rate)
64
+ ]
65
+
66
+ def call(self, inputs, attention_mask=None, training=False):
67
+ input_ids = inputs['input_ids']
68
+ position_ids = tf.range(start=0, limit=tf.shape(input_ids)[-1], delta=1)
69
+
70
+ # Token and position embeddings with RoPE (Rotary Positional Embeddings)
71
+ embeddings = self.token_embedding(input_ids) + self.position_embedding(position_ids)
72
+
73
+ # Adjust attention mask to correct shape [batch_size, 1, 1, seq_len]
74
+ if attention_mask is not None:
75
+ attention_mask = tf.cast(attention_mask[:, tf.newaxis, tf.newaxis, :], dtype=tf.float32)
76
+
77
+ # Apply decoder blocks with stochastic depth and gradient clipping
78
+ hidden_states = embeddings
79
+ for mha, norm, ffn1, ffn2, dropout in self.decoder_blocks:
80
+ attn_output = mha(hidden_states, hidden_states, attention_mask=attention_mask, training=training)
81
+ attn_output = dropout(attn_output, training=training)
82
+ hidden_states = norm(attn_output + hidden_states) # Add & Norm
83
+
84
+ # Feed-forward layers
85
+ ffn_output = ffn1(hidden_states)
86
+ ffn_output = ffn2(ffn_output)
87
+ ffn_output = dropout(ffn_output, training=training)
88
+ hidden_states = norm(ffn_output + hidden_states) # Add & Norm
89
+
90
+ # Final layer normalization
91
+ hidden_states = self.layer_norm(hidden_states)
92
+
93
+ # LM Head for token generation
94
+ logits = self.lm_head(hidden_states)
95
+
96
+ # Softmax layer for confidence
97
+ softmax_output = tf.nn.softmax(logits, axis=-1)
98
+
99
+ return logits, softmax_output
100
+
101
+ def get_config(self):
102
+ return {'config': self.config.__dict__}
103
+
104
+ @classmethod
105
+ def from_config(cls, config):
106
+ return cls(MiniSunConfig(**config['config']))
107
+
108
+ def compute_loss(self, labels, logits):
109
+ if labels is None or logits is None:
110
+ raise ValueError("Labels and logits cannot be None.")
111
+ # Add label smoothing to loss computation
112
+ return tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True, label_smoothing=0.1)
113
+
114
+ def train_step(self, data):
115
+ inputs, labels = data
116
+ input_ids = inputs['input_ids']
117
+ attention_mask = inputs['attention_mask']
118
+
119
+ with tf.GradientTape() as tape:
120
+ logits, _ = self(inputs, training=True)
121
+ loss = self.compute_loss(labels, logits)
122
+
123
+ gradients = tape.gradient(loss, self.trainable_variables)
124
+
125
+ # Gradient clipping for stability
126
+ gradients = [tf.clip_by_value(g, -1.0, 1.0) for g in gradients]
127
+ self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
128
+
129
+ logits_for_metrics = tf.argmax(logits, axis=-1)
130
+ logits_for_metrics = tf.reshape(logits_for_metrics, [-1])
131
+ labels_for_metrics = tf.reshape(labels, [-1])
132
+
133
+ for metric in self.metrics:
134
+ metric.update_state(labels_for_metrics, logits_for_metrics)
135
+
136
+ return {m.name: m.result() for m in self.metrics}
137
+
138
+ def create_model(config):
139
+ model = MiniSunModel(config)
140
+
141
+ # Optimizer with weight decay and mixed precision training
142
+ optimizer = tf.keras.mixed_precision.LossScaleOptimizer(
143
+ tf.keras.optimizers.AdamW(learning_rate=config.learning_rate, weight_decay=config.weight_decay)
144
+ )
145
+ strategy = tf.distribute.get_strategy()
146
+ with strategy.scope():
147
+ model.compile(optimizer=optimizer,
148
+ loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
149
+ metrics=['accuracy'])
150
+ return model
151
+
152
+ def cosine_annealing_with_warmup(step, config):
153
+ """Learning rate schedule with warm-up and cosine annealing."""
154
+ warmup_steps = int(config.total_steps * config.warmup_ratio)
155
+ if step < warmup_steps:
156
+ return config.learning_rate * (step / warmup_steps)
157
+ else:
158
+ cos_step = step - warmup_steps
159
+ total_cos_steps = config.total_steps - warmup_steps
160
+ return 0.5 * config.learning_rate * (1 + tf.cos(tf.constant(np.pi) * cos_step / total_cos_steps))
161
+
162
+ def cosine_annealing_with_restarts(step, config):
163
+ """Learning rate schedule with warm-up and cosine annealing with restarts."""
164
+ warmup_steps = int(config.total_steps * config.warmup_ratio)
165
+ current_cycle = step // config.restart_period
166
+ effective_step = step % config.restart_period
167
+
168
+ if effective_step < warmup_steps:
169
+ return config.learning_rate * (effective_step / warmup_steps)
170
+ else:
171
+ cos_step = effective_step - warmup_steps
172
+ total_cos_steps = config.restart_period - warmup_steps
173
+ return 0.5 * config.learning_rate * (1 + tf.cos(tf.constant(np.pi) * cos_step / total_cos_steps))
174
+
175
+ # Configuration
176
+ config = MiniSunConfig(l1_reg=1e-5, l2_reg=3e-4)
177
+
178
+ # Initialize model with improvements
179
+ model = create_model(config)
180
+
181
+ # Create LearningRateScheduler callbacks
182
+ lr_scheduler = tf.keras.callbacks.LearningRateScheduler(lambda step: cosine_annealing_with_warmup(step, config))
183
+ lr_scheduler_with_restarts = tf.keras.callbacks.LearningRateScheduler(lambda step: cosine_annealing_with_restarts(step, config))