|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Snippets and constants used a lot in image-text configs.""" |
|
|
|
import ml_collections |
|
|
|
|
|
|
|
inits = { |
|
|
|
|
|
'bert_base': ('base', 'gs://vit_models/lit/bert/uncased_L-12_H-768_A-12'), |
|
'bert_large': ('large', 'gs://vit_models/lit/bert/uncased_L-uncased_L-24_H-1024_A-16'), |
|
|
|
|
|
'B/32': ('B/32', 'gs://vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0.npz'), |
|
'B/16': ('B/16', 'gs://vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz'), |
|
'L/16': ('L/16', 'gs://vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_strong1-wd_0.1-do_0.0-sd_0.0.npz'), |
|
} |
|
|
|
|
|
|
|
def _square875(sz): |
|
return f'resize({int(sz/0.875)})|central_crop({sz})|value_range(-1,1)' |
|
|
|
|
|
def _aspect75(sz): |
|
return f'resize_small({int(sz/0.75)})|central_crop({sz})|value_range(-1,1)' |
|
|
|
|
|
def _drop_no_real_label(f): |
|
return len(f['real_label']) > 0 |
|
|
|
|
|
def _drop_no_imagenet(f): |
|
return len(f['labels_imagenet']) > 0 |
|
|
|
|
|
DISCLF_DATASET_OVERRIDES = { |
|
'imagenet2012': {'class_names': 'clip', 'split': 'validation'}, |
|
'imagenet2012_minival': { |
|
'dataset_name': 'imagenet2012', |
|
'class_names': 'clip', |
|
'split': 'train[99%:]', |
|
}, |
|
'imagenet2012_real': { |
|
'split': 'validation', |
|
'class_names': 'clip', |
|
'class_names_dataset_name': 'imagenet2012', |
|
'pp_img': lambda sz: ( |
|
_square875(sz) + '|pad_to_shape(inkey="real_label", outkey="label", shape=[10], pad_value=-1)|keep("label", "image")'), |
|
'pre_filter_fn': _drop_no_real_label, |
|
}, |
|
'imagenet_v2': {'class_names': 'clip'}, |
|
'imagenet_a': { |
|
'class_names': 'clip', |
|
'pp_img': lambda sz: _aspect75(sz) + '|map("i1k_i1ka")', |
|
}, |
|
'imagenet_r': { |
|
'class_names': 'clip', |
|
'pp_img': lambda sz: _square875(sz) + '|map("i1k_i1kr")', |
|
}, |
|
} |
|
|
|
|
|
def get_disclf(sz, *, pp_txt=None, dataset_names=('imagenet2012',), **kw): |
|
"""Returns config for discriminative_classifier of specified datasets.""" |
|
config = ml_collections.ConfigDict(dict( |
|
dataset_names=list(dataset_names), |
|
type='proj.image_text.discriminative_classifier', |
|
prefix='z/0shot/', |
|
pp_img=_square875(sz), |
|
dataset_overrides={}, |
|
cache_final=True, |
|
**kw, |
|
)) |
|
if pp_txt: |
|
config.pp_txt = pp_txt |
|
for name in dataset_names: |
|
if name in DISCLF_DATASET_OVERRIDES: |
|
config.dataset_overrides[name] = {**DISCLF_DATASET_OVERRIDES[name]} |
|
d = config.dataset_overrides[name] |
|
if 'pp_img' in d and callable(d['pp_img']): |
|
with d.ignore_type(): |
|
d['pp_img'] = d['pp_img'](sz) |
|
return config |
|
|
|
|
|
def get_coco( |
|
*, |
|
pp_img='resize(224)|value_range(-1, 1)', |
|
pp_txt='tokenize(max_len=16, inkey="texts", eos="sticky", pad_value=1)', |
|
prefix='z/retr/coco_', |
|
**kw): |
|
"""Returns config for mscoco retrieval zero-shot. |
|
|
|
Args: |
|
pp_img: Pre-processing string for "image" feature. |
|
pp_txt: Pre-processing string for texts (expected to tokenize "texts" to |
|
"labels"). |
|
prefix: Prefix to use for metrics. |
|
**kw: Other config settings, most notably log_{steps,percent,...}. |
|
|
|
Returns: |
|
`ConfigDict` that can be used as a retrieval evaluator configuration. |
|
""" |
|
return ml_collections.ConfigDict({ |
|
'type': 'proj.image_text.retrieval', |
|
'pp_txt': pp_txt, |
|
'pp_img': pp_img, |
|
'prefix': prefix, |
|
'dataset': 'coco_captions', |
|
'txt_name': ('captions', 'text'), |
|
**kw, |
|
}) |
|
|