File size: 7,426 Bytes
c42ab1b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
# @title Model Architecture
import tensorflow as tf
from tensorflow.keras import layers, activations, initializers

class MiniSunConfig:
    def __init__(self, vocab_size=30522, max_position_embeddings=1024, hidden_size=512,
                 num_attention_heads=8, intermediate_size=2048, num_hidden_layers=8,
                 dropout_rate=0.1, weight_decay=0.01, learning_rate=1e-4, total_steps=2500, warmup_steps=0.2):
        self.vocab_size = vocab_size
        self.max_position_embeddings = max_position_embeddings
        self.hidden_size = hidden_size
        self.num_attention_heads = num_attention_heads
        self.intermediate_size = intermediate_size
        self.num_hidden_layers = num_hidden_layers
        self.dropout_rate = dropout_rate
        self.weight_decay = weight_decay
        self.learning_rate = learning_rate
        self.total_steps = total_steps
        self.warmup_steps = warmup_steps

@tf.keras.utils.register_keras_serializable()
class MiniSunModel(tf.keras.Model):
    def __init__(self, config):
        super(MiniSunModel, self).__init__()
        self.config = config

        # Embedding layers for token and position
        self.token_embedding = layers.Embedding(config.vocab_size, config.hidden_size)
        self.position_embedding = layers.Embedding(config.max_position_embeddings, config.hidden_size)

        # Transformer decoder blocks
        self.decoder_blocks = [self._build_decoder_block() for _ in range(config.num_hidden_layers)]

        # Final normalization and head
        self.layer_norm = layers.LayerNormalization(epsilon=1e-6)
        self.lm_head = layers.Dense(config.vocab_size, kernel_initializer=initializers.he_normal())

    def _build_decoder_block(self):
        # Decoder block consisting of multi-head attention and feed-forward layers
        return [
            layers.MultiHeadAttention(num_heads=self.config.num_attention_heads, key_dim=self.config.hidden_size,
                                      kernel_initializer=initializers.he_normal()),
            layers.LayerNormalization(epsilon=1e-6),
            layers.Dense(self.config.intermediate_size, activation=activations.elu,
                         kernel_initializer=initializers.he_normal()),
            layers.Dense(self.config.hidden_size, kernel_initializer=initializers.he_normal()),
            layers.Dropout(self.config.dropout_rate)
        ]

    def call(self, inputs, attention_mask=None, training=False):
        input_ids = inputs['input_ids']
        position_ids = tf.range(start=0, limit=tf.shape(input_ids)[-1], delta=1)

        # Token and position embeddings
        embeddings = self.token_embedding(input_ids) + self.position_embedding(position_ids)

        # Adjust attention mask to correct shape [batch_size, 1, 1, seq_len]
        if attention_mask is not None:
            attention_mask = tf.cast(attention_mask[:, tf.newaxis, tf.newaxis, :], dtype=tf.float32)

        # Apply decoder blocks
        hidden_states = embeddings
        for mha, norm, ffn1, ffn2, dropout in self.decoder_blocks:
            attn_output = mha(hidden_states, hidden_states, attention_mask=attention_mask, training=training)
            attn_output = dropout(attn_output, training=training)
            hidden_states = norm(attn_output + hidden_states)  # Add & Norm

            # Feed-forward layers
            ffn_output = ffn1(hidden_states)
            ffn_output = ffn2(ffn_output)
            ffn_output = dropout(ffn_output, training=training)
            hidden_states = norm(ffn_output + hidden_states)  # Add & Norm

        # Final layer normalization
        hidden_states = self.layer_norm(hidden_states)

        # LM Head for token generation
        logits = self.lm_head(hidden_states)
        return logits

    def get_config(self):
        # Return the configuration of the model
        return {
            'config': self.config.__dict__
        }

    @classmethod
    def from_config(cls, config):
        # Create an instance of the model from the config
        return cls(MiniSunConfig(**config['config']))
    
    def compute_loss(self, labels, logits):
        """Computes the loss between labels and logits."""
        # Ensure labels and logits are not None
        if labels is None or logits is None:
            raise ValueError("Labels and logits cannot be None.")
        return tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True)

    def train_step(self, data):
        inputs, labels = data
        input_ids = inputs['input_ids']
        attention_mask = inputs['attention_mask']

        with tf.GradientTape() as tape:
            logits = self(inputs, training=True)
            loss = self.compute_loss(labels, logits)

        gradients = tape.gradient(loss, self.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))

        logits_for_metrics = tf.argmax(logits, axis=-1)
        logits_for_metrics = tf.reshape(logits_for_metrics, [-1])
        labels_for_metrics = tf.reshape(labels, [-1])

        for metric in self.metrics:
            metric.update_state(labels_for_metrics, logits_for_metrics) 

        return {m.name: m.result() for m in self.metrics}


def create_model(config):
    model = MiniSunModel(config)

    # Optimizer with weight decay
    optimizer = tf.keras.optimizers.AdamW(learning_rate=config.learning_rate, weight_decay=config.weight_decay)

    model.compile(
        optimizer=optimizer,
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=['accuracy']
    )
    return model

def cosine_annealing_with_warmup(step, config):
    """Learning rate schedule with warm-up and cosine annealing."""
    warmup_steps = int(config.total_steps * config.warmup_steps)
    if step < warmup_steps:
        return config.learning_rate * (step / warmup_steps)
    else:
        # Calculate the cosine decay
        cos_step = step - warmup_steps
        total_cos_steps = config.total_steps - warmup_steps
        return 0.5 * config.learning_rate * (1 + tf.cos(tf.constant(np.pi) * cos_step / total_cos_steps))

def cosine_annealing_with_restarts(step, config, restart_period, cycle_num):
    """Learning rate schedule with warm-up and cosine annealing with restarts."""
    warmup_steps = int(config.total_steps * config.warmup_steps)
    
    # Determine the current cycle based on step and restart_period
    current_cycle = step // restart_period

    # Calculate the effective step within the current cycle
    effective_step = step % restart_period

    if effective_step < warmup_steps:
        return config.learning_rate * (effective_step / warmup_steps)
    else:
        # Calculate the cosine decay within the current cycle
        cos_step = effective_step - warmup_steps
        total_cos_steps = restart_period - warmup_steps
        return 0.5 * config.learning_rate * (1 + tf.cos(tf.constant(np.pi) * cos_step / total_cos_steps))


# Configuration
config = MiniSunConfig()

# Initialize model with He initialization
model = create_model(config)

# Create a LearningRateScheduler callback
lr_scheduler = tf.keras.callbacks.LearningRateScheduler(lambda step: cosine_annealing_with_warmup(step, config))
#lr_scheduler_with_restarts = tf.keras.callbacks.LearningRateScheduler(lambda step: cosine_annealing_with_restarts(step, config, restart_period=1000, cycle_num=1))