Spaces:
Runtime error
Runtime error
import os | |
from pprint import pprint | |
from tqdm import tqdm | |
import torch | |
import torch.distributed as dist | |
import timm.models.hub as timm_hub | |
def drop_sequence_mask(N, S, device, p=0.1, training=True): | |
if training: | |
mask = torch.rand((N, S), device=device) | |
mask = mask > p | |
mask[torch.arange(N), torch.randint(S, (N, ))] = True # keep at least one token | |
mask = mask.long() | |
assert torch.all(torch.sum(mask, dim=1) > 0) | |
else: | |
mask = torch.ones((N, S), dtype=torch.long).to(device) | |
return mask | |
def cat_pad(x, cat_dim, pad_dim, pad_val=0): | |
l_max = max([xi.shape[pad_dim] for xi in x]) | |
for i, xi in enumerate(x): | |
l_diff = l_max - xi.shape[pad_dim] | |
if l_diff > 0: | |
shape = list(xi.shape) | |
shape[pad_dim] = l_diff | |
p = torch.full(shape, pad_val, dtype=xi.dtype, device=xi.device) | |
xi = torch.cat([xi, p], dim=pad_dim) | |
x[i] = xi | |
x = torch.cat(x, dim=cat_dim) | |
return x | |
def disabled_train(self, mode=True): | |
"""Overwrite model.train with this function to make sure train/eval mode | |
does not change anymore.""" | |
return self | |
def is_dist_avail_and_initialized(): | |
if not dist.is_available(): | |
return False | |
if not dist.is_initialized(): | |
return False | |
return True | |
def get_rank(): | |
if not is_dist_avail_and_initialized(): | |
return 0 | |
return dist.get_rank() | |
def is_main_process(): | |
return get_rank() == 0 | |
def download_cached_file(url, check_hash=True, progress=False): | |
""" | |
Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again. | |
If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded. | |
""" | |
def get_cached_file_path(): | |
# a hack to sync the file path across processes | |
parts = torch.hub.urlparse(url) | |
filename = os.path.basename(parts.path) | |
cached_file = os.path.join(timm_hub.get_cache_dir(), filename) | |
return cached_file | |
if is_main_process(): | |
timm_hub.download_cached_file(url, check_hash, progress) | |
if is_dist_avail_and_initialized(): | |
dist.barrier() | |
return get_cached_file_path() | |
def trim_ckpt(ckpt_input, ckpt_output, extra_keys=()): | |
kept_keys = ('llm_proj', 'knwl', 'qformer', 'ln_vision', 'query_tokens') + extra_keys | |
ckpt = torch.load(ckpt_input, map_location="cpu") | |
ckpt = { | |
".".join(n.split(".")[2:]): v | |
for n, v in tqdm(ckpt["module"].items(), dynamic_ncols=True) | |
if any([k in n for k in kept_keys]) | |
} | |
print("Kept params:") | |
pprint(list(ckpt.keys())) | |
torch.save(ckpt, ckpt_output) | |