finnstrom3693 commited on
Commit
5be2e91
1 Parent(s): ccd2034

add adaptive tpu and gpu training

Browse files
Files changed (1) hide show
  1. modeling5.py +172 -0
modeling5.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @title Model Architecture
2
+ import tensorflow as tf
3
+ from tensorflow.keras import layers, activations, initializers
4
+
5
+ class MiniSunConfig:
6
+ def __init__(self, vocab_size=30522, max_position_embeddings=1024, hidden_size=512,
7
+ num_attention_heads=8, intermediate_size=2048, num_hidden_layers=8,
8
+ dropout_rate=0.1, weight_decay=0.01, learning_rate=1e-4, total_steps=2500,
9
+ warmup_steps=500, restart_period=500):
10
+ self.vocab_size = vocab_size
11
+ self.max_position_embeddings = max_position_embeddings
12
+ self.hidden_size = hidden_size
13
+ self.num_attention_heads = num_attention_heads
14
+ self.intermediate_size = intermediate_size
15
+ self.num_hidden_layers = num_hidden_layers
16
+ self.dropout_rate = dropout_rate
17
+ self.weight_decay = weight_decay
18
+ self.learning_rate = learning_rate
19
+ self.total_steps = total_steps
20
+ self.warmup_steps = warmup_steps
21
+ self.restart_period = restart_period
22
+
23
+
24
+ @tf.keras.utils.register_keras_serializable()
25
+ class MiniSunModel(tf.keras.Model):
26
+ def __init__(self, config):
27
+ super(MiniSunModel, self).__init__()
28
+ self.config = config
29
+
30
+ # Embedding layers for token and position
31
+ self.token_embedding = layers.Embedding(config.vocab_size, config.hidden_size)
32
+ self.position_embedding = layers.Embedding(config.max_position_embeddings, config.hidden_size)
33
+
34
+ # Initialize an empty list for decoder blocks
35
+ self.decoder_blocks = []
36
+
37
+ # Final normalization and head
38
+ self.layer_norm = layers.LayerNormalization(epsilon=1e-6)
39
+ self.lm_head = layers.Dense(config.vocab_size, kernel_initializer=initializers.he_normal())
40
+
41
+ def build(self, input_shape):
42
+ # Create transformer decoder blocks based on the model configuration
43
+ self.decoder_blocks = [self._build_decoder_block() for _ in range(self.config.num_hidden_layers)]
44
+ # Call the superclass's build method
45
+ super(MiniSunModel, self).build(input_shape)
46
+
47
+ def _build_decoder_block(self):
48
+ # Decoder block consisting of multi-head attention and feed-forward layers
49
+ return [
50
+ layers.MultiHeadAttention(num_heads=self.config.num_attention_heads, key_dim=self.config.hidden_size,
51
+ kernel_initializer=initializers.he_normal()),
52
+ layers.LayerNormalization(epsilon=1e-6),
53
+ layers.Dense(self.config.intermediate_size, activation=activations.elu,
54
+ kernel_initializer=initializers.he_normal()),
55
+ layers.Dense(self.config.hidden_size, kernel_initializer=initializers.he_normal()),
56
+ layers.Dropout(self.config.dropout_rate)
57
+ ]
58
+
59
+ def call(self, inputs, attention_mask=None, training=False):
60
+ input_ids = inputs['input_ids']
61
+ position_ids = tf.range(start=0, limit=tf.shape(input_ids)[-1], delta=1)
62
+
63
+ # Token and position embeddings
64
+ embeddings = self.token_embedding(input_ids) + self.position_embedding(position_ids)
65
+
66
+ # Adjust attention mask to correct shape [batch_size, 1, 1, seq_len]
67
+ if attention_mask is not None:
68
+ attention_mask = tf.cast(attention_mask[:, tf.newaxis, tf.newaxis, :], dtype=tf.float32)
69
+
70
+ # Apply decoder blocks
71
+ hidden_states = embeddings
72
+ for mha, norm, ffn1, ffn2, dropout in self.decoder_blocks:
73
+ attn_output = mha(hidden_states, hidden_states, attention_mask=attention_mask, training=training)
74
+ attn_output = dropout(attn_output, training=training)
75
+ hidden_states = norm(attn_output + hidden_states) # Add & Norm
76
+
77
+ # Feed-forward layers
78
+ ffn_output = ffn1(hidden_states)
79
+ ffn_output = ffn2(ffn_output)
80
+ ffn_output = dropout(ffn_output, training=training)
81
+ hidden_states = norm(ffn_output + hidden_states) # Add & Norm
82
+
83
+ # Final layer normalization
84
+ hidden_states = self.layer_norm(hidden_states)
85
+
86
+ # LM Head for token generation
87
+ logits = self.lm_head(hidden_states)
88
+ return logits
89
+
90
+ def get_config(self):
91
+ # Return the configuration of the model
92
+ return {
93
+ 'config': self.config.__dict__
94
+ }
95
+
96
+ @classmethod
97
+ def from_config(cls, config):
98
+ # Create an instance of the model from the config
99
+ return cls(MiniSunConfig(**config['config']))
100
+
101
+ def compute_loss(self, labels, logits):
102
+ """Computes the loss between labels and logits."""
103
+ if labels is None or logits is None:
104
+ raise ValueError("Labels and logits cannot be None.")
105
+ return tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True)
106
+
107
+ def train_step(self, data):
108
+ inputs, labels = data
109
+ input_ids = inputs['input_ids']
110
+ attention_mask = inputs['attention_mask']
111
+
112
+ with tf.GradientTape() as tape:
113
+ logits = self(inputs, training=True)
114
+ loss = self.compute_loss(labels, logits)
115
+
116
+ gradients = tape.gradient(loss, self.trainable_variables)
117
+ self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
118
+
119
+ logits_for_metrics = tf.argmax(logits, axis=-1)
120
+ logits_for_metrics = tf.reshape(logits_for_metrics, [-1])
121
+ labels_for_metrics = tf.reshape(labels, [-1])
122
+
123
+ for metric in self.metrics:
124
+ metric.update_state(labels_for_metrics, logits_for_metrics)
125
+
126
+ return {m.name: m.result() for m in self.metrics}
127
+
128
+ def create_model(config):
129
+ model = MiniSunModel(config)
130
+
131
+ # Optimizer with weight decay
132
+ optimizer = tf.keras.optimizers.AdamW(learning_rate=config.learning_rate, weight_decay=config.weight_decay)
133
+ strategy = tf.distribute.get_strategy()
134
+ with strategy.scope():
135
+ model.compile(optimizer=optimizer,
136
+ loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
137
+ metrics=['accuracy'])
138
+ return model
139
+
140
+ def cosine_annealing_with_warmup(step, config):
141
+ """Learning rate schedule with warm-up and cosine annealing."""
142
+ warmup_steps = int(config.total_steps * config.warmup_steps)
143
+ if step < warmup_steps:
144
+ return config.learning_rate * (step / warmup_steps)
145
+ else:
146
+ cos_step = step - warmup_steps
147
+ total_cos_steps = config.total_steps - warmup_steps
148
+ return 0.5 * config.learning_rate * (1 + tf.cos(tf.constant(np.pi) * cos_step / total_cos_steps))
149
+
150
+ def cosine_annealing_with_restarts(step, config):
151
+ """Learning rate schedule with warm-up and cosine annealing with restarts."""
152
+ warmup_steps = int(config.total_steps * config.warmup_steps)
153
+
154
+ current_cycle = step // config.restart_period
155
+ effective_step = step % config.restart_period
156
+
157
+ if effective_step < warmup_steps:
158
+ return config.learning_rate * (effective_step / warmup_steps)
159
+ else:
160
+ cos_step = effective_step - warmup_steps
161
+ total_cos_steps = restart_period - warmup_steps
162
+ return 0.5 * config.learning_rate * (1 + tf.cos(tf.constant(np.pi) * cos_step / total_cos_steps))
163
+
164
+ # Configuration
165
+ config = MiniSunConfig()
166
+
167
+ # Initialize model with He initialization
168
+ model = create_model(config)
169
+
170
+ # Create a LearningRateScheduler callback
171
+ lr_scheduler = tf.keras.callbacks.LearningRateScheduler(lambda step: cosine_annealing_with_warmup(step, config))
172
+ lr_scheduler_with_restarts = tf.keras.callbacks.LearningRateScheduler(lambda step: cosine_annealing_with_restarts(step, config))