File size: 6,524 Bytes
1e9934f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
# @title Model Architecture
import tensorflow as tf
from tensorflow.keras import layers, activations, initializers

class MiniSunConfig:
    def __init__(self, vocab_size=30522, max_position_embeddings=1024, hidden_size=512,
                 num_attention_heads=8, intermediate_size=2048, num_hidden_layers=8,
                 dropout_rate=0.1, weight_decay=0.01, learning_rate=1e-4):
        self.vocab_size = vocab_size
        self.max_position_embeddings = max_position_embeddings
        self.hidden_size = hidden_size
        self.num_attention_heads = num_attention_heads
        self.intermediate_size = intermediate_size
        self.num_hidden_layers = num_hidden_layers
        self.dropout_rate = dropout_rate
        self.weight_decay = weight_decay
        self.learning_rate = learning_rate

@tf.keras.utils.register_keras_serializable()
class MiniSunModel(tf.keras.Model):
    def __init__(self, config):
        super(MiniSunModel, self).__init__()
        self.config = config

        # Embedding layers for token and position
        self.token_embedding = layers.Embedding(config.vocab_size, config.hidden_size)
        self.position_embedding = layers.Embedding(config.max_position_embeddings, config.hidden_size)

        # Transformer decoder blocks
        self.decoder_blocks = [self._build_decoder_block() for _ in range(config.num_hidden_layers)]

        # Final normalization and head
        self.layer_norm = layers.LayerNormalization(epsilon=1e-6)
        self.lm_head = layers.Dense(config.vocab_size, kernel_initializer=initializers.he_normal())

    def _build_decoder_block(self):
        # Decoder block consisting of multi-head attention and feed-forward layers
        return [
            layers.MultiHeadAttention(num_heads=self.config.num_attention_heads, key_dim=self.config.hidden_size,
                                      kernel_initializer=initializers.he_normal()),
            layers.LayerNormalization(epsilon=1e-6),
            layers.Dense(self.config.intermediate_size, activation=activations.elu,
                         kernel_initializer=initializers.he_normal()),
            layers.Dense(self.config.hidden_size, kernel_initializer=initializers.he_normal()),
            layers.Dropout(self.config.dropout_rate)
        ]

    def call(self, inputs, attention_mask=None, training=False):
        input_ids = inputs['input_ids']
        position_ids = tf.range(start=0, limit=tf.shape(input_ids)[-1], delta=1)

        # Token and position embeddings
        embeddings = self.token_embedding(input_ids) + self.position_embedding(position_ids)

        # Adjust attention mask to correct shape [batch_size, 1, 1, seq_len]
        if attention_mask is not None:
            attention_mask = tf.cast(attention_mask[:, tf.newaxis, tf.newaxis, :], dtype=tf.float32)

        # Apply decoder blocks
        hidden_states = embeddings
        for mha, norm, ffn1, ffn2, dropout in self.decoder_blocks:
            attn_output = mha(hidden_states, hidden_states, attention_mask=attention_mask, training=training)
            attn_output = dropout(attn_output, training=training)
            hidden_states = norm(attn_output + hidden_states)  # Add & Norm

            # Feed-forward layers
            ffn_output = ffn1(hidden_states)
            ffn_output = ffn2(ffn_output)
            ffn_output = dropout(ffn_output, training=training)
            hidden_states = norm(ffn_output + hidden_states)  # Add & Norm

        # Final layer normalization
        hidden_states = self.layer_norm(hidden_states)

        # LM Head for token generation
        logits = self.lm_head(hidden_states)
        return logits

    def get_config(self):
        # Return the configuration of the model
        return {
            'config': self.config.__dict__
        }

    @classmethod
    def from_config(cls, config):
        # Create an instance of the model from the config
        return cls(MiniSunConfig(**config['config']))
    
    def compute_loss(self, labels, logits):
        """Computes the loss between labels and logits."""
        # Ensure labels and logits are not None
        if labels is None or logits is None:
            raise ValueError("Labels and logits cannot be None.")

        # Check for None values in nested structures if any
        # (e.g., if labels or logits are dictionaries or lists)
        # You might need to add specific checks based on your data structure

        # Calculate and return the loss
        return tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True)

    def train_step(self, data):
        # Unpack the data (expects a tuple: (inputs, labels))
        inputs, labels = data

        # Ensure inputs is a dictionary with input_ids and attention_mask
        input_ids = inputs['input_ids']
        attention_mask = inputs['attention_mask']

        with tf.GradientTape() as tape:
            # Forward pass
            logits = self(inputs, training=True)
            # Compute the loss using compute_loss
            loss = self.compute_loss(labels, logits)

        # Compute gradients
        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)

        # Update weights with optimizer
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))

        # Update metrics
        # Flatten logits and labels to match shapes for metric calculation
        logits_for_metrics = tf.argmax(logits, axis=-1)  # Get predicted token indices
        logits_for_metrics = tf.reshape(logits_for_metrics, [-1])  # [batch_size * sequence_length]
        labels_for_metrics = tf.reshape(labels, [-1])  # [batch_size * sequence_length]

        for metric in self.metrics:
            # Use reshaped logits and labels for metric update
            metric.update_state(labels_for_metrics, logits_for_metrics) 

        return {m.name: m.result() for m in self.metrics}


def create_model(config):
    model = MiniSunModel(config)

    # Optimizer with weight decay
    optimizer = tf.keras.optimizers.AdamW(learning_rate=config.learning_rate, weight_decay=config.weight_decay)

    # Compile model with ELU activation and smoother weight update process
    model.compile(
        optimizer=optimizer,
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=['accuracy']
    )
    return model

# Configuration
config = MiniSunConfig()

# Initialize model with He initialization
model = create_model(config)