File size: 4,347 Bytes
74e8f2f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
# Copyright 2024 Big Vision Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Snippets and constants used a lot in image-text configs."""
import ml_collections
# pylint: disable=line-too-long
inits = {
# Downloaded & extracted from original repo:
# https://github.com/google-research/bert
'bert_base': ('base', 'gs://vit_models/lit/bert/uncased_L-12_H-768_A-12'),
'bert_large': ('large', 'gs://vit_models/lit/bert/uncased_L-uncased_L-24_H-1024_A-16'),
# Recommended "How to train your ViT..." checkpoints from
# https://github.com/google-research/vision_transformer#available-vit-models
'B/32': ('B/32', 'gs://vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0.npz'),
'B/16': ('B/16', 'gs://vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz'),
'L/16': ('L/16', 'gs://vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_strong1-wd_0.1-do_0.0-sd_0.0.npz'),
}
# pylint: enable=line-too-long
def _square875(sz):
return f'resize({int(sz/0.875)})|central_crop({sz})|value_range(-1,1)'
def _aspect75(sz):
return f'resize_small({int(sz/0.75)})|central_crop({sz})|value_range(-1,1)'
def _drop_no_real_label(f):
return len(f['real_label']) > 0
def _drop_no_imagenet(f):
return len(f['labels_imagenet']) > 0
DISCLF_DATASET_OVERRIDES = {
'imagenet2012': {'class_names': 'clip', 'split': 'validation'},
'imagenet2012_minival': {
'dataset_name': 'imagenet2012',
'class_names': 'clip',
'split': 'train[99%:]',
},
'imagenet2012_real': {
'split': 'validation',
'class_names': 'clip',
'class_names_dataset_name': 'imagenet2012',
'pp_img': lambda sz: (
_square875(sz) + '|pad_to_shape(inkey="real_label", outkey="label", shape=[10], pad_value=-1)|keep("label", "image")'), # pylint: disable=line-too-long
'pre_filter_fn': _drop_no_real_label,
},
'imagenet_v2': {'class_names': 'clip'},
'imagenet_a': {
'class_names': 'clip',
'pp_img': lambda sz: _aspect75(sz) + '|map("i1k_i1ka")',
},
'imagenet_r': {
'class_names': 'clip',
'pp_img': lambda sz: _square875(sz) + '|map("i1k_i1kr")',
},
}
def get_disclf(sz, *, pp_txt=None, dataset_names=('imagenet2012',), **kw):
"""Returns config for discriminative_classifier of specified datasets."""
config = ml_collections.ConfigDict(dict(
dataset_names=list(dataset_names),
type='proj.image_text.discriminative_classifier',
prefix='z/0shot/',
pp_img=_square875(sz),
dataset_overrides={},
cache_final=True,
**kw,
))
if pp_txt:
config.pp_txt = pp_txt
for name in dataset_names:
if name in DISCLF_DATASET_OVERRIDES:
config.dataset_overrides[name] = {**DISCLF_DATASET_OVERRIDES[name]}
d = config.dataset_overrides[name]
if 'pp_img' in d and callable(d['pp_img']):
with d.ignore_type():
d['pp_img'] = d['pp_img'](sz)
return config
def get_coco(
*,
pp_img='resize(224)|value_range(-1, 1)',
pp_txt='tokenize(max_len=16, inkey="texts", eos="sticky", pad_value=1)',
prefix='z/retr/coco_',
**kw):
"""Returns config for mscoco retrieval zero-shot.
Args:
pp_img: Pre-processing string for "image" feature.
pp_txt: Pre-processing string for texts (expected to tokenize "texts" to
"labels").
prefix: Prefix to use for metrics.
**kw: Other config settings, most notably log_{steps,percent,...}.
Returns:
`ConfigDict` that can be used as a retrieval evaluator configuration.
"""
return ml_collections.ConfigDict({
'type': 'proj.image_text.retrieval',
'pp_txt': pp_txt,
'pp_img': pp_img,
'prefix': prefix,
'dataset': 'coco_captions',
'txt_name': ('captions', 'text'),
**kw,
})
|