File size: 8,683 Bytes
9e9510b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Step 1 of reorganization complete.\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import shutil\n",
    "\n",
    "def reorganize_checkpoints_step1(root_dir, layers):\n",
    "    for layer in layers:\n",
    "        layer_path = os.path.join(root_dir, layer)\n",
    "        if not os.path.exists(layer_path):\n",
    "            print(f\"Warning: {layer_path} does not exist. Skipping.\")\n",
    "            continue\n",
    "\n",
    "        # Create the new results directory\n",
    "        results_dir = f\"{layer}_checkpoints\"\n",
    "        results_path = os.path.join(root_dir, results_dir)\n",
    "        os.makedirs(results_path, exist_ok=True)\n",
    "\n",
    "\n",
    "        # Iterate through trainer directories\n",
    "        for trainer in os.listdir(layer_path):\n",
    "            if trainer.startswith('trainer_'):\n",
    "                trainer_path = os.path.join(layer_path, trainer)\n",
    "                config_path = os.path.join(trainer_path, 'config.json')\n",
    "                checkpoints_path = os.path.join(trainer_path, 'checkpoints')\n",
    "\n",
    "                # Create trainer directory in results\n",
    "                trainer_results_path = os.path.join(results_path, trainer)\n",
    "                os.makedirs(trainer_results_path, exist_ok=True)\n",
    "\n",
    "                # Copy config.json if it exists\n",
    "                if os.path.exists(config_path):\n",
    "                    shutil.copy2(config_path, trainer_results_path)\n",
    "                else:\n",
    "                    print(f\"Warning: config.json not found in {trainer_path}\")\n",
    "\n",
    "                # Move checkpoints directory if it exists\n",
    "                if os.path.exists(checkpoints_path):\n",
    "                    shutil.move(checkpoints_path, trainer_results_path)\n",
    "                else:\n",
    "                    print(f\"Warning: checkpoints directory not found in {trainer_path}\")\n",
    "\n",
    "    print(\"Step 1 of reorganization complete.\")\n",
    "\n",
    "\n",
    "\n",
    "root_directory = \"/workspace/sae_eval/dictionary_learning/dictionaries/gemma-2-2b/gemma-2-2b_sweep_topk_ctx128_ef8_0824\"\n",
    "layers_to_process = [\"resid_post_layer_3\", \"resid_post_layer_7\", \"resid_post_layer_11\", \"resid_post_layer_15\", \"resid_post_layer_19\"]\n",
    "reorganize_checkpoints_step1(root_directory, layers_to_process)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Step 2 of reorganization complete.\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import shutil\n",
    "import re\n",
    "import json\n",
    "\n",
    "def reorganize_checkpoints_step2(root_dir, checkpoint_dirs):\n",
    "    for checkpoint_dir in checkpoint_dirs:\n",
    "        checkpoint_path = os.path.join(root_dir, checkpoint_dir)\n",
    "        if not os.path.exists(checkpoint_path):\n",
    "            print(f\"Warning: {checkpoint_path} does not exist. Skipping.\")\n",
    "            continue\n",
    "\n",
    "        # Iterate through trainer directories\n",
    "        for trainer in os.listdir(checkpoint_path):\n",
    "            if trainer.startswith('trainer_'):\n",
    "                trainer_path = os.path.join(checkpoint_path, trainer)\n",
    "                config_path = os.path.join(trainer_path, 'config.json')\n",
    "                checkpoints_path = os.path.join(trainer_path, 'checkpoints')\n",
    "\n",
    "                if not os.path.exists(checkpoints_path):\n",
    "                    print(f\"Warning: checkpoints directory not found in {trainer_path}\")\n",
    "                    continue\n",
    "\n",
    "                # Process each checkpoint\n",
    "                for checkpoint in os.listdir(checkpoints_path):\n",
    "                    match = re.match(r'ae_(\\d+)\\.pt', checkpoint)\n",
    "                    if match:\n",
    "                        step = match.group(1)\n",
    "                        new_checkpoint_dir = os.path.join(checkpoint_path, f'{trainer}_step_{step}')\n",
    "                        os.makedirs(new_checkpoint_dir, exist_ok=True)\n",
    "\n",
    "                        # Copy config.json\n",
    "                        if os.path.exists(config_path):\n",
    "                            with open(config_path, 'r') as f:\n",
    "                                config = json.load(f)\n",
    "                            config['trainer']['steps'] = step\n",
    "                            new_config_path = os.path.join(new_checkpoint_dir, 'config.json')\n",
    "                            with open(new_config_path, 'w') as f:\n",
    "                                json.dump(config, f, indent=2)\n",
    "                        else:\n",
    "                            raise Exception(f\"Config.json not found for {trainer}\")\n",
    "                            print(f\"Warning: config.json not found for {trainer}\")\n",
    "\n",
    "                        # Move and rename checkpoint file\n",
    "                        old_checkpoint_path = os.path.join(checkpoints_path, checkpoint)\n",
    "                        new_checkpoint_path = os.path.join(new_checkpoint_dir, 'ae.pt')\n",
    "                        shutil.move(old_checkpoint_path, new_checkpoint_path)\n",
    "\n",
    "                # Remove the original checkpoints directory if it's empty\n",
    "                if not os.listdir(checkpoints_path):\n",
    "                    os.rmdir(checkpoints_path)\n",
    "                else:\n",
    "                    raise Exception(f\"Checkpoints directory {checkpoints_path} is not empty.\")\n",
    "\n",
    "                # Remove the config.json file\n",
    "                if os.path.exists(config_path):\n",
    "                    os.remove(config_path)\n",
    "                else:\n",
    "                    print(f\"Warning: config.json not found for {trainer}\")\n",
    "\n",
    "                # Remove the trainer directory\n",
    "                if not os.listdir(trainer_path):\n",
    "                    os.rmdir(trainer_path)\n",
    "                else:\n",
    "                    raise Exception(f\"Trainer directory {trainer_path} is not empty.\")\n",
    "\n",
    "    print(\"Step 2 of reorganization complete.\")\n",
    "\n",
    "checkpoint_dirs_to_process = []\n",
    "for layer in layers_to_process:\n",
    "    checkpoint_dirs_to_process.append(f\"{layer}_checkpoints\")\n",
    "\n",
    "reorganize_checkpoints_step2(root_directory, checkpoint_dirs_to_process)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "def compare_pytorch_models(file1, file2):\n",
    "    # Load the models\n",
    "    model1 = torch.load(file1, map_location=torch.device('cpu'))\n",
    "    model2 = torch.load(file2, map_location=torch.device('cpu'))\n",
    "    \n",
    "    # If the loaded objects are not dictionaries, assume they are the state dictionaries\n",
    "    if not isinstance(model1, dict):\n",
    "        model1 = model1.state_dict()\n",
    "    if not isinstance(model2, dict):\n",
    "        model2 = model2.state_dict()\n",
    "    \n",
    "    # Check if the models have the same keys\n",
    "    assert set(model1.keys()) == set(model2.keys()), \"Models have different keys\"\n",
    "    \n",
    "    # Compare each parameter\n",
    "    for key in model1.keys():\n",
    "        print(key)\n",
    "        assert torch.allclose(model1[key], model2[key], atol=1e-7), f\"Mismatch in parameter {key}\"\n",
    "    \n",
    "    print(\"Models are identical within the specified tolerance.\")\n",
    "\n",
    "# Usage example (you can run this in your Jupyter notebook):\n",
    "compare_pytorch_models('ae_4882.pt', 'ae_4882_converted.pt')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}