In [3]:
import os
import shutil

def reorganize_checkpoints_step1(root_dir, layers):
 for layer in layers:
 layer_path = os.path.join(root_dir, layer)
 if not os.path.exists(layer_path):
 print(f"Warning: {layer_path} does not exist. Skipping.")
 continue

 # Create the new results directory
 results_dir = f"{layer}_checkpoints"
 results_path = os.path.join(root_dir, results_dir)
 os.makedirs(results_path, exist_ok=True)


 # Iterate through trainer directories
 for trainer in os.listdir(layer_path):
 if trainer.startswith('trainer_'):
 trainer_path = os.path.join(layer_path, trainer)
 config_path = os.path.join(trainer_path, 'config.json')
 checkpoints_path = os.path.join(trainer_path, 'checkpoints')

 # Create trainer directory in results
 trainer_results_path = os.path.join(results_path, trainer)
 os.makedirs(trainer_results_path, exist_ok=True)

 # Copy config.json if it exists
 if os.path.exists(config_path):
 shutil.copy2(config_path, trainer_results_path)
 else:
 print(f"Warning: config.json not found in {trainer_path}")

 # Move checkpoints directory if it exists
 if os.path.exists(checkpoints_path):
 shutil.move(checkpoints_path, trainer_results_path)
 else:
 print(f"Warning: checkpoints directory not found in {trainer_path}")

 print("Step 1 of reorganization complete.")



root_directory = "/workspace/sae_eval/dictionary_learning/dictionaries/gemma-2-2b/gemma-2-2b_sweep_topk_ctx128_ef8_0824"
layers_to_process = ["resid_post_layer_3", "resid_post_layer_7", "resid_post_layer_11", "resid_post_layer_15", "resid_post_layer_19"]
reorganize_checkpoints_step1(root_directory, layers_to_process)

Step 1 of reorganization complete.


In [4]:
import os
import shutil
import re
import json

def reorganize_checkpoints_step2(root_dir, checkpoint_dirs):
 for checkpoint_dir in checkpoint_dirs:
 checkpoint_path = os.path.join(root_dir, checkpoint_dir)
 if not os.path.exists(checkpoint_path):
 print(f"Warning: {checkpoint_path} does not exist. Skipping.")
 continue

 # Iterate through trainer directories
 for trainer in os.listdir(checkpoint_path):
 if trainer.startswith('trainer_'):
 trainer_path = os.path.join(checkpoint_path, trainer)
 config_path = os.path.join(trainer_path, 'config.json')
 checkpoints_path = os.path.join(trainer_path, 'checkpoints')

 if not os.path.exists(checkpoints_path):
 print(f"Warning: checkpoints directory not found in {trainer_path}")
 continue

 # Process each checkpoint
 for checkpoint in os.listdir(checkpoints_path):
 match = re.match(r'ae_(\d+)\.pt', checkpoint)
 if match:
 step = match.group(1)
 new_checkpoint_dir = os.path.join(checkpoint_path, f'{trainer}_step_{step}')
 os.makedirs(new_checkpoint_dir, exist_ok=True)

 # Copy config.json
 if os.path.exists(config_path):
 with open(config_path, 'r') as f:
 config = json.load(f)
 config['trainer']['steps'] = step
 new_config_path = os.path.join(new_checkpoint_dir, 'config.json')
 with open(new_config_path, 'w') as f:
 json.dump(config, f, indent=2)
 else:
 raise Exception(f"Config.json not found for {trainer}")
 print(f"Warning: config.json not found for {trainer}")

 # Move and rename checkpoint file
 old_checkpoint_path = os.path.join(checkpoints_path, checkpoint)
 new_checkpoint_path = os.path.join(new_checkpoint_dir, 'ae.pt')
 shutil.move(old_checkpoint_path, new_checkpoint_path)

 # Remove the original checkpoints directory if it's empty
 if not os.listdir(checkpoints_path):
 os.rmdir(checkpoints_path)
 else:
 raise Exception(f"Checkpoints directory {checkpoints_path} is not empty.")

 # Remove the config.json file
 if os.path.exists(config_path):
 os.remove(config_path)
 else:
 print(f"Warning: config.json not found for {trainer}")

 # Remove the trainer directory
 if not os.listdir(trainer_path):
 os.rmdir(trainer_path)
 else:
 raise Exception(f"Trainer directory {trainer_path} is not empty.")

 print("Step 2 of reorganization complete.")

checkpoint_dirs_to_process = []
for layer in layers_to_process:
 checkpoint_dirs_to_process.append(f"{layer}_checkpoints")

reorganize_checkpoints_step2(root_directory, checkpoint_dirs_to_process)

Step 2 of reorganization complete.


In [None]:
import torch

def compare_pytorch_models(file1, file2):
 # Load the models
 model1 = torch.load(file1, map_location=torch.device('cpu'))
 model2 = torch.load(file2, map_location=torch.device('cpu'))
 
 # If the loaded objects are not dictionaries, assume they are the state dictionaries
 if not isinstance(model1, dict):
 model1 = model1.state_dict()
 if not isinstance(model2, dict):
 model2 = model2.state_dict()
 
 # Check if the models have the same keys
 assert set(model1.keys()) == set(model2.keys()), "Models have different keys"
 
 # Compare each parameter
 for key in model1.keys():
 print(key)
 assert torch.allclose(model1[key], model2[key], atol=1e-7), f"Mismatch in parameter {key}"
 
 print("Models are identical within the specified tolerance.")

# Usage example (you can run this in your Jupyter notebook):
compare_pytorch_models('ae_4882.pt', 'ae_4882_converted.pt')