finnstrom3693 commited on
Commit
c42ab1b
1 Parent(s): 1e9934f

Create modeling4.py

Browse files
Files changed (1) hide show
  1. modeling4.py +172 -0
modeling4.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @title Model Architecture
2
+ import tensorflow as tf
3
+ from tensorflow.keras import layers, activations, initializers
4
+
5
+ class MiniSunConfig:
6
+ def __init__(self, vocab_size=30522, max_position_embeddings=1024, hidden_size=512,
7
+ num_attention_heads=8, intermediate_size=2048, num_hidden_layers=8,
8
+ dropout_rate=0.1, weight_decay=0.01, learning_rate=1e-4, total_steps=2500, warmup_steps=0.2):
9
+ self.vocab_size = vocab_size
10
+ self.max_position_embeddings = max_position_embeddings
11
+ self.hidden_size = hidden_size
12
+ self.num_attention_heads = num_attention_heads
13
+ self.intermediate_size = intermediate_size
14
+ self.num_hidden_layers = num_hidden_layers
15
+ self.dropout_rate = dropout_rate
16
+ self.weight_decay = weight_decay
17
+ self.learning_rate = learning_rate
18
+ self.total_steps = total_steps
19
+ self.warmup_steps = warmup_steps
20
+
21
+ @tf.keras.utils.register_keras_serializable()
22
+ class MiniSunModel(tf.keras.Model):
23
+ def __init__(self, config):
24
+ super(MiniSunModel, self).__init__()
25
+ self.config = config
26
+
27
+ # Embedding layers for token and position
28
+ self.token_embedding = layers.Embedding(config.vocab_size, config.hidden_size)
29
+ self.position_embedding = layers.Embedding(config.max_position_embeddings, config.hidden_size)
30
+
31
+ # Transformer decoder blocks
32
+ self.decoder_blocks = [self._build_decoder_block() for _ in range(config.num_hidden_layers)]
33
+
34
+ # Final normalization and head
35
+ self.layer_norm = layers.LayerNormalization(epsilon=1e-6)
36
+ self.lm_head = layers.Dense(config.vocab_size, kernel_initializer=initializers.he_normal())
37
+
38
+ def _build_decoder_block(self):
39
+ # Decoder block consisting of multi-head attention and feed-forward layers
40
+ return [
41
+ layers.MultiHeadAttention(num_heads=self.config.num_attention_heads, key_dim=self.config.hidden_size,
42
+ kernel_initializer=initializers.he_normal()),
43
+ layers.LayerNormalization(epsilon=1e-6),
44
+ layers.Dense(self.config.intermediate_size, activation=activations.elu,
45
+ kernel_initializer=initializers.he_normal()),
46
+ layers.Dense(self.config.hidden_size, kernel_initializer=initializers.he_normal()),
47
+ layers.Dropout(self.config.dropout_rate)
48
+ ]
49
+
50
+ def call(self, inputs, attention_mask=None, training=False):
51
+ input_ids = inputs['input_ids']
52
+ position_ids = tf.range(start=0, limit=tf.shape(input_ids)[-1], delta=1)
53
+
54
+ # Token and position embeddings
55
+ embeddings = self.token_embedding(input_ids) + self.position_embedding(position_ids)
56
+
57
+ # Adjust attention mask to correct shape [batch_size, 1, 1, seq_len]
58
+ if attention_mask is not None:
59
+ attention_mask = tf.cast(attention_mask[:, tf.newaxis, tf.newaxis, :], dtype=tf.float32)
60
+
61
+ # Apply decoder blocks
62
+ hidden_states = embeddings
63
+ for mha, norm, ffn1, ffn2, dropout in self.decoder_blocks:
64
+ attn_output = mha(hidden_states, hidden_states, attention_mask=attention_mask, training=training)
65
+ attn_output = dropout(attn_output, training=training)
66
+ hidden_states = norm(attn_output + hidden_states) # Add & Norm
67
+
68
+ # Feed-forward layers
69
+ ffn_output = ffn1(hidden_states)
70
+ ffn_output = ffn2(ffn_output)
71
+ ffn_output = dropout(ffn_output, training=training)
72
+ hidden_states = norm(ffn_output + hidden_states) # Add & Norm
73
+
74
+ # Final layer normalization
75
+ hidden_states = self.layer_norm(hidden_states)
76
+
77
+ # LM Head for token generation
78
+ logits = self.lm_head(hidden_states)
79
+ return logits
80
+
81
+ def get_config(self):
82
+ # Return the configuration of the model
83
+ return {
84
+ 'config': self.config.__dict__
85
+ }
86
+
87
+ @classmethod
88
+ def from_config(cls, config):
89
+ # Create an instance of the model from the config
90
+ return cls(MiniSunConfig(**config['config']))
91
+
92
+ def compute_loss(self, labels, logits):
93
+ """Computes the loss between labels and logits."""
94
+ # Ensure labels and logits are not None
95
+ if labels is None or logits is None:
96
+ raise ValueError("Labels and logits cannot be None.")
97
+ return tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True)
98
+
99
+ def train_step(self, data):
100
+ inputs, labels = data
101
+ input_ids = inputs['input_ids']
102
+ attention_mask = inputs['attention_mask']
103
+
104
+ with tf.GradientTape() as tape:
105
+ logits = self(inputs, training=True)
106
+ loss = self.compute_loss(labels, logits)
107
+
108
+ gradients = tape.gradient(loss, self.trainable_variables)
109
+ self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
110
+
111
+ logits_for_metrics = tf.argmax(logits, axis=-1)
112
+ logits_for_metrics = tf.reshape(logits_for_metrics, [-1])
113
+ labels_for_metrics = tf.reshape(labels, [-1])
114
+
115
+ for metric in self.metrics:
116
+ metric.update_state(labels_for_metrics, logits_for_metrics)
117
+
118
+ return {m.name: m.result() for m in self.metrics}
119
+
120
+
121
+ def create_model(config):
122
+ model = MiniSunModel(config)
123
+
124
+ # Optimizer with weight decay
125
+ optimizer = tf.keras.optimizers.AdamW(learning_rate=config.learning_rate, weight_decay=config.weight_decay)
126
+
127
+ model.compile(
128
+ optimizer=optimizer,
129
+ loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
130
+ metrics=['accuracy']
131
+ )
132
+ return model
133
+
134
+ def cosine_annealing_with_warmup(step, config):
135
+ """Learning rate schedule with warm-up and cosine annealing."""
136
+ warmup_steps = int(config.total_steps * config.warmup_steps)
137
+ if step < warmup_steps:
138
+ return config.learning_rate * (step / warmup_steps)
139
+ else:
140
+ # Calculate the cosine decay
141
+ cos_step = step - warmup_steps
142
+ total_cos_steps = config.total_steps - warmup_steps
143
+ return 0.5 * config.learning_rate * (1 + tf.cos(tf.constant(np.pi) * cos_step / total_cos_steps))
144
+
145
+ def cosine_annealing_with_restarts(step, config, restart_period, cycle_num):
146
+ """Learning rate schedule with warm-up and cosine annealing with restarts."""
147
+ warmup_steps = int(config.total_steps * config.warmup_steps)
148
+
149
+ # Determine the current cycle based on step and restart_period
150
+ current_cycle = step // restart_period
151
+
152
+ # Calculate the effective step within the current cycle
153
+ effective_step = step % restart_period
154
+
155
+ if effective_step < warmup_steps:
156
+ return config.learning_rate * (effective_step / warmup_steps)
157
+ else:
158
+ # Calculate the cosine decay within the current cycle
159
+ cos_step = effective_step - warmup_steps
160
+ total_cos_steps = restart_period - warmup_steps
161
+ return 0.5 * config.learning_rate * (1 + tf.cos(tf.constant(np.pi) * cos_step / total_cos_steps))
162
+
163
+
164
+ # Configuration
165
+ config = MiniSunConfig()
166
+
167
+ # Initialize model with He initialization
168
+ model = create_model(config)
169
+
170
+ # Create a LearningRateScheduler callback
171
+ lr_scheduler = tf.keras.callbacks.LearningRateScheduler(lambda step: cosine_annealing_with_warmup(step, config))
172
+ #lr_scheduler_with_restarts = tf.keras.callbacks.LearningRateScheduler(lambda step: cosine_annealing_with_restarts(step, config, restart_period=1000, cycle_num=1))