File size: 2,739 Bytes
7962ed0
febd802
 
7962ed0
 
 
 
 
febd802
7962ed0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
febd802
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import os
from pprint import pprint
from tqdm import tqdm
import torch
import torch.distributed as dist
import timm.models.hub as timm_hub



def drop_sequence_mask(N, S, device, p=0.1, training=True):
    if training:
        mask = torch.rand((N, S), device=device)
        mask = mask > p
        mask[torch.arange(N), torch.randint(S, (N, ))] = True  # keep at least one token
        mask = mask.long()
        assert torch.all(torch.sum(mask, dim=1) > 0)
    else:
        mask = torch.ones((N, S), dtype=torch.long).to(device)
    
    return mask


def cat_pad(x, cat_dim, pad_dim, pad_val=0):
    l_max = max([xi.shape[pad_dim] for xi in x])
    for i, xi in enumerate(x):
        l_diff = l_max - xi.shape[pad_dim]
        if l_diff > 0:
            shape = list(xi.shape)
            shape[pad_dim] = l_diff
            p = torch.full(shape, pad_val, dtype=xi.dtype, device=xi.device)
            xi = torch.cat([xi, p], dim=pad_dim)
        x[i] = xi
    x = torch.cat(x, dim=cat_dim)

    return x


def disabled_train(self, mode=True):
    """Overwrite model.train with this function to make sure train/eval mode
    does not change anymore."""
    return self


def is_dist_avail_and_initialized():
    if not dist.is_available():
        return False
    if not dist.is_initialized():
        return False
    return True


def get_rank():
    if not is_dist_avail_and_initialized():
        return 0
    return dist.get_rank()


def is_main_process():
    return get_rank() == 0


def download_cached_file(url, check_hash=True, progress=False):
    """
    Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again.
    If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded.
    """

    def get_cached_file_path():
        # a hack to sync the file path across processes
        parts = torch.hub.urlparse(url)
        filename = os.path.basename(parts.path)
        cached_file = os.path.join(timm_hub.get_cache_dir(), filename)

        return cached_file

    if is_main_process():
        timm_hub.download_cached_file(url, check_hash, progress)

    if is_dist_avail_and_initialized():
        dist.barrier()

    return get_cached_file_path()


def trim_ckpt(ckpt_input, ckpt_output, extra_keys=()):
    kept_keys = ('llm_proj', 'knwl', 'qformer', 'ln_vision', 'query_tokens') + extra_keys

    ckpt = torch.load(ckpt_input, map_location="cpu")
    ckpt = {
        ".".join(n.split(".")[2:]): v
        for n, v in tqdm(ckpt["module"].items(), dynamic_ncols=True)
        if any([k in n for k in kept_keys])
    }
    print("Kept params:")
    pprint(list(ckpt.keys()))

    torch.save(ckpt, ckpt_output)