File size: 2,952 Bytes
37b3db0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import os
import inspect

from collections import OrderedDict
from dataclasses import dataclass

import sys
sys.path.append("src")

from training.params import get_default_params


@dataclass
class Config:
    tokenizer_context_length = 256
    eval_freq = 1
    text_encoder_model_name = 'microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract'
    distributed_engine = "ddp"
    train_data = None
    val_data = None
    train_num_samples = None
    val_num_samples = None
    dataset_type = "auto"
    dataset_resampled = False
    csv_separator = "\t"
    csv_img_key = "filepath"
    csv_caption_key = "title"
    imagenet_val = "/datasets01/imagenet_full_size/061417/val"
    imagenet_v2 = None
    logs = "./logs/"
    log_local = False
    name = None
    workers = 8
    batch_size = 64
    epochs = 32
    lr = None
    beta1 = None
    beta2 = None
    eps = None
    wd = 0.2
    warmup = 2000  # 10000
    use_bn_sync = False
    skip_scheduler = False
    save_frequency = 10
    save_most_recent = True  # False
    zeroshot_frequency = 1
    val_frequency = 1
    resume = None
    precision = "amp"
    clip_model = "CLIP"
    model = "RN50"
    pretrained = ''
    pretrained_image = False
    lock_image = False
    lock_image_unlocked_groups = 0
    lock_image_freeze_bn_stats = False
    grad_checkpointing = False
    local_loss = False
    gather_with_grad = False
    force_quick_gelu = False
    torchscript = False
    trace = False
    dist_url = "env://"
    dist_backend = "nccl"
    report_to = ""
    wandb_notes = ''
    debug = False
    copy_codebase = False
    horovod = False
    ddp_static_graph = False
    no_set_device_rank = False
    seed = 0
    norm_gradient_clip = None
    train_data_upsampling_factors = None

    def __post_init__(self):
        args = self
        args.name = self.__class__.__name__
        args.output_dir = os.path.join(args.logs, args.name)

        for name, val in get_default_params(args.model).items():
            if getattr(args, name) is None:
                setattr(args, name, val)


def parse_start_end(shards):
    start, end = os.path.basename(shards).split("{")[1].split("}")[0].split("..")
    return int(start), int(end)


def search_config(config_name):
    import importlib
    project_dir = os.path.dirname(__file__)
    all_configs = {}
    for code in os.listdir(project_dir):
        if code.endswith(".py") and code.startswith("run_configs"):
            module = importlib.import_module(code[:-3])
            for _config_name in dir(module):
                if _config_name in ["Config"] or _config_name.startswith("__") or _config_name.startswith("run_config"):
                    continue
                if _config_name not in all_configs:
                    all_configs[_config_name] = module
    print(f"launching {config_name} from {all_configs[config_name].__file__}")
    config = getattr(all_configs[config_name], config_name)()
    return config