muzairkhattak
first commit for the demo
37b3db0
raw
history blame
2.95 kB
import os
import inspect
from collections import OrderedDict
from dataclasses import dataclass
import sys
sys.path.append("src")
from training.params import get_default_params
@dataclass
class Config:
tokenizer_context_length = 256
eval_freq = 1
text_encoder_model_name = 'microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract'
distributed_engine = "ddp"
train_data = None
val_data = None
train_num_samples = None
val_num_samples = None
dataset_type = "auto"
dataset_resampled = False
csv_separator = "\t"
csv_img_key = "filepath"
csv_caption_key = "title"
imagenet_val = "/datasets01/imagenet_full_size/061417/val"
imagenet_v2 = None
logs = "./logs/"
log_local = False
name = None
workers = 8
batch_size = 64
epochs = 32
lr = None
beta1 = None
beta2 = None
eps = None
wd = 0.2
warmup = 2000 # 10000
use_bn_sync = False
skip_scheduler = False
save_frequency = 10
save_most_recent = True # False
zeroshot_frequency = 1
val_frequency = 1
resume = None
precision = "amp"
clip_model = "CLIP"
model = "RN50"
pretrained = ''
pretrained_image = False
lock_image = False
lock_image_unlocked_groups = 0
lock_image_freeze_bn_stats = False
grad_checkpointing = False
local_loss = False
gather_with_grad = False
force_quick_gelu = False
torchscript = False
trace = False
dist_url = "env://"
dist_backend = "nccl"
report_to = ""
wandb_notes = ''
debug = False
copy_codebase = False
horovod = False
ddp_static_graph = False
no_set_device_rank = False
seed = 0
norm_gradient_clip = None
train_data_upsampling_factors = None
def __post_init__(self):
args = self
args.name = self.__class__.__name__
args.output_dir = os.path.join(args.logs, args.name)
for name, val in get_default_params(args.model).items():
if getattr(args, name) is None:
setattr(args, name, val)
def parse_start_end(shards):
start, end = os.path.basename(shards).split("{")[1].split("}")[0].split("..")
return int(start), int(end)
def search_config(config_name):
import importlib
project_dir = os.path.dirname(__file__)
all_configs = {}
for code in os.listdir(project_dir):
if code.endswith(".py") and code.startswith("run_configs"):
module = importlib.import_module(code[:-3])
for _config_name in dir(module):
if _config_name in ["Config"] or _config_name.startswith("__") or _config_name.startswith("run_config"):
continue
if _config_name not in all_configs:
all_configs[_config_name] = module
print(f"launching {config_name} from {all_configs[config_name].__file__}")
config = getattr(all_configs[config_name], config_name)()
return config