finnstrom3693
commited on
Commit
•
5be2e91
1
Parent(s):
ccd2034
add adaptive tpu and gpu training
Browse files- modeling5.py +172 -0
modeling5.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @title Model Architecture
|
2 |
+
import tensorflow as tf
|
3 |
+
from tensorflow.keras import layers, activations, initializers
|
4 |
+
|
5 |
+
class MiniSunConfig:
|
6 |
+
def __init__(self, vocab_size=30522, max_position_embeddings=1024, hidden_size=512,
|
7 |
+
num_attention_heads=8, intermediate_size=2048, num_hidden_layers=8,
|
8 |
+
dropout_rate=0.1, weight_decay=0.01, learning_rate=1e-4, total_steps=2500,
|
9 |
+
warmup_steps=500, restart_period=500):
|
10 |
+
self.vocab_size = vocab_size
|
11 |
+
self.max_position_embeddings = max_position_embeddings
|
12 |
+
self.hidden_size = hidden_size
|
13 |
+
self.num_attention_heads = num_attention_heads
|
14 |
+
self.intermediate_size = intermediate_size
|
15 |
+
self.num_hidden_layers = num_hidden_layers
|
16 |
+
self.dropout_rate = dropout_rate
|
17 |
+
self.weight_decay = weight_decay
|
18 |
+
self.learning_rate = learning_rate
|
19 |
+
self.total_steps = total_steps
|
20 |
+
self.warmup_steps = warmup_steps
|
21 |
+
self.restart_period = restart_period
|
22 |
+
|
23 |
+
|
24 |
+
@tf.keras.utils.register_keras_serializable()
|
25 |
+
class MiniSunModel(tf.keras.Model):
|
26 |
+
def __init__(self, config):
|
27 |
+
super(MiniSunModel, self).__init__()
|
28 |
+
self.config = config
|
29 |
+
|
30 |
+
# Embedding layers for token and position
|
31 |
+
self.token_embedding = layers.Embedding(config.vocab_size, config.hidden_size)
|
32 |
+
self.position_embedding = layers.Embedding(config.max_position_embeddings, config.hidden_size)
|
33 |
+
|
34 |
+
# Initialize an empty list for decoder blocks
|
35 |
+
self.decoder_blocks = []
|
36 |
+
|
37 |
+
# Final normalization and head
|
38 |
+
self.layer_norm = layers.LayerNormalization(epsilon=1e-6)
|
39 |
+
self.lm_head = layers.Dense(config.vocab_size, kernel_initializer=initializers.he_normal())
|
40 |
+
|
41 |
+
def build(self, input_shape):
|
42 |
+
# Create transformer decoder blocks based on the model configuration
|
43 |
+
self.decoder_blocks = [self._build_decoder_block() for _ in range(self.config.num_hidden_layers)]
|
44 |
+
# Call the superclass's build method
|
45 |
+
super(MiniSunModel, self).build(input_shape)
|
46 |
+
|
47 |
+
def _build_decoder_block(self):
|
48 |
+
# Decoder block consisting of multi-head attention and feed-forward layers
|
49 |
+
return [
|
50 |
+
layers.MultiHeadAttention(num_heads=self.config.num_attention_heads, key_dim=self.config.hidden_size,
|
51 |
+
kernel_initializer=initializers.he_normal()),
|
52 |
+
layers.LayerNormalization(epsilon=1e-6),
|
53 |
+
layers.Dense(self.config.intermediate_size, activation=activations.elu,
|
54 |
+
kernel_initializer=initializers.he_normal()),
|
55 |
+
layers.Dense(self.config.hidden_size, kernel_initializer=initializers.he_normal()),
|
56 |
+
layers.Dropout(self.config.dropout_rate)
|
57 |
+
]
|
58 |
+
|
59 |
+
def call(self, inputs, attention_mask=None, training=False):
|
60 |
+
input_ids = inputs['input_ids']
|
61 |
+
position_ids = tf.range(start=0, limit=tf.shape(input_ids)[-1], delta=1)
|
62 |
+
|
63 |
+
# Token and position embeddings
|
64 |
+
embeddings = self.token_embedding(input_ids) + self.position_embedding(position_ids)
|
65 |
+
|
66 |
+
# Adjust attention mask to correct shape [batch_size, 1, 1, seq_len]
|
67 |
+
if attention_mask is not None:
|
68 |
+
attention_mask = tf.cast(attention_mask[:, tf.newaxis, tf.newaxis, :], dtype=tf.float32)
|
69 |
+
|
70 |
+
# Apply decoder blocks
|
71 |
+
hidden_states = embeddings
|
72 |
+
for mha, norm, ffn1, ffn2, dropout in self.decoder_blocks:
|
73 |
+
attn_output = mha(hidden_states, hidden_states, attention_mask=attention_mask, training=training)
|
74 |
+
attn_output = dropout(attn_output, training=training)
|
75 |
+
hidden_states = norm(attn_output + hidden_states) # Add & Norm
|
76 |
+
|
77 |
+
# Feed-forward layers
|
78 |
+
ffn_output = ffn1(hidden_states)
|
79 |
+
ffn_output = ffn2(ffn_output)
|
80 |
+
ffn_output = dropout(ffn_output, training=training)
|
81 |
+
hidden_states = norm(ffn_output + hidden_states) # Add & Norm
|
82 |
+
|
83 |
+
# Final layer normalization
|
84 |
+
hidden_states = self.layer_norm(hidden_states)
|
85 |
+
|
86 |
+
# LM Head for token generation
|
87 |
+
logits = self.lm_head(hidden_states)
|
88 |
+
return logits
|
89 |
+
|
90 |
+
def get_config(self):
|
91 |
+
# Return the configuration of the model
|
92 |
+
return {
|
93 |
+
'config': self.config.__dict__
|
94 |
+
}
|
95 |
+
|
96 |
+
@classmethod
|
97 |
+
def from_config(cls, config):
|
98 |
+
# Create an instance of the model from the config
|
99 |
+
return cls(MiniSunConfig(**config['config']))
|
100 |
+
|
101 |
+
def compute_loss(self, labels, logits):
|
102 |
+
"""Computes the loss between labels and logits."""
|
103 |
+
if labels is None or logits is None:
|
104 |
+
raise ValueError("Labels and logits cannot be None.")
|
105 |
+
return tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True)
|
106 |
+
|
107 |
+
def train_step(self, data):
|
108 |
+
inputs, labels = data
|
109 |
+
input_ids = inputs['input_ids']
|
110 |
+
attention_mask = inputs['attention_mask']
|
111 |
+
|
112 |
+
with tf.GradientTape() as tape:
|
113 |
+
logits = self(inputs, training=True)
|
114 |
+
loss = self.compute_loss(labels, logits)
|
115 |
+
|
116 |
+
gradients = tape.gradient(loss, self.trainable_variables)
|
117 |
+
self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
|
118 |
+
|
119 |
+
logits_for_metrics = tf.argmax(logits, axis=-1)
|
120 |
+
logits_for_metrics = tf.reshape(logits_for_metrics, [-1])
|
121 |
+
labels_for_metrics = tf.reshape(labels, [-1])
|
122 |
+
|
123 |
+
for metric in self.metrics:
|
124 |
+
metric.update_state(labels_for_metrics, logits_for_metrics)
|
125 |
+
|
126 |
+
return {m.name: m.result() for m in self.metrics}
|
127 |
+
|
128 |
+
def create_model(config):
|
129 |
+
model = MiniSunModel(config)
|
130 |
+
|
131 |
+
# Optimizer with weight decay
|
132 |
+
optimizer = tf.keras.optimizers.AdamW(learning_rate=config.learning_rate, weight_decay=config.weight_decay)
|
133 |
+
strategy = tf.distribute.get_strategy()
|
134 |
+
with strategy.scope():
|
135 |
+
model.compile(optimizer=optimizer,
|
136 |
+
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
|
137 |
+
metrics=['accuracy'])
|
138 |
+
return model
|
139 |
+
|
140 |
+
def cosine_annealing_with_warmup(step, config):
|
141 |
+
"""Learning rate schedule with warm-up and cosine annealing."""
|
142 |
+
warmup_steps = int(config.total_steps * config.warmup_steps)
|
143 |
+
if step < warmup_steps:
|
144 |
+
return config.learning_rate * (step / warmup_steps)
|
145 |
+
else:
|
146 |
+
cos_step = step - warmup_steps
|
147 |
+
total_cos_steps = config.total_steps - warmup_steps
|
148 |
+
return 0.5 * config.learning_rate * (1 + tf.cos(tf.constant(np.pi) * cos_step / total_cos_steps))
|
149 |
+
|
150 |
+
def cosine_annealing_with_restarts(step, config):
|
151 |
+
"""Learning rate schedule with warm-up and cosine annealing with restarts."""
|
152 |
+
warmup_steps = int(config.total_steps * config.warmup_steps)
|
153 |
+
|
154 |
+
current_cycle = step // config.restart_period
|
155 |
+
effective_step = step % config.restart_period
|
156 |
+
|
157 |
+
if effective_step < warmup_steps:
|
158 |
+
return config.learning_rate * (effective_step / warmup_steps)
|
159 |
+
else:
|
160 |
+
cos_step = effective_step - warmup_steps
|
161 |
+
total_cos_steps = restart_period - warmup_steps
|
162 |
+
return 0.5 * config.learning_rate * (1 + tf.cos(tf.constant(np.pi) * cos_step / total_cos_steps))
|
163 |
+
|
164 |
+
# Configuration
|
165 |
+
config = MiniSunConfig()
|
166 |
+
|
167 |
+
# Initialize model with He initialization
|
168 |
+
model = create_model(config)
|
169 |
+
|
170 |
+
# Create a LearningRateScheduler callback
|
171 |
+
lr_scheduler = tf.keras.callbacks.LearningRateScheduler(lambda step: cosine_annealing_with_warmup(step, config))
|
172 |
+
lr_scheduler_with_restarts = tf.keras.callbacks.LearningRateScheduler(lambda step: cosine_annealing_with_restarts(step, config))
|