finnstrom3693 commited on
Commit
1e9934f
1 Parent(s): 0e33744

Create modeling3.py

Browse files
Files changed (1) hide show
  1. modeling3.py +155 -0
modeling3.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @title Model Architecture
2
+ import tensorflow as tf
3
+ from tensorflow.keras import layers, activations, initializers
4
+
5
+ class MiniSunConfig:
6
+ def __init__(self, vocab_size=30522, max_position_embeddings=1024, hidden_size=512,
7
+ num_attention_heads=8, intermediate_size=2048, num_hidden_layers=8,
8
+ dropout_rate=0.1, weight_decay=0.01, learning_rate=1e-4):
9
+ self.vocab_size = vocab_size
10
+ self.max_position_embeddings = max_position_embeddings
11
+ self.hidden_size = hidden_size
12
+ self.num_attention_heads = num_attention_heads
13
+ self.intermediate_size = intermediate_size
14
+ self.num_hidden_layers = num_hidden_layers
15
+ self.dropout_rate = dropout_rate
16
+ self.weight_decay = weight_decay
17
+ self.learning_rate = learning_rate
18
+
19
+ @tf.keras.utils.register_keras_serializable()
20
+ class MiniSunModel(tf.keras.Model):
21
+ def __init__(self, config):
22
+ super(MiniSunModel, self).__init__()
23
+ self.config = config
24
+
25
+ # Embedding layers for token and position
26
+ self.token_embedding = layers.Embedding(config.vocab_size, config.hidden_size)
27
+ self.position_embedding = layers.Embedding(config.max_position_embeddings, config.hidden_size)
28
+
29
+ # Transformer decoder blocks
30
+ self.decoder_blocks = [self._build_decoder_block() for _ in range(config.num_hidden_layers)]
31
+
32
+ # Final normalization and head
33
+ self.layer_norm = layers.LayerNormalization(epsilon=1e-6)
34
+ self.lm_head = layers.Dense(config.vocab_size, kernel_initializer=initializers.he_normal())
35
+
36
+ def _build_decoder_block(self):
37
+ # Decoder block consisting of multi-head attention and feed-forward layers
38
+ return [
39
+ layers.MultiHeadAttention(num_heads=self.config.num_attention_heads, key_dim=self.config.hidden_size,
40
+ kernel_initializer=initializers.he_normal()),
41
+ layers.LayerNormalization(epsilon=1e-6),
42
+ layers.Dense(self.config.intermediate_size, activation=activations.elu,
43
+ kernel_initializer=initializers.he_normal()),
44
+ layers.Dense(self.config.hidden_size, kernel_initializer=initializers.he_normal()),
45
+ layers.Dropout(self.config.dropout_rate)
46
+ ]
47
+
48
+ def call(self, inputs, attention_mask=None, training=False):
49
+ input_ids = inputs['input_ids']
50
+ position_ids = tf.range(start=0, limit=tf.shape(input_ids)[-1], delta=1)
51
+
52
+ # Token and position embeddings
53
+ embeddings = self.token_embedding(input_ids) + self.position_embedding(position_ids)
54
+
55
+ # Adjust attention mask to correct shape [batch_size, 1, 1, seq_len]
56
+ if attention_mask is not None:
57
+ attention_mask = tf.cast(attention_mask[:, tf.newaxis, tf.newaxis, :], dtype=tf.float32)
58
+
59
+ # Apply decoder blocks
60
+ hidden_states = embeddings
61
+ for mha, norm, ffn1, ffn2, dropout in self.decoder_blocks:
62
+ attn_output = mha(hidden_states, hidden_states, attention_mask=attention_mask, training=training)
63
+ attn_output = dropout(attn_output, training=training)
64
+ hidden_states = norm(attn_output + hidden_states) # Add & Norm
65
+
66
+ # Feed-forward layers
67
+ ffn_output = ffn1(hidden_states)
68
+ ffn_output = ffn2(ffn_output)
69
+ ffn_output = dropout(ffn_output, training=training)
70
+ hidden_states = norm(ffn_output + hidden_states) # Add & Norm
71
+
72
+ # Final layer normalization
73
+ hidden_states = self.layer_norm(hidden_states)
74
+
75
+ # LM Head for token generation
76
+ logits = self.lm_head(hidden_states)
77
+ return logits
78
+
79
+ def get_config(self):
80
+ # Return the configuration of the model
81
+ return {
82
+ 'config': self.config.__dict__
83
+ }
84
+
85
+ @classmethod
86
+ def from_config(cls, config):
87
+ # Create an instance of the model from the config
88
+ return cls(MiniSunConfig(**config['config']))
89
+
90
+ def compute_loss(self, labels, logits):
91
+ """Computes the loss between labels and logits."""
92
+ # Ensure labels and logits are not None
93
+ if labels is None or logits is None:
94
+ raise ValueError("Labels and logits cannot be None.")
95
+
96
+ # Check for None values in nested structures if any
97
+ # (e.g., if labels or logits are dictionaries or lists)
98
+ # You might need to add specific checks based on your data structure
99
+
100
+ # Calculate and return the loss
101
+ return tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True)
102
+
103
+ def train_step(self, data):
104
+ # Unpack the data (expects a tuple: (inputs, labels))
105
+ inputs, labels = data
106
+
107
+ # Ensure inputs is a dictionary with input_ids and attention_mask
108
+ input_ids = inputs['input_ids']
109
+ attention_mask = inputs['attention_mask']
110
+
111
+ with tf.GradientTape() as tape:
112
+ # Forward pass
113
+ logits = self(inputs, training=True)
114
+ # Compute the loss using compute_loss
115
+ loss = self.compute_loss(labels, logits)
116
+
117
+ # Compute gradients
118
+ trainable_vars = self.trainable_variables
119
+ gradients = tape.gradient(loss, trainable_vars)
120
+
121
+ # Update weights with optimizer
122
+ self.optimizer.apply_gradients(zip(gradients, trainable_vars))
123
+
124
+ # Update metrics
125
+ # Flatten logits and labels to match shapes for metric calculation
126
+ logits_for_metrics = tf.argmax(logits, axis=-1) # Get predicted token indices
127
+ logits_for_metrics = tf.reshape(logits_for_metrics, [-1]) # [batch_size * sequence_length]
128
+ labels_for_metrics = tf.reshape(labels, [-1]) # [batch_size * sequence_length]
129
+
130
+ for metric in self.metrics:
131
+ # Use reshaped logits and labels for metric update
132
+ metric.update_state(labels_for_metrics, logits_for_metrics)
133
+
134
+ return {m.name: m.result() for m in self.metrics}
135
+
136
+
137
+ def create_model(config):
138
+ model = MiniSunModel(config)
139
+
140
+ # Optimizer with weight decay
141
+ optimizer = tf.keras.optimizers.AdamW(learning_rate=config.learning_rate, weight_decay=config.weight_decay)
142
+
143
+ # Compile model with ELU activation and smoother weight update process
144
+ model.compile(
145
+ optimizer=optimizer,
146
+ loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
147
+ metrics=['accuracy']
148
+ )
149
+ return model
150
+
151
+ # Configuration
152
+ config = MiniSunConfig()
153
+
154
+ # Initialize model with He initialization
155
+ model = create_model(config)