finnstrom3693 commited on
Commit
d7cbcc5
1 Parent(s): 6aeb9de

Update modeling-dev2.py

Browse files
Files changed (1) hide show
  1. modeling-dev2.py +14 -8
modeling-dev2.py CHANGED
@@ -126,25 +126,31 @@ class MiniSunModel(tf.keras.Model):
126
  inputs, labels = data
127
  input_ids = inputs['input_ids']
128
  attention_mask = inputs['attention_mask']
129
-
130
  with tf.GradientTape() as tape:
131
  logits, _ = self(inputs, training=True)
132
  loss = self.compute_loss(labels, logits)
133
-
134
  gradients = tape.gradient(loss, self.trainable_variables)
135
-
136
  # Gradient clipping for stability
137
  gradients = [tf.clip_by_value(g, -1.0, 1.0) for g in gradients]
138
  self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
139
-
 
140
  logits_for_metrics = tf.argmax(logits, axis=-1)
141
- logits_for_metrics = tf.reshape(logits_for_metrics, [-1])
142
- labels_for_metrics = tf.reshape(labels, [-1])
143
-
144
  for metric in self.metrics:
145
  metric.update_state(labels_for_metrics, logits_for_metrics)
 
 
 
 
 
 
146
 
147
- return {m.name: m.result() for m in self.metrics}
148
 
149
  def create_model(config):
150
  model = MiniSunModel(config)
 
126
  inputs, labels = data
127
  input_ids = inputs['input_ids']
128
  attention_mask = inputs['attention_mask']
129
+
130
  with tf.GradientTape() as tape:
131
  logits, _ = self(inputs, training=True)
132
  loss = self.compute_loss(labels, logits)
133
+
134
  gradients = tape.gradient(loss, self.trainable_variables)
135
+
136
  # Gradient clipping for stability
137
  gradients = [tf.clip_by_value(g, -1.0, 1.0) for g in gradients]
138
  self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
139
+
140
+ # Compute predictions and metrics
141
  logits_for_metrics = tf.argmax(logits, axis=-1)
142
+ labels_for_metrics = tf.reshape(labels, [-1]) # Flatten labels
143
+ logits_for_metrics = tf.reshape(logits_for_metrics, [-1]) # Flatten predictions
144
+
145
  for metric in self.metrics:
146
  metric.update_state(labels_for_metrics, logits_for_metrics)
147
+
148
+ # Return loss and metrics
149
+ results = {m.name: m.result() for m in self.metrics}
150
+ results['loss'] = loss
151
+
152
+ return results
153
 
 
154
 
155
  def create_model(config):
156
  model = MiniSunModel(config)