GPT-K / model /utils.py
cwkuo
trim checkpoint model weights
febd802
import os
from pprint import pprint
from tqdm import tqdm
import torch
import torch.distributed as dist
import timm.models.hub as timm_hub
def drop_sequence_mask(N, S, device, p=0.1, training=True):
if training:
mask = torch.rand((N, S), device=device)
mask = mask > p
mask[torch.arange(N), torch.randint(S, (N, ))] = True # keep at least one token
mask = mask.long()
assert torch.all(torch.sum(mask, dim=1) > 0)
else:
mask = torch.ones((N, S), dtype=torch.long).to(device)
return mask
def cat_pad(x, cat_dim, pad_dim, pad_val=0):
l_max = max([xi.shape[pad_dim] for xi in x])
for i, xi in enumerate(x):
l_diff = l_max - xi.shape[pad_dim]
if l_diff > 0:
shape = list(xi.shape)
shape[pad_dim] = l_diff
p = torch.full(shape, pad_val, dtype=xi.dtype, device=xi.device)
xi = torch.cat([xi, p], dim=pad_dim)
x[i] = xi
x = torch.cat(x, dim=cat_dim)
return x
def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return self
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
def get_rank():
if not is_dist_avail_and_initialized():
return 0
return dist.get_rank()
def is_main_process():
return get_rank() == 0
def download_cached_file(url, check_hash=True, progress=False):
"""
Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again.
If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded.
"""
def get_cached_file_path():
# a hack to sync the file path across processes
parts = torch.hub.urlparse(url)
filename = os.path.basename(parts.path)
cached_file = os.path.join(timm_hub.get_cache_dir(), filename)
return cached_file
if is_main_process():
timm_hub.download_cached_file(url, check_hash, progress)
if is_dist_avail_and_initialized():
dist.barrier()
return get_cached_file_path()
def trim_ckpt(ckpt_input, ckpt_output, extra_keys=()):
kept_keys = ('llm_proj', 'knwl', 'qformer', 'ln_vision', 'query_tokens') + extra_keys
ckpt = torch.load(ckpt_input, map_location="cpu")
ckpt = {
".".join(n.split(".")[2:]): v
for n, v in tqdm(ckpt["module"].items(), dynamic_ncols=True)
if any([k in n for k in kept_keys])
}
print("Kept params:")
pprint(list(ckpt.keys()))
torch.save(ckpt, ckpt_output)