{ "cells": [ { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Step 1 of reorganization complete.\n" ] } ], "source": [ "import os\n", "import shutil\n", "\n", "def reorganize_checkpoints_step1(root_dir, layers):\n", " for layer in layers:\n", " layer_path = os.path.join(root_dir, layer)\n", " if not os.path.exists(layer_path):\n", " print(f\"Warning: {layer_path} does not exist. Skipping.\")\n", " continue\n", "\n", " # Create the new results directory\n", " results_dir = f\"{layer}_checkpoints\"\n", " results_path = os.path.join(root_dir, results_dir)\n", " os.makedirs(results_path, exist_ok=True)\n", "\n", "\n", " # Iterate through trainer directories\n", " for trainer in os.listdir(layer_path):\n", " if trainer.startswith('trainer_'):\n", " trainer_path = os.path.join(layer_path, trainer)\n", " config_path = os.path.join(trainer_path, 'config.json')\n", " checkpoints_path = os.path.join(trainer_path, 'checkpoints')\n", "\n", " # Create trainer directory in results\n", " trainer_results_path = os.path.join(results_path, trainer)\n", " os.makedirs(trainer_results_path, exist_ok=True)\n", "\n", " # Copy config.json if it exists\n", " if os.path.exists(config_path):\n", " shutil.copy2(config_path, trainer_results_path)\n", " else:\n", " print(f\"Warning: config.json not found in {trainer_path}\")\n", "\n", " # Move checkpoints directory if it exists\n", " if os.path.exists(checkpoints_path):\n", " shutil.move(checkpoints_path, trainer_results_path)\n", " else:\n", " print(f\"Warning: checkpoints directory not found in {trainer_path}\")\n", "\n", " print(\"Step 1 of reorganization complete.\")\n", "\n", "\n", "\n", "root_directory = \"/workspace/sae_eval/dictionary_learning/dictionaries/gemma-2-2b/gemma-2-2b_sweep_topk_ctx128_ef8_0824\"\n", "layers_to_process = [\"resid_post_layer_3\", \"resid_post_layer_7\", \"resid_post_layer_11\", \"resid_post_layer_15\", \"resid_post_layer_19\"]\n", "reorganize_checkpoints_step1(root_directory, layers_to_process)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Step 2 of reorganization complete.\n" ] } ], "source": [ "import os\n", "import shutil\n", "import re\n", "import json\n", "\n", "def reorganize_checkpoints_step2(root_dir, checkpoint_dirs):\n", " for checkpoint_dir in checkpoint_dirs:\n", " checkpoint_path = os.path.join(root_dir, checkpoint_dir)\n", " if not os.path.exists(checkpoint_path):\n", " print(f\"Warning: {checkpoint_path} does not exist. Skipping.\")\n", " continue\n", "\n", " # Iterate through trainer directories\n", " for trainer in os.listdir(checkpoint_path):\n", " if trainer.startswith('trainer_'):\n", " trainer_path = os.path.join(checkpoint_path, trainer)\n", " config_path = os.path.join(trainer_path, 'config.json')\n", " checkpoints_path = os.path.join(trainer_path, 'checkpoints')\n", "\n", " if not os.path.exists(checkpoints_path):\n", " print(f\"Warning: checkpoints directory not found in {trainer_path}\")\n", " continue\n", "\n", " # Process each checkpoint\n", " for checkpoint in os.listdir(checkpoints_path):\n", " match = re.match(r'ae_(\\d+)\\.pt', checkpoint)\n", " if match:\n", " step = match.group(1)\n", " new_checkpoint_dir = os.path.join(checkpoint_path, f'{trainer}_step_{step}')\n", " os.makedirs(new_checkpoint_dir, exist_ok=True)\n", "\n", " # Copy config.json\n", " if os.path.exists(config_path):\n", " with open(config_path, 'r') as f:\n", " config = json.load(f)\n", " config['trainer']['steps'] = step\n", " new_config_path = os.path.join(new_checkpoint_dir, 'config.json')\n", " with open(new_config_path, 'w') as f:\n", " json.dump(config, f, indent=2)\n", " else:\n", " raise Exception(f\"Config.json not found for {trainer}\")\n", " print(f\"Warning: config.json not found for {trainer}\")\n", "\n", " # Move and rename checkpoint file\n", " old_checkpoint_path = os.path.join(checkpoints_path, checkpoint)\n", " new_checkpoint_path = os.path.join(new_checkpoint_dir, 'ae.pt')\n", " shutil.move(old_checkpoint_path, new_checkpoint_path)\n", "\n", " # Remove the original checkpoints directory if it's empty\n", " if not os.listdir(checkpoints_path):\n", " os.rmdir(checkpoints_path)\n", " else:\n", " raise Exception(f\"Checkpoints directory {checkpoints_path} is not empty.\")\n", "\n", " # Remove the config.json file\n", " if os.path.exists(config_path):\n", " os.remove(config_path)\n", " else:\n", " print(f\"Warning: config.json not found for {trainer}\")\n", "\n", " # Remove the trainer directory\n", " if not os.listdir(trainer_path):\n", " os.rmdir(trainer_path)\n", " else:\n", " raise Exception(f\"Trainer directory {trainer_path} is not empty.\")\n", "\n", " print(\"Step 2 of reorganization complete.\")\n", "\n", "checkpoint_dirs_to_process = []\n", "for layer in layers_to_process:\n", " checkpoint_dirs_to_process.append(f\"{layer}_checkpoints\")\n", "\n", "reorganize_checkpoints_step2(root_directory, checkpoint_dirs_to_process)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch\n", "\n", "def compare_pytorch_models(file1, file2):\n", " # Load the models\n", " model1 = torch.load(file1, map_location=torch.device('cpu'))\n", " model2 = torch.load(file2, map_location=torch.device('cpu'))\n", " \n", " # If the loaded objects are not dictionaries, assume they are the state dictionaries\n", " if not isinstance(model1, dict):\n", " model1 = model1.state_dict()\n", " if not isinstance(model2, dict):\n", " model2 = model2.state_dict()\n", " \n", " # Check if the models have the same keys\n", " assert set(model1.keys()) == set(model2.keys()), \"Models have different keys\"\n", " \n", " # Compare each parameter\n", " for key in model1.keys():\n", " print(key)\n", " assert torch.allclose(model1[key], model2[key], atol=1e-7), f\"Mismatch in parameter {key}\"\n", " \n", " print(\"Models are identical within the specified tolerance.\")\n", "\n", "# Usage example (you can run this in your Jupyter notebook):\n", "compare_pytorch_models('ae_4882.pt', 'ae_4882_converted.pt')" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 2 }