# Copyright (c) 2024, EleutherAI # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Adapted from https://github.com/awaelchli/pytorch-lightning-snippets/blob/master/checkpoint/peek.py import code import os import re from argparse import ArgumentParser, Namespace from collections.abc import Mapping, Sequence from pathlib import Path import torch class COLORS: BLUE = "\033[94m" CYAN = "\033[96m" GREEN = "\033[92m" RED = "\033[31m" YELLOW = "\033[33m" MAGENTA = "\033[35m" WHITE = "\033[37m" UNDERLINE = "\033[4m" END = "\033[0m" PRIMITIVE_TYPES = (int, float, bool, str, type) def natural_sort(l): convert = lambda text: int(text) if text.isdigit() else text.lower() alphanum_key = lambda key: [convert(c) for c in re.split("([0-9]+)", str(key))] return sorted(l, key=alphanum_key) def sizeof_fmt(num, suffix="B"): for unit in ["", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"]: if abs(num) < 1024.0: return "%3.1f%s%s" % (num, unit, suffix) num /= 1024.0 return "%.1f%s%s" % (num, "Yi", suffix) def pretty_print(contents: dict): """Prints a nice summary of the top-level contents in a checkpoint dictionary.""" col_size = max(len(str(k)) for k in contents) for k, v in sorted(contents.items()): key_length = len(str(k)) line = " " * (col_size - key_length) line += f"{k}: {COLORS.BLUE}{type(v).__name__}{COLORS.END}" if isinstance(v, dict): pretty_print(v) elif isinstance(v, PRIMITIVE_TYPES): line += f" = " line += f"{COLORS.CYAN}{repr(v)}{COLORS.END}" elif isinstance(v, Sequence): line += ", " line += f"{COLORS.CYAN}len={len(v)}{COLORS.END}" elif isinstance(v, torch.Tensor): if v.ndimension() in (0, 1) and v.numel() == 1: line += f" = " line += f"{COLORS.CYAN}{v.item()}{COLORS.END}" else: line += ", " line += f"{COLORS.CYAN}shape={list(v.shape)}{COLORS.END}" line += ", " line += f"{COLORS.CYAN}dtype={v.dtype}{COLORS.END}" line += ( ", " + f"{COLORS.CYAN}size={sizeof_fmt(v.nelement() * v.element_size())}{COLORS.END}" ) print(line) def common_entries(*dcts): if not dcts: return for i in set(dcts[0]).intersection(*dcts[1:]): yield (i,) + tuple(d[i] for d in dcts) def pretty_print_double(contents1: dict, contents2: dict, args): """Prints a nice summary of the top-level contents in a checkpoint dictionary.""" col_size = max( max(len(str(k)) for k in contents1), max(len(str(k)) for k in contents2) ) common_keys = list(contents1.keys() & contents2.keys()) uncommon_keys_1 = [i for i in contents2.keys() if i not in common_keys] uncommon_keys_2 = [i for i in contents1.keys() if i not in common_keys] diffs_found = False if uncommon_keys_1 + uncommon_keys_2: diffs_found = True if uncommon_keys_1: print( f"{COLORS.RED}{len(uncommon_keys_1)} key(s) found in ckpt 1 that isn't present in ckpt 2:{COLORS.END} \n\t{COLORS.BLUE}{' '.join(uncommon_keys_1)}{COLORS.END}" ) if uncommon_keys_2: print( f"{COLORS.RED}{len(uncommon_keys_2)} key(s) found in ckpt 2 that isn't present in ckpt 1:{COLORS.END} \n\t{COLORS.BLUE}{' '.join(uncommon_keys_2)}{COLORS.END}" ) for k, v1, v2 in sorted(common_entries(contents1, contents2)): key_length = len(str(k)) line = " " * (col_size - key_length) if type(v1) != type(v2): print( f"{COLORS.RED}{k} is a different type between ckpt1 and ckpt2: ({type(v1).__name__} vs. {type(v2).__name__}){COLORS.END}" ) continue else: prefix = f"{k}: {COLORS.BLUE}{type(v1).__name__} | {type(v2).__name__}{COLORS.END}" if isinstance(v1, dict): pretty_print_double(v1, v2, args) elif isinstance(v1, PRIMITIVE_TYPES): if repr(v1) != repr(v2): c = COLORS.RED line += f" = " line += f"{c}{repr(v1)} | {repr(v2)}{COLORS.END}" else: c = COLORS.CYAN if not args.diff: line += f" = " line += f"{c}{repr(v1)} | {repr(v2)}{COLORS.END}" elif isinstance(v1, Sequence): if len(v1) != len(v2): c = COLORS.RED line += ", " line += f"{c}len={len(v1)} | len={len(v2)}{COLORS.END}" else: c = COLORS.CYAN if not args.diff: line += ", " line += f"{c}len={len(v1)} | len={len(v2)}{COLORS.END}" elif isinstance(v1, torch.Tensor): if v1.ndimension() != v2.ndimension(): c = COLORS.RED else: c = COLORS.CYAN if (v1.ndimension() in (0, 1) and v1.numel() == 1) and ( v2.ndimension() in (0, 1) and v2.numel() == 1 ): if not args.diff: line += f" = " line += f"{c}{v1.item()} | {c}{v2.item()}{COLORS.END}" else: if list(v1.shape) != list(v2.shape): c = COLORS.RED line += ", " line += f"{c}shape={list(v1.shape)} | shape={list(v2.shape)}{COLORS.END}" else: c = COLORS.CYAN if not args.diff: line += ", " line += f"{c}shape={list(v1.shape)} | shape={list(v2.shape)}{COLORS.END}" if v1.dtype != v2.dtype: c = COLORS.RED line += f"{c}dtype={v1.dtype} | dtype={v2.dtype}{COLORS.END}" else: c = COLORS.CYAN if not args.diff: line += ", " line += f"{c}dtype={v1.dtype} | dtype={v2.dtype}{COLORS.END}" if list(v1.shape) == list(v2.shape): if torch.allclose(v1, v2): if not args.diff: line += f", {COLORS.CYAN}VALUES EQUAL{COLORS.END}" else: line += f", {COLORS.RED}VALUES DIFFER{COLORS.END}" if line.replace(" ", "") != "": line = prefix + line print(line) diffs_found = True if args.diff and not diffs_found: pass else: if not args.diff: print("\n") return diffs_found def get_attribute(obj: object, name: str) -> object: if isinstance(obj, Mapping): return obj[name] if isinstance(obj, Namespace): return obj.name return getattr(object, name) def get_files(pth): if os.path.isdir(pth): files = list(Path(pth).glob("*.pt")) + list(Path(pth).glob("*.ckpt")) elif os.path.isfile(pth): assert pth.endswith(".pt") or pth.endswith(".ckpt") files = [Path(pth)] else: raise ValueError("Dir / File not found.") return natural_sort(files) def peek(args: Namespace): files = get_files(args.dir) for file in files: file = Path(file).absolute() print(f"{COLORS.GREEN}{file.name}:{COLORS.END}") ckpt = torch.load(file, map_location=torch.device("cpu")) selection = dict() attribute_names = args.attributes or list(ckpt.keys()) for name in attribute_names: parts = name.split("/") current = ckpt for part in parts: current = get_attribute(current, part) selection.update({name: current}) pretty_print(selection) print("\n") if args.interactive: code.interact( banner="Entering interactive shell. You can access the checkpoint contents through the local variable 'checkpoint'.", local={"checkpoint": ckpt, "torch": torch}, ) def get_shared_fnames(files_1, files_2): names_1 = [Path(i).name for i in files_1] names_1_parent = Path(files_1[0]).parent names_2 = [Path(i).name for i in files_2] names_2_parent = Path(files_2[0]).parent shared_names = list(set.intersection(*map(set, [names_1, names_2]))) return [names_1_parent / i for i in shared_names], [ names_2_parent / i for i in shared_names ] def get_selection(filename, args): ckpt = torch.load(filename, map_location=torch.device("cpu")) selection = dict() attribute_names = args.attributes or list(ckpt.keys()) for name in attribute_names: parts = name.split("/") current = ckpt for part in parts: current = get_attribute(current, part) selection.update({name: current}) return selection def compare(args: Namespace): dirs = [i.strip() for i in args.dir.split(",")] assert len(dirs) == 2, "Only works with 2 directories / files" files_1 = get_files(dirs[0]) files_2 = get_files(dirs[1]) files_1, files_2 = get_shared_fnames(files_1, files_2) for file1, file2 in zip(files_1, files_2): file1 = Path(file1).absolute() file2 = Path(file2).absolute() print(f"COMPARING {COLORS.GREEN}{file1.name} & {file2.name}:{COLORS.END}") selection_1 = get_selection(file1, args) selection_2 = get_selection(file2, args) diffs_found = pretty_print_double(selection_1, selection_2, args) if args.diff and diffs_found: print( f"{COLORS.RED}THE ABOVE DIFFS WERE FOUND IN {file1.name} & {file2.name} ^{COLORS.END}\n" ) if args.interactive: code.interact( banner="Entering interactive shell. You can access the checkpoint contents through the local variable 'selection_1' / 'selection_2'.\nPress Ctrl-D to exit.", local={ "selection_1": selection_1, "selection_2": selection_2, "torch": torch, }, ) def main(): parser = ArgumentParser() parser.add_argument( "dir", type=str, help="The checkpoint dir to inspect. Must be either: \ - a directory containing pickle binaries saved with 'torch.save' ending in .pt or .ckpt \ - a single path to a .pt or .ckpt file \ - two comma separated directories - in which case the script will *compare* the two checkpoints", ) parser.add_argument( "--attributes", nargs="*", help="Name of one or several attributes to query. To access an attribute within a nested structure, use '/' as separator.", default=None, ) parser.add_argument( "--interactive", "-i", action="store_true", help="Drops into interactive shell after printing the summary.", ) parser.add_argument( "--compare", "-c", action="store_true", help="If true, script will compare two directories separated by commas", ) parser.add_argument( "--diff", "-d", action="store_true", help="In compare mode, only print diffs" ) args = parser.parse_args() if args.compare: compare(args) else: peek(args) if __name__ == "__main__": main()