File size: 3,703 Bytes
9aa8ed3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import torch
import torch.nn as nn
import torch.nn.functional as F
from tensorboardX import SummaryWriter
from safetensors.torch import save_file, load_file
from pathlib import Path
import time

def count_parameters_layerwise(model):
    # Layerwise params, turn this into a util function.
    total_params = 0
    layer_params = {}
    
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad:
            continue
            
        param_count = parameter.numel()
        layer_params[name] = param_count
        total_params += param_count
    
    print(f"\nModel Parameter Summary:")
    print("-" * 60)
    for name, count in layer_params.items():
        print(f"{name}: {count:,} parameters")
    print("-" * 60)
    print(f"Total Trainable Parameters: {total_params:,}\n")
    
    return total_params

def save_checkpoint(model, filename="checkpoint.safetensors"):
    if hasattr(model, '_orig_mod'):
        model = model._orig_mod
    
    torch.save(model.state_dict(), filename.replace('.safetensors', '.pt'))

def load_checkpoint(model, filename="checkpoint.safetensors"):
    if hasattr(model, '_orig_mod'):
        model = model._orig_mod
        
    try:
        model_state = load_file(filename)
        model.load_state_dict(model_state)
    except Exception as e:
        model_state = torch.load(filename.replace('.safetensors', '.pt'), weights_only=True)
        model.load_state_dict(model_state)
        
class TBLogger:
    def __init__(self, log_dir='logs/current_run', flush_secs=10, enable_grad_logging=True):
        Path(log_dir).mkdir(parents=True, exist_ok=True)
        self.writer = SummaryWriter(log_dir, flush_secs=flush_secs)
        self.enable_grad_logging = enable_grad_logging
        self.start_time = time.time()
        
    def log(self, metrics, step=None, model=None, prefix='', grad_checking=False):
        for name, value in metrics.items():
            full_name = f"{prefix}{name}" if prefix else name
            
            if isinstance(value, (int, float)):
                self.writer.add_scalar(full_name, value, step)
            elif isinstance(value, torch.Tensor):
                self.writer.add_scalar(full_name, value.item(), step)
            elif isinstance(value, (list, tuple)) and len(value) > 0:
                if all(isinstance(x, (int, float)) for x in value):
                    self.writer.add_histogram(full_name, torch.tensor(value), step)
                
        if self.enable_grad_logging and model is not None:
            self._log_gradients(model, step, grad_checking)
            
    def _log_gradients(self, model, step, grad_checking):
        total_norm = 0.0
        for name, param in model.named_parameters():
            if grad_checking and param.grad is not None:
                # Check for inf/nan in gradients
                if torch.isnan(param.grad).any():
                    print(f"Warning: Found nan in gradients for layer: {name}")
                    continue
                if torch.isinf(param.grad).any():
                    print(f"Warning: Found inf in gradients for layer: {name}")
                    continue
                    
                param_norm = param.grad.detach().data.norm(2)
                self.writer.add_scalar(f"gradients/{name}_norm", param_norm, step)
                total_norm += param_norm.item() ** 2
        
        # Only compute total norm if we haven't encountered inf/nan
        if total_norm > 0:  # This means we had valid gradients
            total_norm = total_norm ** 0.5
            self.writer.add_scalar("gradients/total_norm", total_norm, step)
    
    def close(self):
        self.writer.close()