diff --git a/big_vision/__init__.py b/big_vision/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/big_vision/__pycache__/__init__.cpython-310.pyc b/big_vision/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dfa66cf82b641aa87e35835af5a6af722ca8682a Binary files /dev/null and b/big_vision/__pycache__/__init__.cpython-310.pyc differ diff --git a/big_vision/__pycache__/sharding.cpython-310.pyc b/big_vision/__pycache__/sharding.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e86032655dc7aac2f218b0e167c95ef5db86c99a Binary files /dev/null and b/big_vision/__pycache__/sharding.cpython-310.pyc differ diff --git a/big_vision/__pycache__/utils.cpython-310.pyc b/big_vision/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..908193120372b743c0c98453e81115d65ec0b2d1 Binary files /dev/null and b/big_vision/__pycache__/utils.cpython-310.pyc differ diff --git a/big_vision/configs/__init__.py b/big_vision/configs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/big_vision/configs/bit_i1k.py b/big_vision/configs/bit_i1k.py new file mode 100644 index 0000000000000000000000000000000000000000..8bd53c318b108bab483923a95e8d3c0df42d709d --- /dev/null +++ b/big_vision/configs/bit_i1k.py @@ -0,0 +1,102 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""Pre-training BiT on ILSVRC-2012 as in https://arxiv.org/abs/1912.11370 + +Run training of a BiT-ResNet-50x1 variant, which takes ~32min on v3-128: + +big_vision.train \ + --config big_vision/configs/bit_i1k.py \ + --workdir gs://[your_bucket]/big_vision/`date '+%m-%d_%H%M'` \ + --config.model.depth 50 --config.model.width 1 +""" + +# from big_vision.configs.common_fewshot import get_fewshot_lsr +import ml_collections as mlc + + +def get_config(runlocal=False): + """Config for training on ImageNet-1k.""" + config = mlc.ConfigDict() + + config.seed = 0 + config.total_epochs = 90 + config.num_classes = 1000 + config.loss = 'softmax_xent' + + config.input = dict() + config.input.data = dict( + name='imagenet2012', + split='train[:99%]', + ) + config.input.batch_size = 4096 + config.input.cache_raw = True # Needs up to 120GB of RAM! + config.input.shuffle_buffer_size = 250_000 # Per host. + + pp_common = '|onehot(1000, key="{lbl}", key_result="labels")' + pp_common += '|value_range(-1, 1)|keep("image", "labels")' + config.input.pp = 'decode_jpeg_and_inception_crop(224)|flip_lr' + pp_common.format(lbl='label') + pp_eval = 'decode|resize_small(256)|central_crop(224)' + pp_common + + config.log_training_steps = 50 + config.ckpt_steps = 1000 + + # Model section + config.model_name = 'bit' + config.model = dict( + depth=50, # You can also pass e.g. [3, 5, 10, 2] + width=1.0, + ) + + # Optimizer section + config.optax_name = 'big_vision.momentum_hp' + config.grad_clip_norm = 1.0 + + # linear scaling rule. Don't forget to sweep if sweeping batch_size. + config.wd = (1e-4 / 256) * config.input.batch_size + config.lr = (0.1 / 256) * config.input.batch_size + config.schedule = dict(decay_type='cosine', warmup_steps=1000) + + # Eval section + def get_eval(split, dataset='imagenet2012'): + return dict( + type='classification', + data=dict(name=dataset, split=split), + pp_fn=pp_eval.format(lbl='label'), + loss_name=config.loss, + log_steps=1000, # Very fast O(seconds) so it's fine to run it often. + cache='final_data', + ) + config.evals = {} + config.evals.train = get_eval('train[:2%]') + config.evals.minival = get_eval('train[99%:]') + config.evals.val = get_eval('validation') + config.evals.v2 = get_eval('test', dataset='imagenet_v2') + config.evals.real = get_eval('validation', dataset='imagenet2012_real') + config.evals.real.pp_fn = pp_eval.format(lbl='real_label') + + # config.evals.fewshot = get_fewshot_lsr(runlocal=runlocal) + # config.evals.fewshot.log_steps = 1000 + + if runlocal: + config.input.batch_size = 32 + config.input.cache_raw = False + config.input.shuffle_buffer_size = 100 + + local_eval = config.evals.val + config.evals = {'val': local_eval} + config.evals.val.cache = 'none' + + return config \ No newline at end of file diff --git a/big_vision/configs/bit_i21k.py b/big_vision/configs/bit_i21k.py new file mode 100644 index 0000000000000000000000000000000000000000..c42342e9ab8ff513211954efab79dd4309fbe101 --- /dev/null +++ b/big_vision/configs/bit_i21k.py @@ -0,0 +1,85 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""A config for pre-training BiT on ImageNet-21k. + +This config relies on the Imagenet-21k tfds dataset, which is not yet +available publicly in TFDS. We intend to add the dataset to public TFDS soon, +and this config will then be runnable. +""" + +from big_vision.configs.common_fewshot import get_fewshot_lsr +import ml_collections as mlc + + +def get_config(): + """Config for training on imagenet-21k.""" + config = mlc.ConfigDict() + + config.seed = 0 + config.total_epochs = 90 + config.num_classes = 21843 + config.init_head_bias = -10.0 + config.loss = 'sigmoid_xent' + + config.input = dict() + config.input.data = dict( + name='imagenet21k', + split='full[51200:]', + ) + config.input.batch_size = 4096 + config.input.shuffle_buffer_size = 250_000 # Per host, so small-ish is ok. + + pp_common = '|value_range(-1, 1)|onehot({onehot_args})|keep("image", "labels")' + pp_common_i21k = pp_common.format(onehot_args=f'{config.num_classes}') + pp_common_i1k = pp_common.format(onehot_args='1000, key="label", key_result="labels"') + config.input.pp = 'decode_jpeg_and_inception_crop(224)|flip_lr' + pp_common_i21k + pp_eval = 'decode|resize_small(256)|central_crop(224)' + + config.log_training_steps = 50 + config.ckpt_steps = 1000 + + # Model section + config.model_name = 'bit_paper' + config.model = dict(depth=50, width=1.0) + + # Optimizer section + config.optax_name = 'big_vision.momentum_hp' + config.grad_clip_norm = 1.0 + + # linear scaling rule. Don't forget to sweep if sweeping batch_size. + config.lr = (0.03 / 256) * config.input.batch_size + config.wd = (3e-5 / 256) * config.input.batch_size + config.schedule = dict(decay_type='cosine', warmup_steps=5000) + + # Evaluations on i21k itself. + def eval_i21k(split): + return dict( + type='classification', + data={**config.input.data, 'split': split}, + pp_fn=pp_eval + pp_common_i21k, + loss_name=config.loss, + log_steps=1000, # Very fast O(seconds) so it's fine to run it often. + ) + config.evals = {} + config.evals.test = eval_i21k('full[:25_600]') + config.evals.val = eval_i21k('full[25_600:51_200]') + config.evals.train = eval_i21k('full[51_200:76_800]') + + # Few-shot evaluators + config.evals.fewshot = get_fewshot_lsr() + config.evals.fewshot.log_steps = 25_000 + + return config \ No newline at end of file diff --git a/big_vision/configs/common.py b/big_vision/configs/common.py new file mode 100644 index 0000000000000000000000000000000000000000..c1628c3ccaa554eb5d2a39e2317fb06953542a6d --- /dev/null +++ b/big_vision/configs/common.py @@ -0,0 +1,188 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A few things commonly used across A LOT of config files.""" + +import string + +import ml_collections as mlc + + +def input_for_quicktest(config_input, quicktest): + if quicktest: + config_input.batch_size = 8 + config_input.shuffle_buffer_size = 10 + config_input.cache_raw = False + + +def parse_arg(arg, lazy=False, **spec): + """Makes ConfigDict's get_config single-string argument more usable. + + Example use in the config file: + + import big_vision.configs.common as bvcc + def get_config(arg): + arg = bvcc.parse_arg(arg, + res=(224, int), + runlocal=False, + schedule='short', + ) + + # ... + + config.shuffle_buffer = 250_000 if not arg.runlocal else 50 + + Ways that values can be passed when launching: + + --config amazing.py:runlocal,schedule=long,res=128 + --config amazing.py:res=128 + --config amazing.py:runlocal # A boolean needs no value for "true". + --config amazing.py:runlocal=False # Explicit false boolean. + --config amazing.py:128 # The first spec entry may be passed unnamed alone. + + Uses strict bool conversion (converting 'True', 'true' to True, and 'False', + 'false', '' to False). + + Args: + arg: the string argument that's passed to get_config. + lazy: allow lazy parsing of arguments, which are not in spec. For these, + the type is auto-extracted in dependence of most complex possible type. + **spec: the name and default values of the expected options. + If the value is a tuple, the value's first element is the default value, + and the second element is a function called to convert the string. + Otherwise the type is automatically extracted from the default value. + + Returns: + ConfigDict object with extracted type-converted values. + """ + # Normalize arg and spec layout. + arg = arg or '' # Normalize None to empty string + spec = {k: get_type_with_default(v) for k, v in spec.items()} + + result = mlc.ConfigDict(type_safe=False) # For convenient dot-access only. + + # Expand convenience-cases for a single parameter without = sign. + if arg and ',' not in arg and '=' not in arg: + # (think :runlocal) If it's the name of sth in the spec (or there is no + # spec), it's that in bool. + if arg in spec or not spec: + arg = f'{arg}=True' + # Otherwise, it is the value for the first entry in the spec. + else: + arg = f'{list(spec.keys())[0]}={arg}' + # Yes, we rely on Py3.7 insertion order! + + # Now, expand the `arg` string into a dict of keys and values: + raw_kv = {raw_arg.split('=')[0]: + raw_arg.split('=', 1)[-1] if '=' in raw_arg else 'True' + for raw_arg in arg.split(',') if raw_arg} + + # And go through the spec, using provided or default value for each: + for name, (default, type_fn) in spec.items(): + val = raw_kv.pop(name, None) + result[name] = type_fn(val) if val is not None else default + + if raw_kv: + if lazy: # Process args which are not in spec. + for k, v in raw_kv.items(): + result[k] = autotype(v) + else: + raise ValueError(f'Unhandled config args remain: {raw_kv}') + + return result + + +def get_type_with_default(v): + """Returns (v, string_to_v_type) with lenient bool parsing.""" + # For bool, do safe string conversion. + if isinstance(v, bool): + def strict_bool(x): + assert x.lower() in {'true', 'false', ''} + return x.lower() == 'true' + return (v, strict_bool) + # If already a (default, type) tuple, use that. + if isinstance(v, (tuple, list)): + assert len(v) == 2 and isinstance(v[1], type), ( + 'List or tuple types are currently not supported because we use `,` as' + ' dumb delimiter. Contributions (probably using ast) welcome. You can' + ' unblock by using a string with eval(s.replace(";", ",")) or similar') + return (v[0], v[1]) + # Otherwise, derive the type from the default value. + return (v, type(v)) + + +def autotype(x): + """Auto-converts string to bool/int/float if possible.""" + assert isinstance(x, str) + if x.lower() in {'true', 'false'}: + return x.lower() == 'true' # Returns as bool. + try: + return int(x) # Returns as int. + except ValueError: + try: + return float(x) # Returns as float. + except ValueError: + return x # Returns as str. + + +def pack_arg(**kw): + """Packs key-word args as a string to be parsed by `parse_arg()`.""" + for v in kw.values(): + assert ',' not in f'{v}', f"Can't use `,` in config_arg value: {v}" + return ','.join([f'{k}={v}' for k, v in kw.items()]) + + +def arg(**kw): + """Use like `add(**bvcc.arg(res=256, foo=bar), lr=0.1)` to pass config_arg.""" + return {'config_arg': pack_arg(**kw), **kw} + + +def _get_field_ref(config_dict, field_name): + path = field_name.split('.') + for field in path[:-1]: + config_dict = getattr(config_dict, field) + return config_dict.get_ref(path[-1]) + + +def format_str(format_string, config): + """Format string with reference fields from config. + + This makes it easy to build preprocess strings that contain references to + fields tha are edited after. E.g.: + + ``` + config = mlc.ConficDict() + config.res = (256, 256) + config.pp = bvcc.format_str('resize({res})', config) + ... + # if config.res is modified (e.g. via sweeps) it will propagate to pp field: + config.res = (512, 512) + assert config.pp == 'resize((512, 512))' + ``` + + Args: + format_string: string to format with references. + config: ConfigDict to get references to format the string. + + Returns: + A reference field which renders a string using references to config fields. + """ + output = '' + parts = string.Formatter().parse(format_string) + for (literal_text, field_name, format_spec, conversion) in parts: + assert not format_spec and not conversion + output += literal_text + if field_name: + output += _get_field_ref(config, field_name).to_str() + return output diff --git a/big_vision/configs/common_fewshot.py b/big_vision/configs/common_fewshot.py new file mode 100644 index 0000000000000000000000000000000000000000..21fdc586a3266c51400e22d6315cd9d130821f09 --- /dev/null +++ b/big_vision/configs/common_fewshot.py @@ -0,0 +1,56 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Most common few-shot eval configuration.""" + +import ml_collections as mlc + + +def get_fewshot_lsr(target_resolution=224, resize_resolution=256, + runlocal=False, **kw): + """Returns a standard-ish fewshot eval configuration.""" + kw.setdefault('representation_layer', 'pre_logits') + kw.setdefault('shots', (1, 5, 10, 25)) + kw.setdefault('l2_reg', 2.0 ** 10) + kw.setdefault('num_seeds', 3) + kw.setdefault('prefix', '') # No prefix as we already use a/ z/ and zz/ + + # Backward-compatible default: + if not any(f'log_{x}' in kw for x in ['steps', 'percent', 'examples', 'epochs']): # pylint: disable=line-too-long + kw['log_steps'] = 25_000 + + config = mlc.ConfigDict(kw) + config.type = 'fewshot_lsr' + config.datasets = { + 'caltech': ('caltech101', 'train', 'test'), # copybara:srtip + 'cars': ('cars196:2.1.0', 'train', 'test'), + 'cifar100': ('cifar100', 'train', 'test'), + 'dtd': ('dtd', 'train', 'test'), + # The first 65000 ImageNet samples have at least 30 shots per any class. + # Commented out by default because needs manual download. + # 'imagenet': ('imagenet2012', 'train[:65000]', 'validation'), + 'pets': ('oxford_iiit_pet', 'train', 'test'), + 'uc_merced': ('uc_merced', 'train[:1000]', 'train[1000:]'), + } if not runlocal else { + 'pets': ('oxford_iiit_pet', 'train', 'test'), + } + config.pp_train = (f'decode|resize({resize_resolution})|' + f'central_crop({target_resolution})|' + f'value_range(-1,1)|keep("image", "label")') + config.pp_eval = (f'decode|resize({resize_resolution})|' + f'central_crop({target_resolution})|' + f'value_range(-1,1)|keep("image", "label")') + config.display_first = [('imagenet', 10)] if not runlocal else [('pets', 10)] + + return config diff --git a/big_vision/configs/load_and_eval.py b/big_vision/configs/load_and_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..7e102b0f561f2cc6ec59439f831e9e289488b7b0 --- /dev/null +++ b/big_vision/configs/load_and_eval.py @@ -0,0 +1,143 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pytype: disable=not-writable,attribute-error +# pylint: disable=line-too-long,missing-function-docstring +r"""A config to load and eval key model using the core train.py. + +The runtime varies widely depending on the model, but each one should reproduce +the corresponding paper's numbers. +This configuration makes use of the "arg" to get_config to select which model +to run, so a few examples are given below: + +Run and evaluate a BiT-M ResNet-50x1 model that was transferred to i1k: + +big_vision.train \ + --config big_vision/configs/load_and_eval.py:name=bit_paper,batch_size=8 \ + --config.model_init M-imagenet2012 --config.model.width 1 --config.model.depth 50 + +Run and evaluate the recommended ViT-B/32 from "how to train your vit" paper: + +big_vision.train \ + --config big_vision/configs/load_and_eval.py:name=vit_i21k,batch_size=8 \ + --config.model.variant B/32 --config.model_init howto-i21k-B/32 +""" + +import big_vision.configs.common as bvcc +from big_vision.configs.common_fewshot import get_fewshot_lsr + + +def eval_only(config, batch_size, spec_for_init): + """Set a few configs that turn trainer into (almost) eval-only.""" + config.total_steps = 0 + config.input = {} + config.input.batch_size = batch_size + config.input.data = dict(name='bv:dummy', spec=spec_for_init) + config.optax_name = 'identity' + config.lr = 0.0 + + config.mesh = [('data', -1)] + config.sharding_strategy = [('params/.*', 'fsdp(axis="data")')] + config.sharding_rules = [('act_batch', ('data',))] + + return config + + +def get_config(arg=''): + config = bvcc.parse_arg(arg, name='bit_paper', batch_size=4) + + # Make the config eval-only by setting some dummies. + eval_only(config, config.batch_size, spec_for_init=dict( + image=dict(shape=(224, 224, 3), dtype='float32'), + )) + + config.evals = dict(fewshot=get_fewshot_lsr()) + + # Just calls the function with the name given as `config`. + # Could also be a giant if-block if you're into that kind of thing. + globals()[config.name](config) + return config + + +def bit_paper(config): + config.num_classes = 1000 + + config.model_name = 'bit_paper' + config.model_init = 'M-imagenet2012' # M = i21k, -imagenet2012 = fine-tuned + config.model = dict(width=1, depth=50) + + def get_eval(split, lbl, dataset='imagenet2012_real'): + return dict( + type='classification', + data=dict(name=dataset, split=split), + loss_name='softmax_xent', + cache='none', # Only run once, on low-mem machine. + pp_fn=( + 'decode|resize(384)|value_range(-1, 1)' + f'|onehot(1000, key="{lbl}", key_result="labels")' + '|keep("image", "labels")' + ), + ) + config.evals.test = get_eval('validation', 'original_label') + config.evals.real = get_eval('validation', 'real_label') + config.evals.v2 = get_eval('test', 'label', 'imagenet_v2') + + +def vit_i1k(config): + config.num_classes = 1000 + + config.model_name = 'vit' + config.model_init = '' # Will be set in sweep. + config.model = dict(variant='S/16', pool_type='gap', posemb='sincos2d', + rep_size=True) + + config.evals.val = dict( + type='classification', + data=dict(name='imagenet2012', split='validation'), + pp_fn='decode|resize_small(256)|central_crop(224)|value_range(-1, 1)|onehot(1000, key="label", key_result="labels")|keep("image", "labels")', + loss_name='softmax_xent', + cache='none', # Only run once, on low-mem machine. + ) + + +def mlp_mixer_i1k(config): + config.num_classes = 1000 + + config.model_name = 'mlp_mixer' + config.model_init = '' # Will be set in sweep. + config.model = dict(variant='L/16') + + config.evals.val = dict( + type='classification', + data=dict(name='imagenet2012', split='validation'), + pp_fn='decode|resize_small(256)|central_crop(224)|value_range(-1, 1)|onehot(1000, key="label", key_result="labels")|keep("image", "labels")', + loss_name='softmax_xent', + cache='none', # Only run once, on low-mem machine. + ) + + +def vit_i21k(config): + config.num_classes = 21843 + + config.model_name = 'vit' + config.model_init = '' # Will be set in sweep. + config.model = dict(variant='B/32', pool_type='tok') + + config.evals.val = dict( + type='classification', + data=dict(name='imagenet21k', split='full[:51200]'), + pp_fn='decode|resize_small(256)|central_crop(224)|value_range(-1, 1)|onehot(21843)|keep("image", "labels")', + loss_name='sigmoid_xent', + cache='none', # Only run once, on low-mem machine. + ) diff --git a/big_vision/configs/mlp_mixer_i1k.py b/big_vision/configs/mlp_mixer_i1k.py new file mode 100644 index 0000000000000000000000000000000000000000..8afe9abfd31f4ecb4e53466ea3e2b2794e8af7e7 --- /dev/null +++ b/big_vision/configs/mlp_mixer_i1k.py @@ -0,0 +1,120 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""A config for training MLP-Mixer-B/16 model on ILSVRC-2012 ("ImageNet-1k"). + +Achieves 76.3% top-1 accuracy on the test split in 2h11m on TPU v3-128 +with 300 epochs. A shorter 60 epochs run is expected to get to 70.5% in 27m. + +big_vision.train \ + --config big_vision/configs/mlp_mixer_i1k.py \ + --workdir gs://[your_bucket]/big_vision/`date '+%m-%d_%H%M'` \ +""" + +from big_vision.configs.common_fewshot import get_fewshot_lsr +import ml_collections as mlc + + +def get_config(mode=None): + """Config for training Mixer on i1k.""" + config = mlc.ConfigDict() + + config.seed = 0 + config.total_epochs = 300 + config.num_classes = 1000 + config.loss = 'sigmoid_xent' + config.init_head_bias = -6.9 + + config.input = dict() + config.input.data = dict( + name='imagenet2012', + split='train[:99%]', + ) + config.input.batch_size = 4096 + config.input.cache_raw = True # Needs up to 120GB of RAM! + config.input.shuffle_buffer_size = 250_000 + + config.input.pp = ( + 'decode_jpeg_and_inception_crop(224)' + '|flip_lr' + '|randaug(2,15)' + '|value_range(-1, 1)' + '|onehot(1000, key="label", key_result="labels")' + '|keep("image", "labels")' + ) + pp_eval = ( + 'decode' + '|resize_small(256)|central_crop(224)' + '|value_range(-1, 1)' + '|onehot(1000, key="{lbl}", key_result="labels")' + '|keep("image", "labels")' + ) + + # To continue using the near-defunct randaug op. + config.pp_modules = ['ops_general', 'ops_image', 'ops_text', 'archive.randaug'] + + config.log_training_steps = 50 + config.ckpt_steps = 1000 + + config.prefetch_to_device = 2 + + # Model section + config.model_name = 'mlp_mixer' + config.model = dict() + config.model.variant = 'B/16' + config.model.stoch_depth = 0.1 + + config.mixup = dict(fold_in=None, p=0.5) + + # Optimizer section + config.optax_name = 'scale_by_adam' + config.grad_clip_norm = 1. + + config.lr = 0.001 + config.wd = 1e-4 + config.schedule = dict( + decay_type='linear', + warmup_steps=10_000, + linear_end=1e-5, + ) + + # Eval section + def get_eval(split, dataset='imagenet2012'): + return dict( + type='classification', + data=dict(name=dataset, split=split), + pp_fn=pp_eval.format(lbl='label'), + loss_name=config.loss, + log_steps=2500, # Very fast O(seconds) so it's fine to run it often. + cache_final=mode != 'gpu8', + ) + config.evals = {} + config.evals.train = get_eval('train[:2%]') + config.evals.minival = get_eval('train[99%:]') + config.evals.val = get_eval('validation') + config.evals.v2 = get_eval('test', dataset='imagenet_v2') + config.evals.real = get_eval('validation', dataset='imagenet2012_real') + config.evals.real.pp_fn = pp_eval.format(lbl='real_label') + + config.fewshot = get_fewshot_lsr() + + if mode == 'gpu8': + config.total_epochs = 60 + config.input.batch_size = 512 + config.input.cache_raw = False + if mode == 'regression_test': + config.total_epochs = 60 + + return config diff --git a/big_vision/configs/proj/cappa/README.md b/big_vision/configs/proj/cappa/README.md new file mode 100644 index 0000000000000000000000000000000000000000..d5aa147205f0f4139bf40b2ac5d9825120c22522 --- /dev/null +++ b/big_vision/configs/proj/cappa/README.md @@ -0,0 +1,37 @@ +# Image Captioners Are Scalable Vision Learners Too + +*by Michael Tschannen, Manoj Kumar, Andreas Steiner, Xiaohua Zhai, Neil Houlsby, Lucas Beyer* [[arxiv]](https://arxiv.org/abs/2306.07915) + +![CapPa Architecture](./cappa_architecture.png) + +This directory contains a config for training a CapPa model from scratch. +Note that most models in the paper were trained on a proprietary dataset +(WebLI), but similar results can be obtained by training on [LAION](https://laion.ai/). + +By default, this config trains on COCO captions as this data set is readily +available in [TFDS](https://www.tensorflow.org/datasets) without manual steps. +This is not meant to produce a meaningful model, but +provides a way for the user to run the config out of the box. Please update the +config with with a TFDS-wrapped variant of your favorite image/text data set to +train capable models. + +After setting up `big_vision` as described in the [main README](https://github.com/google-research/big_vision#cloud-tpu-vm-setup), training can be launched as follows + +``` +python -m big_vision.trainers.proj.cappa.generative \ + --config big_vision/configs/proj/cappa/pretrain.py \ + --workdir gs://$GS_BUCKET_NAME/big_vision/`date '+%m-%d_%H%M'` +``` + +To run the Cap baseline (autoregressive captioning without parallel prediction), +set `config.model.masked_pred_prob = 0.0`. + +### Citation +``` +@inproceedings{tschannen2023image, + title={Image Captioners Are Scalable Vision Learners Too}, + author={Tschannen, Michael and Kumar, Manoj and Steiner, Andreas and Zhai, Xiaohua and Houlsby, Neil and Beyer, Lucas}, + booktitle={Neural Information Processing Systems (NeurIPS)}, + year={2023} +} +``` diff --git a/big_vision/configs/proj/cappa/cappa_architecture.png b/big_vision/configs/proj/cappa/cappa_architecture.png new file mode 100644 index 0000000000000000000000000000000000000000..5cf85b74beb97abd4e97d71bd10ca1d8a83c45ab Binary files /dev/null and b/big_vision/configs/proj/cappa/cappa_architecture.png differ diff --git a/big_vision/configs/proj/cappa/pretrain.py b/big_vision/configs/proj/cappa/pretrain.py new file mode 100644 index 0000000000000000000000000000000000000000..8b0df3c1df1c12c50f820ca05237d52f14ce20ad --- /dev/null +++ b/big_vision/configs/proj/cappa/pretrain.py @@ -0,0 +1,140 @@ +# Copyright 2023 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""Trains a CapPa model (https://arxiv.org/abs/2306.07915) on coco_captions. + +This config is for reference, we never ran a full training on a large +image/text data set on public infrastructure. + +big_vision.trainers.proj.cappa.generative \ + --config big_vision/configs/proj/cappa/pretrain.py \ + --workdir gs://[your_bucket]/big_vision/`date '+%m-%d_%H%M'` +""" + + +from big_vision.configs import common_fewshot +import big_vision.configs.common as bvcc +import ml_collections + + +def get_config(arg=None): + """Returns the base config.""" + config = bvcc.parse_arg(arg, + runlocal=False, + total_steps=366_500, + batch_size=8*1024, + warmup_steps=10_000, + ) + + config.evals = {} + config.input = {} + config.input.batch_size = config.batch_size if not config.runlocal else 8 + shuffle_buffer_size = 50_000 if not config.runlocal else 50 + + res = 224 + patch_size = 16 + max_text_tokens = 64 + + pp_image = (f'resize({res})|value_range(-1,1)') + + def tokenizer(inkey, outkey): + return (f'tokenize(max_len={max_text_tokens}, model="c4_en", ' + f'eos="sticky", inkey="{inkey}", outkey="{outkey}")') + + pp_coco = (f'decode|{pp_image}|' + 'coco_captions("captions")|choice(inkey="captions", outkey="text")|' + f'{tokenizer("text", "labels")}|keep("image", "labels")') + config.input.pp = pp_coco + + # NOTE: "coco_captions" is way too small a dataset to train on. It's simply + # used here to serve as a smoke test that the implementation works correctly. + config.input.data = dict(name='coco_captions', split='train') # num_examples=82_783 + config.input.shuffle_buffer_size = shuffle_buffer_size + + config.evals.val_coco = { + 'type': 'proj.cappa.perplexity', + 'pred': 'perplexity', + 'log_steps': 1000, + 'data': dict(name='coco_captions', split='val'), # num_examples=5_000 + 'pp_fn': pp_coco, + } + + # Few-shot metrics + config.evals.fewshot = common_fewshot.get_fewshot_lsr( + target_resolution=res, resize_resolution=int(256 / 224 * res)) + config.evals.fewshot.type = 'fewshot_lsr' + config.evals.fewshot.log_steps = 5_000 if not config.runlocal else 5 + config.evals.fewshot.representation_layer = 'pre_logits' + config.evals.fewshot.pred = 'enc_rep' + config.evals.fewshot.pp_eval = config.evals.fewshot.pp_train + + # NOTE: Scoring of the entire imagenet validation set is rather slow: + # ~100 secs / 1k classes / host. + config.evals['imagenet/scoring'] = dict( + type='proj.cappa.scoring_classifier', + pred='score', + log_percent=0.1, + data=dict(name='imagenet2012', split='validation'), + pp_fn=f'decode|{pp_image}|keep("image", "label")', + pp_txt=tokenizer('label', 'labels'), + ) + + for e in config.evals.values(): + e.skip_first = True + + config.log_training_steps = 50 + config.ckpt_steps = 1000 + config.keep_ckpt_steps = None # 10_000 + + # Model section + config.model_name = 'proj.cappa.cappa' + config.model = ml_collections.ConfigDict() + config.model.num_layers = 12 + config.model.num_heads = 12 + config.model.mlp_dim = 3072 + config.model.emb_dim = 768 + config.model.vocab_size = 32_000 + config.model.patches = (patch_size, patch_size) + config.model.seq_len = max_text_tokens + config.model.posemb_type = 'learn' + + # Decoder + config.model.decoder_num_layers = 6 + # 0 values here mean to use the same value as for the encoder + config.model.decoder_num_heads = 0 + config.model.decoder_mlp_dim = 0 + config.model.decoder_emb_dim = 0 + config.model.dec_dropout_rate = 0.0 + config.model.masked_pred_prob = 0.75 + config.model.masking_ratio = 1.0 + config.model.decoder_bias = False + + config.optax_name = 'big_vision.scale_by_adafactor' + config.optax = dict(beta2_cap=0.999) + config.grad_clip_norm = 1.0 + config.label_smoothing = 0.0 + + schedule = dict(decay_type='cosine', + warmup_steps=config.warmup_steps + if not config.runlocal else 5) + + # Standard schedule + config.lr = 0.001 + config.wd = 0.0001 + config.schedule = schedule + + config.seed = 0 + + return config \ No newline at end of file diff --git a/big_vision/configs/proj/clippo/README.md b/big_vision/configs/proj/clippo/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e33f891363053fad22b1ad4f84c6af6d316e72c5 --- /dev/null +++ b/big_vision/configs/proj/clippo/README.md @@ -0,0 +1,85 @@ +## Image-and-Language Understanding from Pixels Only + +*by Michael Tschannen, Basil Mustafa, Neil Houlsby* [[arxiv]](https://arxiv.org/abs/2212.08045) [[colab]](https://colab.research.google.com/github/google-research/big_vision/blob/main/big_vision/configs/proj/clippo/clippo_colab.ipynb) + +We provide pretrained CLIP with Pixels Only (CLIPPO) models and code to train such models on image/alt-text data sets. + +### Pretrained models + +Six ViT-B/16 models trained on a mix of [`YFCC-100M`](https://arxiv.org/abs/1503.01817) and [`C4`](https://arxiv.org/abs/1910.10683) (some initialized with an [ImageNet21k-pretrained checkpoint](https://github.com/google-research/vision_transformer#vision-transformer)\) are available. +These models were trained using the schedules and hyperparameters described in the paper. We use the full `YFCC-100M` data set, sampling one of the available `title/description/tag` annotations at random for each each example. We drop non-descriptive annotations (e.g. descriptions consisting of digits only) following the filtering procedure outlined in the [LiT paper](https://arxiv.org/abs/2303.04671), Appendix E. The preprocessing for the `C4` data is as described in the paper. + +The tables below show details about the checkpoints and their performance on Vision & Language benchmarks, and [`GLUE`](https://arxiv.org/abs/1804.07461). We also provide a [colab](https://colab.research.google.com/github/google-research/big_vision/blob/main/big_vision/configs/proj/clippo/clippo_colab.ipynb) to load the models, compute embeddings, and perform zero-shot classification. + +##### Checkpoint details + +| model | training dataset | #param. | steps | checkpoint | +|:-----------------|:-------------------|:----------|:--------|:-----------| +| CLIPPO | YFCC-100M | 93M | 250k | `gs://big_vision/clippo/clippo_b16_yfcc100m.npz` | +| CLIPPO I21k init | YFCC-100M | 93M | 250k | `gs://big_vision/clippo/clippo_b16_yfcc100m_i21k_init.npz` | +| CLIPPO I21k init | YFCC-100M + 25%C4 | 93M | 333k | `gs://big_vision/clippo/clippo_b16_yfcc100m_i21k_init_25c4.npz` | +| CLIPPO I21k init | YFCC-100M + 50%C4 | 93M | 500k | `gs://big_vision/clippo/clippo_b16_yfcc100m_i21k_init_50c4.npz` | +| CLIPPO I21k init | YFCC-100M + 75%C4 | 93M | 500k | `gs://big_vision/clippo/clippo_b16_yfcc100m_i21k_init_75c4.npz` | +| CLIPPO | C4 | 93M | 250k | `gs://big_vision/clippo/clippo_b16_100c4.npz` | + +##### Vision \& Language results + +| model | training dataset | ImageNet 10-shot | ImageNet 0-shot | MS-COCO I→T | MS-COCO T→I | +|:-----------------|:-------------------|-----------:|----------:|--------:|--------:| +| CLIPPO | YFCC-100M | 38.2 | 43.4 | 34.7 | 19.7 | +| CLIPPO I21k init | YFCC-100M | 44.7 | 47.4 | 36.1 | 21.3 | +| CLIPPO I21k init | YFCC-100M + 25%C4 | 43.8 | 44.8 | 33.3 | 19.4 | +| CLIPPO I21k init | YFCC-100M + 50%C4 | 41.2 | 42.0 | 31.4 | 17.8 | +| CLIPPO I21k init | YFCC-100M + 75%C4 | 34.5 | 33.4 | 26.6 | 14.6 | + +##### GLUE results + +| model | training dataset | MNLI-M/MM | QQP | QNLI | SST-2 | COLA | STS-B | MRPC | RTE | avg | +|:-----------------|:-------------------|:------------|------:|-------:|--------:|-------:|--------:|-------:|------:|------:| +| CLIPPO | YFCC-100M | 71.3 / 71.5 | 79.1 | 67.9 | 85.7 | 0.0 | 14.0 | 83.4 | 54.9 | 58.6 | +| CLIPPO I21k init | YFCC-100M | 70.0 / 70.1 | 83.7 | 81.6 | 86.1 | 0.0 | 18.5 | 83.0 | 53.1 | 60.7 | +| CLIPPO I21k init | YFCC-100M + 25%C4 | 75.7 / 75.1 | 85.2 | 83.5 | 89.6 | 0.0 | 82.3 | 82.7 | 52.7 | 69.7 | +| CLIPPO I21k init | YFCC-100M + 50%C4 | 77.4 / 77.4 | 86.0 | 83.9 | 91.7 | 34.5 | 84.5 | 85.1 | 56.3 | 75.2 | +| CLIPPO I21k init | YFCC-100M + 75%C4 | 79.8 / 79.1 | 86.5 | 84.3 | 92.0 | 44.5 | 85.3 | 88.2 | 58.5 | 77.6 | +| CLIPPO | C4 | 79.9 / 80.2 | 86.7 | 85.2 | 93.3 | 50.9 | 84.7 | 86.3 | 58.5 | 78.4 | + +### Training your own models + +To train your own CLIPPO model, please follow the setup instructions in the [`big_vision` main README](https://github.com/google-research/big_vision#cloud-tpu-vm-setup). In the following, we provide the CLIPPO-specific commands required in addition to the setup, assume you are using the Google Cloud TPU setup (potentially with adapted TPU configuration, see table below). If you are using GPUs, please set up your machine directly and only execute the `--command` portions of the commands below from the `big_vision` repository root. + +The text rendering preproprocessing function requires manual download of the Unifont .hex files from [Unifoundry](https://unifoundry.com/unifont/) (please follow link for license): + +```bash +gcloud alpha compute tpus tpu-vm ssh $NAME --zone=$ZONE --worker=all \ +--command "bash big_vision/pp/proj/clippo/download_unifont.sh" +``` + +Launch the training by running + +```bash +gcloud alpha compute tpus tpu-vm ssh $NAME --zone=$ZONE --worker=all \ +--command "TFDS_DATA_DIR=gs://$GS_BUCKET_NAME/tensorflow_datasets bash big_vision/run_tpu.sh big_vision.trainers.proj.image_text.contrastive --config big_vision/configs/proj/clippo/train_clippo.py --workdir gs://$GS_BUCKET_NAME/big_vision/workdir/`date '+%m-%d_%H%M'`" +``` + +*Important note:* The input pipeline relies on [TensorFlow Datasets (TFDS)](https://www.tensorflow.org/datasets) which does not provide automatic integration with large image/alt-text datasets out of the box. The above config therefore trains by default on MS-COCO Captions which can be automatically downloaded via TFDS, and additionally initializes the CLIPPO ViT backbone with weights pretrained on ImageNet21k. This setup is not meant to produce good accuracy, but to provide the user with a way to sanity-check their setup. If you want to train on a large data set such as [`LAION-400M`](https://arxiv.org/abs/2111.02114) or [`YFCC-100M`](https://arxiv.org/abs/1503.01817), please follow [these instructions](https://www.tensorflow.org/datasets/add_dataset) to wrap your data set using TFDS, and update the dataset in the config accordingly. Also note that the ImageNet1k evaluations require manual download of the data, see [these instructions](https://github.com/google-research/big_vision#preparing-tfds-data). To train with your own data set and with ImageNet1k-based evaluations, use `--config big_vision/configs/proj/clippo/train_clippo.py:test_with_coco=False,i1k_eval=True` in the command above. + +##### Expected results + +| train dataset | batch size | #steps | TPU chips | ImageNet 0-shot | MS-COCO I→T | MS-COCO T→I | Config `arg` | +| :--- | ---: | ---: | ---: | :---: | :---: | :---: | :--- | +| *MS-COCO (sanity check)* | 4000 | 400 | 32 v3 | 4.2 | 12.6 | 8.6 | `i1k_eval=True` | +| LAION-400M | 8192 | 100k |128 v2 | 51.5 | 44.8 | 29.3 | `test_with_coco=False,i1k_eval=True` | +| LAION-400M | 10240\* | 100k | 128 v3 | 53.6 | 46.7 | 30.3 | `test_with_coco=False,i1k_eval=True` | + +\* The experiments in the paper use a batch size of 10240 which requires a memory-optimized ViT implementation to run on 128 TPU v2 chips or 128 TPU v3 chips (in which case the TPU memory capacity allows to increase the batch size beyond 10240). + +### Citation + +``` +@inproceedings{tschannen2023image, + title={Image-and-Language Understanding from Pixels Only}, + author={Tschannen, Michael and Mustafa, Basil and Houlsby, Neil}, + booktitle={Proc. IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, + year={2023} +} +``` diff --git a/big_vision/configs/proj/clippo/clippo_colab.ipynb b/big_vision/configs/proj/clippo/clippo_colab.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..dd540a85114007630b052ee024bd1cd62da45204 --- /dev/null +++ b/big_vision/configs/proj/clippo/clippo_colab.ipynb @@ -0,0 +1 @@ +{"cells":[{"cell_type":"markdown","metadata":{"id":"HRDyfbk2klCU"},"source":["# CLIPPO colab\n","\n","Paper: [Image-and-Language Understanding from Pixels Only](https://arxiv.org/abs/2212.08045)\n","\n","This colab shows how to\n","- load pretrained CLIP with Pixels Only (CLIPPO) models,\n","- use them to compute image and text embeddings,\n","- perform zero-shot image and text classification.\n","\n","Six ViT-B/16 models trained on a mix of [YFCC-100M](https://arxiv.org/abs/1503.01817) and [C4](https://arxiv.org/abs/1910.10683) (some initialized with an [ImageNet21k-pretrained checkpoint](https://github.com/google-research/vision_transformer#vision-transformer)\\) are available. Please refer to the [GitHub readme](https://github.com/google-research/big_vision/blob/main/big_vision/configs/proj/clippo/README.md) for training code and details on the checkpoints.\n","\n","This colab is derived from the [colab](https://colab.research.google.com/github/google-research/vision_transformer/blob/main/lit.ipynb) accompanying the paper [LiT: Zero-Shot Transfer with Locked-Image Text Tuning](https://arxiv.org/abs/2111.07991)."]},{"cell_type":"markdown","metadata":{"id":"Vauyt3WMehFx"},"source":["## Set up the environment"]},{"cell_type":"code","execution_count":1,"metadata":{"id":"JHkfIOBXp-2J","colab":{"base_uri":"https://localhost:8080/"},"outputId":"2f09c28a-8c36-4379-c670-164a74cd95e8","executionInfo":{"status":"ok","timestamp":1678890204634,"user_tz":-60,"elapsed":2243,"user":{"displayName":"Michael Tschannen","userId":"02619997334944082183"}}},"outputs":[{"output_type":"stream","name":"stdout","text":["Cloning into 'big_vision'...\n","remote: Enumerating objects: 210, done.\u001b[K\n","remote: Counting objects: 100% (210/210), done.\u001b[K\n","remote: Compressing objects: 100% (181/181), done.\u001b[K\n","remote: Total 210 (delta 40), reused 102 (delta 18), pack-reused 0\u001b[K\n","Receiving objects: 100% (210/210), 470.30 KiB | 10.94 MiB/s, done.\n","Resolving deltas: 100% (40/40), done.\n","Already up to date.\n"]}],"source":["# Clone the big_vision repository\n","!git clone --branch=main --depth=1 https://github.com/google-research/big_vision\n","!cd big_vision && git pull"]},{"cell_type":"code","source":["# Install the python dependencies\n","!pip install -qr big_vision/big_vision/requirements.txt"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"NWLvz5yvH7La","executionInfo":{"status":"ok","timestamp":1678890254330,"user_tz":-60,"elapsed":49697,"user":{"displayName":"Michael Tschannen","userId":"02619997334944082183"}},"outputId":"6388a22b-e079-4103-cc03-127b38820d65"},"execution_count":2,"outputs":[{"output_type":"stream","name":"stdout","text":[" Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n"," Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n"," Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m96.8/96.8 KB\u001b[0m \u001b[31m2.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m41.6/41.6 KB\u001b[0m \u001b[31m2.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m214.2/214.2 KB\u001b[0m \u001b[31m9.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.4/5.4 MB\u001b[0m \u001b[31m12.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.1/1.1 MB\u001b[0m \u001b[31m26.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.8/5.8 MB\u001b[0m \u001b[31m25.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m367.1/367.1 KB\u001b[0m \u001b[31m11.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m77.9/77.9 KB\u001b[0m \u001b[31m4.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m8.4/8.4 MB\u001b[0m \u001b[31m22.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m238.7/238.7 KB\u001b[0m \u001b[31m20.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m74.2/74.2 KB\u001b[0m \u001b[31m7.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m87.9/87.9 KB\u001b[0m \u001b[31m4.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m349.9/349.9 KB\u001b[0m \u001b[31m18.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.0/3.0 MB\u001b[0m \u001b[31m9.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.1/1.1 MB\u001b[0m \u001b[31m32.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m84.5/84.5 KB\u001b[0m \u001b[31m8.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25h Building wheel for flaxformer (setup.py) ... \u001b[?25l\u001b[?25hdone\n"," Building wheel for optax (setup.py) ... \u001b[?25l\u001b[?25hdone\n"," Building wheel for panopticapi (setup.py) ... \u001b[?25l\u001b[?25hdone\n"," Building wheel for ml-collections (setup.py) ... \u001b[?25l\u001b[?25hdone\n","\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n","ipython 7.9.0 requires jedi>=0.10, which is not installed.\u001b[0m\u001b[31m\n","\u001b[0m"]}]},{"cell_type":"code","source":["# Download Unifont for text rendering\n","!wget https://unifoundry.com/pub/unifont/unifont-9.0.06/font-builds/unifont-9.0.06.hex.gz https://unifoundry.com/pub/unifont/unifont-9.0.06/font-builds/unifont_upper-9.0.06.hex.gz\n","!gunzip unifont-9.0.06.hex.gz unifont_upper-9.0.06.hex.gz\n","!mv unifont-9.0.06.hex unifont_upper-9.0.06.hex big_vision/big_vision/pp/proj/clippo/"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"VAkDMh3yQF1A","executionInfo":{"status":"ok","timestamp":1678890255858,"user_tz":-60,"elapsed":1535,"user":{"displayName":"Michael Tschannen","userId":"02619997334944082183"}},"outputId":"0e2851b4-95c6-4260-c3b7-58c2e745aa99"},"execution_count":3,"outputs":[{"output_type":"stream","name":"stdout","text":["--2023-03-15 14:24:14-- https://unifoundry.com/pub/unifont/unifont-9.0.06/font-builds/unifont-9.0.06.hex.gz\n","Resolving unifoundry.com (unifoundry.com)... 107.180.4.157\n","Connecting to unifoundry.com (unifoundry.com)|107.180.4.157|:443... connected.\n","HTTP request sent, awaiting response... 200 OK\n","Length: 939547 (918K) [application/x-gzip]\n","Saving to: ‘unifont-9.0.06.hex.gz’\n","\n","unifont-9.0.06.hex. 100%[===================>] 917.53K 1.14MB/s in 0.8s \n","\n","2023-03-15 14:24:15 (1.14 MB/s) - ‘unifont-9.0.06.hex.gz’ saved [939547/939547]\n","\n","--2023-03-15 14:24:15-- https://unifoundry.com/pub/unifont/unifont-9.0.06/font-builds/unifont_upper-9.0.06.hex.gz\n","Reusing existing connection to unifoundry.com:443.\n","HTTP request sent, awaiting response... 200 OK\n","Length: 112342 (110K) [application/x-gzip]\n","Saving to: ‘unifont_upper-9.0.06.hex.gz’\n","\n","unifont_upper-9.0.0 100%[===================>] 109.71K --.-KB/s in 0.001s \n","\n","2023-03-15 14:24:15 (165 MB/s) - ‘unifont_upper-9.0.06.hex.gz’ saved [112342/112342]\n","\n","FINISHED --2023-03-15 14:24:15--\n","Total wall clock time: 1.3s\n","Downloaded: 2 files, 1.0M in 0.8s (1.27 MB/s)\n"]}]},{"cell_type":"code","source":["%cd big_vision"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"Ri6EirCxRCLW","executionInfo":{"status":"ok","timestamp":1678890255858,"user_tz":-60,"elapsed":4,"user":{"displayName":"Michael Tschannen","userId":"02619997334944082183"}},"outputId":"e262823d-d502-4159-a731-0c1649a85a85"},"execution_count":4,"outputs":[{"output_type":"stream","name":"stdout","text":["/content/big_vision\n"]}]},{"cell_type":"markdown","source":["## Load a checkpoint and initialize the model"],"metadata":{"id":"TskButpkoWBz"}},{"cell_type":"code","execution_count":5,"metadata":{"id":"4DS88TsHsli7","executionInfo":{"status":"ok","timestamp":1678890260586,"user_tz":-60,"elapsed":4730,"user":{"displayName":"Michael Tschannen","userId":"02619997334944082183"}}},"outputs":[],"source":["import jax\n","import jax.numpy as jnp\n","from matplotlib import pyplot as plt\n","import numpy as np\n","import pandas as pd\n","import tensorflow as tf\n","import tensorflow_datasets as tfds\n","import tqdm\n","import importlib"]},{"cell_type":"code","source":["# Select the checkpoint and download it\n","checkpoint_paths = {\n"," 'clippo_b16_yfcc100m': 'gs://big_vision/clippo/clippo_b16_yfcc100m.npz',\n"," 'clippo_b16_yfcc100m_i21k_init': 'gs://big_vision/clippo/clippo_b16_yfcc100m_i21k_init.npz',\n"," 'clippo_b16_yfcc100m_i21k_init_25c4': 'gs://big_vision/clippo/clippo_b16_yfcc100m_i21k_init_25c4.npz',\n"," 'clippo_b16_yfcc100m_i21k_init_50c4': 'gs://big_vision/clippo/clippo_b16_yfcc100m_i21k_init_50c4.npz',\n"," 'clippo_b16_yfcc100m_i21k_init_75c4': 'gs://big_vision/clippo/clippo_b16_yfcc100m_i21k_init_75c4.npz',\n"," 'clippo_b16_100c4': 'gs://big_vision/clippo/clippo_b16_100c4.npz'\n","}\n","\n","checkpoint = 'clippo_b16_yfcc100m_i21k_init_25c4'\n","checkpoint_path = checkpoint_paths[checkpoint]\n","!gsutil cp $checkpoint_path ."],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"y7DA8gxrSe0B","executionInfo":{"status":"ok","timestamp":1678890270393,"user_tz":-60,"elapsed":9812,"user":{"displayName":"Michael Tschannen","userId":"02619997334944082183"}},"outputId":"f76d861d-f55c-4a6d-cc2e-a41c6eb81dde"},"execution_count":6,"outputs":[{"output_type":"stream","name":"stdout","text":["Copying gs://big_vision/clippo/clippo_b16_yfcc100m_i21k_init_25c4.npz...\n","/ [0 files][ 0.0 B/540.4 MiB] \r==> NOTE: You are downloading one or more large file(s), which would\n","run significantly faster if you enabled sliced object downloads. This\n","feature is enabled by default but requires that compiled crcmod be\n","installed (see \"gsutil help crcmod\").\n","\n","\\ [1 files][540.4 MiB/540.4 MiB] \n","Operation completed over 1 objects/540.4 MiB. \n"]}]},{"cell_type":"code","source":["from big_vision.configs.proj.clippo import train_clippo\n","from big_vision import utils\n","\n","# The models are trained for resolution 224\n","RES = 224\n","\n","# Load model module\n","config = train_clippo.get_config()\n","model_module = importlib.import_module(f'big_vision.models.{config.model_name}')\n","model = model_module.Model(**config.model)\n","\n","# Load model parameters\n","params = utils.load_checkpoint(None, checkpoint_path)['params']"],"metadata":{"id":"dC2pE8r8GF4g","executionInfo":{"status":"ok","timestamp":1678890287922,"user_tz":-60,"elapsed":17532,"user":{"displayName":"Michael Tschannen","userId":"02619997334944082183"}}},"execution_count":7,"outputs":[]},{"cell_type":"code","source":["# Define and load preprocessing functions\n","from big_vision.pp import builder as pp_builder\n","for pp_modules in config.pp_modules:\n"," importlib.import_module(f'big_vision.pp.{pp_modules}')\n","\n","# Unifont renderer\n","def tokenizer(inkey='text', outkey='text'):\n"," return (f'render_unifont('\n"," f'inkey=\"{inkey}\", '\n"," f'outkey=\"{outkey}\", '\n"," f'image_size={RES}, '\n"," f'lower=True, '\n"," f'font_size=16, '\n"," f'text_brightness=0, '\n"," f'background_brightness=127)|'\n"," f'value_range(-1, 1, inkey=\"{outkey}\", outkey=\"{outkey}\")')\n","\n","pp_image_str = f'resize({RES})|value_range(-1,1)'\n","pp_text_str = tokenizer()\n","\n","pp_image_fn = pp_builder.get_preprocess_fn(pp_image_str)\n","pp_text_fn = pp_builder.get_preprocess_fn(pp_text_str)\n","\n","def preprocess_images(images):\n"," return [np.array(pp_image_fn({'image': img})['image']) for img in images]\n","\n","def preprocess_texts(texts):\n"," return [np.array(pp_text_fn({'text': text})['text']) for text in texts]"],"metadata":{"id":"pSXto6vnMEcB","executionInfo":{"status":"ok","timestamp":1678890294214,"user_tz":-60,"elapsed":6296,"user":{"displayName":"Michael Tschannen","userId":"02619997334944082183"}}},"execution_count":8,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"fzQg_w-Cn8a2"},"source":["## Use the model"]},{"cell_type":"markdown","source":["### Compute and compare image/text embeddings"],"metadata":{"id":"Y4LxH_faqXG8"}},{"cell_type":"code","execution_count":9,"metadata":{"id":"YBd39mOzmjPK","outputId":"b090c383-bda3-4c10-c816-0f810a06304e","colab":{"base_uri":"https://localhost:8080/","height":185,"referenced_widgets":["dccdc3dac09849d89bb4e33cd84f342e","5375b788c4434c70ab1174814f8f8c01","d2ed5d9e693649dbbce268b311751c42","3bb94f09d06b4f21a36d797603167e0e","daecd79e3f4d4f8ea5bdfaddfa979047","8fdbc2c453284869844620f0b3d48118","0b6bae24732c4f4e8c6533057db1c320","eeca15fa7845458aac9d0ac4564eb683","181c3157bb5c4d0682092d4362c71e38","0dae83dd44984685b6b547ea563f2ac5","499675f6f43449e1b457a943bb89f5e6","ad89503ecf0a4ce09e5d817b6aeca758","774eb9564b324fb398ec9f4257f67673","1b3a0f324a4f477eade4332e2f40b115","1986deffd06e4ed3bf882576f218ecb8","93da5eeb980f416da18ed48ebd3d025d","2a068826d3d94f9199c42bdc4381de7c","4a14f7780cae4c9ba6dd0b3588e713ea","ecfea453250a49caadb5162d3b02ac40","9879ce5353e04f738660e92d78a3abcd","e3bc1b53a7f94ed3bd409c309ca7fa3a","f0122c02070d44e9ab72d6189e65cfb5","1756f9ca57194e05a21e86be3390c160","353e41bf89064820b79b2d0c0ea7df31","bb8c6693952a4dba95e0d16305b6bd3f","7d00a1fa21354a51805a415db2410a7c","5c4eafb7fe01476e8306e62dda1d46f6","e53866d027f64bd4b299ecfdcf4f0a1c","7f7af62a6dd14868a0b35e520fec9d6a","e1e62a44c0b74e9781e93342838a744e","a84b52168d624ca1b442e989950b3b97","c5497a566d504a6d926c4a74d24fea75","c7cf618be98a4ef4954e83ca3b49874d","b157a99f67604979b1cd918b68fca6ac","6c00492230ed4b93af80dfc8e92b029b","b86e74d1b6ff42129cffb98dec1f3f7b","651f7d28f1a24bb3b7e48b3692a220e2","3498b3fee9b54ee0889f21bb893c78be","a781b7bb93a7451ba0c80d0f29986898","0d321d1d25fd40ef83074eba17b66080","4c0e7be5d0a94baaaadc5a6500dd9b6c","d34fe0909ab249058650c2cf8ceb35c4","7669118fa50947c5b9ae662fb6267f1d","6c6e27d70026499abaee7600b26fde29","ff8597efa78344a09d83e4a6db42abfd","0725627ce3414d1c92b340d3d3f8f281","de52e4ca1d6f4e0092d6c68274d3af8f","78fa9b95d9c649d9be746399446eec77","cdf7a866e43a4bdaa0c809b78741f757","4a10c07fa0e1434984871acb73d5f46a","53893397fee74fd5b6a34dada619b6d1","3af6c8755e7c49f6aa01f45364a7e77e","7c4f5b790f1d4c00aa975141a8ede738","b5d6c41c9a9141dc94963faa2d4a676a","8e2d89124d264321ab04e8b7716c7d33","5c711a79f5bc4b08b3dd2015e948adad","d78bc7be8f3847569d3134d5db2b5c8f","55ca92dd00c74bff961c46f74aa95558","c0cd6e7644ce4769b1cf9a15ee503d0d","9b2109cef3404964b4b3e0e299369ca3","c7a7c829ed2e40d6b9aeba7cd2792361","d103db0b223d4b6a927321aaee909f1c","df02c258cb1f47a48a5bda4dbd1cd01a","1c2f0535a5494ae1b557dd8426fd69c1","1ea41bb66d8c41639cf8d9273fa8c6d3","07a7298243ea44e5aa95adb338c2b5fa","4e71c9ec77e04e0fb85f83f739f48949","0d2f9fb372734a2cb58c1027afd2ab64","8ffb631b45d7432193da5c756028a061","9e1179561ab6434eabea9eb1238dae43","d41263cf6e4845c29612dc1f8a113a64","e0cb8d6dfb7349d7b7cfe1698140b29b","5e4a96dea2154789be89dc50523167d1","fa2922365ee043908b03abf0cc8b3d46","6c7f68bc82da44c89df1fe69883af09b","c9c3213ba5184b9ca4f46601d372a291","825b35a2d7cf4606ad7056d60f59e907","4464bdf63c684c748e530de69bdb8da3","b4a8f48c427642b7a5e574794c67c450","353d24473b444907b18c2d8eb1853d1b","785f446240fb476a8aa5cb1b8fb38573","8a813d3ace9e4c9f8defa24c29ee96ac","cf0747e7cf8a4fb6af9a8a2d2ff4885f","4cf0aab25b1c43338c8761e1ff492631","3993f8faf3794a9dae2c55e13b8938e4","33a70dcd558646f09e73060acdb80a98","af99dc606a7c4da0931365780439f632","e4047f4ac6824cacb4a0f2f1c5da6b48"]},"executionInfo":{"status":"ok","timestamp":1678890417359,"user_tz":-60,"elapsed":123162,"user":{"displayName":"Michael Tschannen","userId":"02619997334944082183"}}},"outputs":[{"output_type":"stream","name":"stdout","text":["Downloading and preparing dataset Unknown size (download: Unknown size, generated: Unknown size, total: Unknown size) to /root/tensorflow_datasets/imagenette/full-size-v2/1.0.0...\n"]},{"output_type":"display_data","data":{"text/plain":["Dl Completed...: 0 url [00:00, ? url/s]"],"application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"dccdc3dac09849d89bb4e33cd84f342e"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":["Dl Size...: 0 MiB [00:00, ? MiB/s]"],"application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"ad89503ecf0a4ce09e5d817b6aeca758"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":["Extraction completed...: 0 file [00:00, ? file/s]"],"application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"1756f9ca57194e05a21e86be3390c160"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":["Generating splits...: 0%| | 0/2 [00:00"],"image/png":"(internal link)(internal link)\n"},"metadata":{"needs_background":"light"}}],"source":["plt.figure(figsize=(30, 12))\n","plt.subplot(211)\n","plt.imshow(np.hstack(images) * .5 + .5)\n","plt.axis('off');\n","plt.subplot(212)\n","plt.imshow(np.hstack(texts) * .5 + .5)\n","plt.axis('off')"]},{"cell_type":"code","execution_count":13,"metadata":{"id":"GOe1Miplnf7Z","executionInfo":{"status":"ok","timestamp":1678890441504,"user_tz":-60,"elapsed":18794,"user":{"displayName":"Michael Tschannen","userId":"02619997334944082183"}}},"outputs":[],"source":["# Embed both texts and images with a single model call\n","# See \"zero-shot evaluation\" below for how to do this separately\n","zimg, ztxt, out = model.apply({\"params\": params}, jnp.stack(images), jnp.stack(texts))"]},{"cell_type":"code","execution_count":14,"metadata":{"colab":{"height":331,"base_uri":"https://localhost:8080/"},"id":"PtATpe2Bn5I0","outputId":"21e157fd-b0e2-4f10-916e-69bb73b43f77","executionInfo":{"status":"ok","timestamp":1678890441932,"user_tz":-60,"elapsed":431,"user":{"displayName":"Michael Tschannen","userId":"02619997334944082183"}}},"outputs":[{"output_type":"display_data","data":{"text/plain":["
"],"image/png":"\n"},"metadata":{"needs_background":"light"}}],"source":["# Visualize embedding similarities\n","plt.imshow(ztxt @ zimg.T)\n","plt.xlabel(\"Text\")\n","plt.ylabel(\"Image\")\n","tick_labels = [\"CD Player\", \"Truck\", \"Gas Station\", \"Chainsaw\", \"Colorful houses\"]\n","plt.xticks(ticks=range(5),labels=tick_labels, rotation=45)\n","plt.yticks(ticks=range(5),labels=tick_labels, rotation=0)\n","plt.show()"]},{"cell_type":"code","execution_count":15,"metadata":{"colab":{"height":206,"base_uri":"https://localhost:8080/"},"id":"bs9fr-1kCRQW","outputId":"affa9b69-6df8-4959-9a98-8a46b6c8f594","executionInfo":{"status":"ok","timestamp":1678890443073,"user_tz":-60,"elapsed":1144,"user":{"displayName":"Michael Tschannen","userId":"02619997334944082183"}}},"outputs":[{"output_type":"execute_result","data":{"text/plain":[""],"text/html":["\n","\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
 01234
itap of a cd player98.83%1.07%0.05%0.05%0.00%
a photo of a truck0.15%96.14%3.46%0.25%0.00%
gas station0.01%0.25%99.73%0.01%0.00%
chainsaw0.12%4.22%1.34%94.32%0.00%
a bad photo of colorful houses0.30%5.81%4.67%0.09%89.12%
\n"]},"metadata":{},"execution_count":15}],"source":["probs = np.array(jax.nn.softmax(out['t'] * ztxt @ zimg.T, axis=1))\n","pd.DataFrame(probs, index=text_list).style.background_gradient('Greens', vmin=0, vmax=1).format('{:.2%}')"]},{"cell_type":"markdown","source":["### Compute and compare sentence embeddings\n","\n","Since we co-train some of our models on pairs of neighboring sentences from C4\n","with the same contrastive loss as used for image/alt-text pairs, we expect the\n","embeddings to capture sentence similarities well. Indeed, our GLUE evaluations\n","show that CLIPPO learns good sentenence embeddings.\n","\n","Below we visualize the similarities between pairs of neigboring sentences form\n","different Wikipedia articles. CLIPPO models with C4 in the training mix assign\n","higher similarties to sentences from the same article than from different\n","article."],"metadata":{"id":"Fs-TpneMq2uX"}},{"cell_type":"code","source":["# Selection of sentence pairs from Wikipedia (collected on 3/7/2023)\n","sentence_pairs = [\n"," # https://en.wikipedia.org/wiki/Google_JAX\n"," ['Google JAX is a machine learning framework for transforming numerical functions.',\n"," 'It is described as bringing together a modified version of autograd (automatic obtaining of the gradient function through differentiation of a function) and TensorFlow\\'s XLA (Accelerated Linear Algebra).'],\n"," # https://en.wikipedia.org/wiki/Matterhorn\n"," ['The Matterhorn (/ˈmætərhɔːrn/, German: [ˈmatɐˌhɔʁn]; Italian: Cervino, [tʃerˈviːno]; French: Cervin, [sɛʁvɛ̃]; Romansh: Mont(e) Cervin(u)) is a mountain of the Alps, straddling the main watershed and border between Switzerland and Italy.',\n"," 'It is a large, near-symmetric pyramidal peak in the extended Monte Rosa area of the Pennine Alps, whose summit is 4,478 metres (14,692 ft) high, making it one of the highest summits in the Alps and Europe.'],\n"," # https://en.wikipedia.org/wiki/Claude_Shannon\n"," ['Claude Elwood Shannon (April 30, 1916 – February 24, 2001) was an American mathematician, electrical engineer, and cryptographer known as a \"father of information theory\".',\n"," 'As a 21-year-old master\\'s degree student at the Massachusetts Institute of Technology (MIT), he wrote his thesis demonstrating that electrical applications of Boolean algebra could construct any logical numerical relationship.'],\n"," # https://en.wikipedia.org/wiki/Z%C3%BCrich\n"," ['Zürich (/ˈzjʊərɪk, ˈzʊərɪk/ ZURE-ik, ZOOR-ik, German: [ˈtsyːrɪç] (listen); see below) is the largest city in Switzerland and the capital of the canton of Zürich.',\n"," 'It is located in north-central Switzerland, at the northwestern tip of Lake Zürich.'],\n","]\n","\n","# Preprocess sentences\n","sentence_lists = list(zip(*sentence_pairs))\n","first_sentences = preprocess_texts(sentence_lists[0])\n","second_sentences = preprocess_texts(sentence_lists[1])"],"metadata":{"id":"_dJrKpXu43f4","executionInfo":{"status":"ok","timestamp":1678890459181,"user_tz":-60,"elapsed":16111,"user":{"displayName":"Michael Tschannen","userId":"02619997334944082183"}}},"execution_count":16,"outputs":[]},{"cell_type":"code","source":["plt.figure(figsize=(30, 12))\n","plt.subplot(211)\n","plt.imshow(np.hstack(first_sentences) * .5 + .5)\n","plt.axis('off');\n","plt.subplot(212)\n","plt.imshow(np.hstack(second_sentences) * .5 + .5)\n","plt.axis('off')"],"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":700},"id":"8xKWKxzE7vhs","executionInfo":{"status":"ok","timestamp":1678890460516,"user_tz":-60,"elapsed":1350,"user":{"displayName":"Michael Tschannen","userId":"02619997334944082183"}},"outputId":"6c5adddf-5ad6-4394-b7e5-adeede5cb1cd"},"execution_count":17,"outputs":[{"output_type":"execute_result","data":{"text/plain":["(-0.5, 895.5, 223.5, -0.5)"]},"metadata":{},"execution_count":17},{"output_type":"display_data","data":{"text/plain":["
"],"image/png":"(internal link)\n"},"metadata":{"needs_background":"light"}}]},{"cell_type":"code","source":["# Compute and compare embeddings\n","ztxt1, ztxt2, out = model.apply({\"params\": params}, jnp.stack(first_sentences), jnp.stack(second_sentences))"],"metadata":{"id":"QZOuehV8SoSc","executionInfo":{"status":"ok","timestamp":1678890469988,"user_tz":-60,"elapsed":9475,"user":{"displayName":"Michael Tschannen","userId":"02619997334944082183"}}},"execution_count":18,"outputs":[]},{"cell_type":"code","source":["probs_txt = np.array(jax.nn.softmax(out['t'] * ztxt1 @ ztxt2.T, axis=1))\n","pd.DataFrame(probs_txt).style.background_gradient('Greens', vmin=0, vmax=1).format('{:.2%}')"],"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":175},"id":"8VljxbXbTHf1","executionInfo":{"status":"ok","timestamp":1678890470393,"user_tz":-60,"elapsed":416,"user":{"displayName":"Michael Tschannen","userId":"02619997334944082183"}},"outputId":"9ab19f19-7b8e-41ea-c595-60794e6c992c"},"execution_count":19,"outputs":[{"output_type":"execute_result","data":{"text/plain":[""],"text/html":["\n","\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
 0123
098.74%0.00%1.26%0.00%
10.00%98.67%0.00%1.33%
22.96%0.01%96.99%0.05%
30.00%0.00%0.00%100.00%
\n"]},"metadata":{},"execution_count":19}]},{"cell_type":"markdown","metadata":{"id":"_w1fANZM0FZY"},"source":["## `tfds` zero-shot evaluation\n","\n","We provide two zero-shot clasification examples\n","- Image classification on CIFAR-100 (run cells after \"Dataset and preprocessing preparation CIFAR-100\")\n","- Text classification on SST-2 (run cells after \"Dataset and preprocessing preparation SST-2\")\n","\n","After that, run the cells after \"Zero-shot classification\" cells to perform the classification."]},{"cell_type":"markdown","source":["#### Dataset and preprocessing preparation CIFAR-100"],"metadata":{"id":"QCo1J3x0c2gG"}},{"cell_type":"code","execution_count":20,"metadata":{"id":"Ojt_6bny09Nt","outputId":"c03fabc7-6366-4dc2-e6a1-e81502300179","colab":{"base_uri":"https://localhost:8080/","height":321,"referenced_widgets":["664319014bea49ddb0682ca183f37d38","7b80afb39bd34ffbb6ec71999f043a0d","7f8cfe19a0774252b9a2e0acf8712b95","24291661ee6a4e1ea39fe36a2461ad03","456418b864a04211a89f7ded9893182a","4e035e47a511466aaee1f410141dc219","1fe28e969c15490b8d5deddf404bbfc8","9b493b2b10d749af94288e1da31082de","7910356c0b024c48995d1215cfc40d92","57bca4c7a4e048118328e261f2816485","1981f2194ce14b9d858b4e66ad749a74","7bf24c0b526849148f79ddbe489fecbe","d96f75be285b4ece84a79b80525e07ff","eae9d0ad46014301a4104063786c6513","33daed476d1e48aea2c175d7d165007c","25ac8ad9428446638d2e831154b67e9e","096d4f03e30642ecb75b856d32192b77","ff4aa1191d5546e29cfc0b2099deaf32","5bd7bfbeae434bd0ad6d1a004866e8bf","49a2dd8bfc2440dc87edd2a379b3b9df","117b9b987ba646239156062a37e44881","7bcffca94bd349ce8e36b41eed824e9f","14202fa017664237abee369e0b787b2a","e77432e0ad7a48cda8a84a59c7c4fcd9","81cb50b67148443a97315c6004b9bd9e","072d46906efe4a07ac2a04bf532ca56e","d4a3cb590ed2442682da8cc156ed9ad0","df162d04da2a41209af80c51808caf2b","c1ed3a33cca94d0f8e91ea8cafbeffca","621ed7453b7a48dfa3fb5d5c0da0bec3","8c3a583468f54969a8203d83a3b511ad","259cda89a60d4ae28a23d222bf17a4f4","5d02bdf68b19451dad6399a596c700f9","31d1942f4b104c048fe34b06fd5612e4","e22728de4ffb4c2ca07c8a6ac6da10de","282d4ebc49794e869d299726b52bcf21","f74afaef34b14da7a6e5f0c2f569f68d","70dfcbb20077452a95f7b585fd182be6","99cdd80186aa49b891f7a5d4de962f18","5f8b6e4a048a4fd5b944e4d7853b9a71","1418a4fa26a64e6bbdcca38fb682e39d","c1e9314ff7904b8dbc50a8708aab1eb1","0bf192b5c7034c2191c5f92c7df1c9dd","2f71d1430e9c45799e8949ec0158a222","07689a70f74446bcba36184e47560bd1","65f301b7b02d4c4dba0f233f9e9e425a","e4412596145f4c01af01d51402ed45d6","179510514bfd4e3ea9c46889b13da181","3021ffddf75e42759b0749f82cb17eb4","5c101b460ebd4a27a01a2ada97f777a7","807aa95840eb41b4859a300500d8beec","46d2eb5e2620481f9495b472a8c84775","d7cbb87d3c9245a8b339ecbcea83e8c2","6d3ac9db950848f3baae1fa93b8e68ea","dee93ec00db64a9d8f3f1edd65a90ac8","4dc9469114be4125b4d344d644f5aff9","274ee0dd728f4a33be912fcd1786f058","5708d1ee150f46029173593d2ec4c21f","33749fbfb6d9403f962568ba42c1e8ed","a586b099410a4223b0c62b21e94b6c10","22b60f7e0f234a7e83c01a0b7f5ad9b1","a39974f014e24fc5ac43e72b34342e9d","f10a575c66664f618e179e4df8dc6fa3","5c3235bc34964d51b54fa096fc92c1f9","f43832732153414a951c1810c899c531","84bce2199a5e4d1c8526b1a6d1059cf5","40ad9dad27c747f5afc95f69b5bf42fa","3978ee7fa0f444bebf6da70d16c26095","f01d3d3fc2c84fb2b2bd14560cdd7fa3","8f28e75e20b24966b582137d4dfa31ee","4aaf504fe9fc4d618d4174ae9844b347","44a3a4df887a4bebbabd04033546d7c9","7fd53f280d914a8395c6c1c3cdc677ce","b83e550fff71444480e85f9155edb0d0","3742d1e304584823a975a123bfe2ccb4","135831c870fd49aeaf21294f8256f1f9","2d8ab80c4b5742a2b8ac5154d58d52d3","4be15f66f88c4401a769ce297c325705","e09cae9d4eb44f84892d1e4f705b13f8","ad869a7a2c3948f6b2485510529df75b","8e343e5259294118806a3144ef125d44","07ed03f79ec643028f7ed7c741d284a9","bb2436bfe4b047979ad69dbf58b12836","9e683e88a8e042889c004fd9e34e5ba9","d247a7b15d814991933dcf6705f5904b","b8130ff7e9f64911b7f11dd947d9eea4","cca976a0f62f41c9a6161e60be7f757c","47fb22e77cd44a75b752f4b8ee16c026"]},"executionInfo":{"status":"ok","timestamp":1678890542909,"user_tz":-60,"elapsed":72522,"user":{"displayName":"Michael Tschannen","userId":"02619997334944082183"}}},"outputs":[{"output_type":"stream","name":"stdout","text":["Downloading and preparing dataset Unknown size (download: Unknown size, generated: Unknown size, total: Unknown size) to /root/tensorflow_datasets/cifar100/3.0.2...\n"]},{"output_type":"display_data","data":{"text/plain":["Dl Completed...: 0 url [00:00, ? url/s]"],"application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"664319014bea49ddb0682ca183f37d38"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":["Dl Size...: 0 MiB [00:00, ? MiB/s]"],"application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"7bf24c0b526849148f79ddbe489fecbe"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":["Extraction completed...: 0 file [00:00, ? file/s]"],"application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"14202fa017664237abee369e0b787b2a"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":["Generating splits...: 0%| | 0/2 [00:00"],"image/png":"\n"},"metadata":{"needs_background":"light"}}],"source":["imgs = next(iter(ds_test.batch(4)))['image']\n","\n","plt.figure(figsize=(15, 4))\n","plt.imshow(np.hstack(imgs) * .5 + .5)\n","plt.axis('off');"]},{"cell_type":"code","source":["#@markdown *Prompt engineering*\\\n","#@markdown The [official CLIP Colab](https://github.com/openai/CLIP/blob/main/notebooks/Prompt_Engineering_for_ImageNet.ipynb)\n","#@markdown lists two sets of prompts: the 80 prompts mentioned in the [CLIP paper](https://arxiv.org/abs/2103.00020)\n","#@markdown as well as a shortlist of 7 prompts. That will be used by default for speed,\n","#@markdown but using 80 prompts boosts performance a bit.\n","use_80_prompts = False #@param {\"type\": \"boolean\"}\n","if use_80_prompts:\n"," PROMPTS = [\n"," 'a bad photo of a {}.',\n"," 'a photo of many {}.',\n"," 'a sculpture of a {}.',\n"," 'a photo of the hard to see {}.',\n"," 'a low resolution photo of the {}.',\n"," 'a rendering of a {}.',\n"," 'graffiti of a {}.',\n"," 'a bad photo of the {}.',\n"," 'a cropped photo of the {}.',\n"," 'a tattoo of a {}.',\n"," 'the embroidered {}.',\n"," 'a photo of a hard to see {}.',\n"," 'a bright photo of a {}.',\n"," 'a photo of a clean {}.',\n"," 'a photo of a dirty {}.',\n"," 'a dark photo of the {}.',\n"," 'a drawing of a {}.',\n"," 'a photo of my {}.',\n"," 'the plastic {}.',\n"," 'a photo of the cool {}.',\n"," 'a close-up photo of a {}.',\n"," 'a black and white photo of the {}.',\n"," 'a painting of the {}.',\n"," 'a painting of a {}.',\n"," 'a pixelated photo of the {}.',\n"," 'a sculpture of the {}.',\n"," 'a bright photo of the {}.',\n"," 'a cropped photo of a {}.',\n"," 'a plastic {}.',\n"," 'a photo of the dirty {}.',\n"," 'a jpeg corrupted photo of a {}.',\n"," 'a blurry photo of the {}.',\n"," 'a photo of the {}.',\n"," 'a good photo of the {}.',\n"," 'a rendering of the {}.',\n"," 'a {} in a video game.',\n"," 'a photo of one {}.',\n"," 'a doodle of a {}.',\n"," 'a close-up photo of the {}.',\n"," 'a photo of a {}.',\n"," 'the origami {}.',\n"," 'the {} in a video game.',\n"," 'a sketch of a {}.',\n"," 'a doodle of the {}.',\n"," 'a origami {}.',\n"," 'a low resolution photo of a {}.',\n"," 'the toy {}.',\n"," 'a rendition of the {}.',\n"," 'a photo of the clean {}.',\n"," 'a photo of a large {}.',\n"," 'a rendition of a {}.',\n"," 'a photo of a nice {}.',\n"," 'a photo of a weird {}.',\n"," 'a blurry photo of a {}.',\n"," 'a cartoon {}.',\n"," 'art of a {}.',\n"," 'a sketch of the {}.',\n"," 'a embroidered {}.',\n"," 'a pixelated photo of a {}.',\n"," 'itap of the {}.',\n"," 'a jpeg corrupted photo of the {}.',\n"," 'a good photo of a {}.',\n"," 'a plushie {}.',\n"," 'a photo of the nice {}.',\n"," 'a photo of the small {}.',\n"," 'a photo of the weird {}.',\n"," 'the cartoon {}.',\n"," 'art of the {}.',\n"," 'a drawing of the {}.',\n"," 'a photo of the large {}.',\n"," 'a black and white photo of a {}.',\n"," 'the plushie {}.',\n"," 'a dark photo of a {}.',\n"," 'itap of a {}.',\n"," 'graffiti of the {}.',\n"," 'a toy {}.',\n"," 'itap of my {}.',\n"," 'a photo of a cool {}.',\n"," 'a photo of a small {}.',\n"," 'a tattoo of the {}.',\n"," ]\n","else:\n"," PROMPTS = [\n"," 'itap of a {}.',\n"," 'a bad photo of the {}.',\n"," 'a origami {}.',\n"," 'a photo of the large {}.',\n"," 'a {} in a video game.',\n"," 'art of the {}.',\n"," 'a photo of the small {}.',\n"," '{}',\n"," ]"],"metadata":{"cellView":"form","id":"rB9324hHgWvh","executionInfo":{"status":"ok","timestamp":1678890543807,"user_tz":-60,"elapsed":5,"user":{"displayName":"Michael Tschannen","userId":"02619997334944082183"}}},"execution_count":22,"outputs":[]},{"cell_type":"code","execution_count":23,"metadata":{"id":"8BQtRYhd0F94","outputId":"1878af0f-14f4-4389-97d4-5870f51fe977","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1678890543807,"user_tz":-60,"elapsed":5,"user":{"displayName":"Michael Tschannen","userId":"02619997334944082183"}}},"outputs":[{"output_type":"execute_result","data":{"text/plain":["800"]},"metadata":{},"execution_count":23}],"source":["# Construct prompts for all templates\n","class_prompts = [\n"," prompt.format(classname)\n"," for classname in classnames\n"," for prompt in PROMPTS\n","]\n","len(class_prompts)"]},{"cell_type":"code","execution_count":24,"metadata":{"id":"Ye2MkSEVBcuF","outputId":"1930d7c2-4f2b-49ef-c498-ac2b9b209680","colab":{"base_uri":"https://localhost:8080/","height":241},"executionInfo":{"status":"ok","timestamp":1678890545724,"user_tz":-60,"elapsed":1920,"user":{"displayName":"Michael Tschannen","userId":"02619997334944082183"}}},"outputs":[{"output_type":"display_data","data":{"text/plain":["
"],"image/png":"\n"},"metadata":{"needs_background":"light"}}],"source":["# Render prompt images for all classes and visualize them\n","text_images = preprocess_texts(class_prompts [:7])\n","\n","plt.figure(figsize=(30, 6))\n","plt.imshow(np.hstack(text_images) * .5 + .5)\n","plt.axis('off');"]},{"cell_type":"markdown","source":["#### Dataset and preprocessing preparation SST-2\n","\n","The goal of SST-2 is sentiment classification (`positive`, `negative`) of movie\n","reviews. To this end we embed the reviews and a few prompts for positive/negative sentiment using CLIPPO and compare the embeddings to perform zero-shot\n","classification."],"metadata":{"id":"n4vkj4zejNR4"}},{"cell_type":"code","execution_count":30,"metadata":{"outputId":"64b6f081-453e-40fd-990e-037b212d5c99","colab":{"base_uri":"https://localhost:8080/","height":165,"referenced_widgets":["bb1c1dafddfc4ab1bf0ea96916d088dc","412d72f0cab84cdc9ccb202015b8f09b","321f978c7d154fa9a793612bd45e34fb","c98c2b0e046e4d11bdb41466c8b9bbed","3d795eb785364e04b3458db07a689231","47defbe6eded463aa913f6c94eba650c","90a7c3c7f67c4b9dabae3124d4f4807a","f9c8321259e64b5da15a5780e48387cc","7d9c02eab6b14e64b0d54535a51d3f26","7261da104bf4488489651c59b57c810e","292cacad99804debab40eaff1f7241f0","16c4f6ca9f104713aaed23017b953711","c9382c9cc43144ccbc94ff9c126c75cb","5d35963be08f4f5e855c86ad7360b061","9ba65fbf4ff54df8832d0790b326bd74","26e7683fe07e46f2b6e42ab2cf525d1c","2e2f1aca25d64fc6a5eb6c5984ea1884","4b045a43035c4f4aa8f72acfc44361a0","999d70d6ea1c43f0a1b7005999dc2d9d","4e45f77fa3fe477eaca52d987b4d0217","86979c617e50477db737303ad9042a7c","23c0c48a8eb04a6aa12a39e102a71322","60999d5a3045449b87f83734c86adcce","7e297ca7eaf64ce1b3bdf8685acbf74f","e9b3045f23b24882b847fcf7c821f3bd","a6322a45f87d43dd8cb4edc4f366742e","56a2a0d2d23243cf8f5c6424c6a704bd","bb36af1417914c8aa0728cf0294be51c","9cd358bace4a46288e5a11e03256eee1","71ea73a230af420c8c49d9cfaa5ab007","1199e807dbea492486a7faf9aeb2e8d2","b0d8ca926f9e469ebaaf8216ad9ede8c","c2867984ecb1465f95a65d126c939991","523eaf4ccda54016b3fbbc7bcba69d09","e4a736a0c2134a51bf605daa8882db31","f288cf7496284b1bb1ee1b8c5dbf9afe","ed84f6bbac894e62ad0599dae35a9b8d","0ca6bb008cdf457fa868d567e4c5cbc9","652284bf59e846cd980fef70006cbbd7","5842f56fc48440169916b714284e80fe","aa19d3fac97b48678fbf88350a462765","0afb2cb83138438daa94145fafa706cc","af89e62625854f85bce7d651bfc4acfe","7ab055add3f54e939b70b9937ed100cd","959675b37c284184a97ddfe7fed30088","f82ad804566d4c018941d9efe57aa753","6e5a097140b14e0b9d18bf49ad747ea1","29f7931b409b48d7b23ca62965c11447","a00b6e479fed421eb64cc26089d127ec","e6b51524f5f34035a302c1e25c905f6f","2b7e38d56e6c406d85d4c2f8dbd14381","9da221b94d634814a2b2f5b2598adfbf","fe302d219a9d48f386835333902f745c","7a64baf19db946708ac3cfe9bba0ecff","9f6f17c6ffb44fa09431423d27740cdf","e5ad5de128d4472dbe7c8c722685dd63","6baf1d03147a4d509ffeca5b46bfaef8","bd1b3ebfa0f040b6a3409f83f0689612","a14de20b05b74107912737505564a9b6","ea8f4fe35c91417ea8b9f5411ff3609e","9c09190e2fdd4f97affc2234e86c5bee","864e3bd5051a47a8901a28f7aaab9ccd","1afb48bfcdc64803b2e63dac7380ffcc","a437ae7a7d67409eb066f470b01896e8","246027ed95ea40a9a037bb1218a5bab6","e02e735df1f24875ae30e06c37434a96","50f38b8c1de741eb895e3de8d10a6e97","0727fde380a54a26b84117fb58b34814","b3806705ac714b7caa823d184d09dd0a","df6c3e7278a0402cbdcf56861181d1a0","dac1bd568a134f1ab6752ccea84d2cd7","9b4029630887481f895100433ea29ca5","48e4c9c0a9ad45f488833b9533db4a67","225d0c8c8ab54188bb8823f2618a7abb","79bde2b3bba444ff83109656e67760d1","ac5c7cce9d7641d89ac2921a6ced5480","463efc760f3c424c944084d4853ccb81","de8dc6464f554376834078c8fbeb4e9b","2c301232781c44c5a9883bb0e3d6d5a8","3e2396bc62d74a5ba3dd9a11731de392","4062069fd10849e48aa2f42f0ccad406","d058e12b7c9a47b6957916f0a5735a51","a50bd51c0e774c63a9a0c23de522c9e7","13a1ce102a2045b69e4565e7a63449ae","5176c6d1c19c434aa88256aa83bccd57","274282e991c142e2a4f62410e9fd29d3","d48e121c27884d7eadfda819c78e9617","56546bf856364f65b289cdf7ef5895fa","98133af35daf47b38b785da163078430","e99cda177fcf455281a2a2b606d0d8ef","87591ac7a7e247c88cd2149b259802b9","9dba966e37cd44398d55e7d4e191f1f0","ae3d380509cf4aba8e3a7dbe3b54000a","691c71d141d34e79ab4becc4f26252df","ab344b76b2e84abb8bebacd105941428","993f4775c78243ca8241ce3de9111b7f","4dbba02a274a488e9e4874752a0ded05","6552a8e5158644e099d74dd9e1f561df","2e8ff466b6b94290aef03bb2a39389c7","009034a493bc4bc7b3658f1381bce4cc","9d9eef4bb31f4183aa14346b7aa5cb20","6fadc7132d0d42ef9c3fc83dc8ee1564","a0cb3869861e41f88d8ae0b7e9627122","1776a1f9c103417bab4873d44ab31202","fff0dce5a48745978722580454349072","891a5a8d28f5434bada9cdaea176aea7","07b3506aa2e444379a0e41486d8154c4","fcbfa3898e0a4d01962c7502974db9d5","4303aa6e9a424f46b1291aac04db28f4","8635d996805741ce87e22fac012a9501"]},"executionInfo":{"status":"ok","timestamp":1678890966974,"user_tz":-60,"elapsed":13213,"user":{"displayName":"Michael Tschannen","userId":"02619997334944082183"}},"id":"7DEsxcvtRsEX"},"outputs":[{"output_type":"stream","name":"stdout","text":["Downloading and preparing dataset Unknown size (download: Unknown size, generated: Unknown size, total: Unknown size) to /root/tensorflow_datasets/glue/sst2/2.0.0...\n"]},{"output_type":"display_data","data":{"text/plain":["Dl Completed...: 0 url [00:00, ? url/s]"],"application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"bb1c1dafddfc4ab1bf0ea96916d088dc"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":["Dl Size...: 0 MiB [00:00, ? MiB/s]"],"application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"16c4f6ca9f104713aaed23017b953711"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":["Extraction completed...: 0 file [00:00, ? file/s]"],"application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"60999d5a3045449b87f83734c86adcce"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":["Generating splits...: 0%| | 0/3 [00:00"]},"metadata":{},"execution_count":31}]},{"cell_type":"code","source":["PROMPTS = (\n"," 'a {} review',\n"," 'a {} movie review',\n"," 'a {} sentiment',\n"," 'the movie was received {}',\n"," 'the reception was {}',\n"," )\n","class_prompts = [p.format(c) for c in classnames for p in PROMPTS]\n","len(class_prompts)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"a9-hXwnGSGMN","executionInfo":{"status":"ok","timestamp":1678890967745,"user_tz":-60,"elapsed":3,"user":{"displayName":"Michael Tschannen","userId":"02619997334944082183"}},"outputId":"b04b954c-1bf0-4c4d-ef0f-9c91a89a951d"},"execution_count":32,"outputs":[{"output_type":"execute_result","data":{"text/plain":["10"]},"metadata":{},"execution_count":32}]},{"cell_type":"code","source":["# Render prompt images for all classes and visualize them\n","text_images = preprocess_texts(class_prompts [:5])\n","\n","plt.figure(figsize=(20, 5))\n","plt.imshow(np.hstack(text_images) * .5 + .5)\n","plt.axis('off');"],"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":254},"id":"bHPRePTVkn4k","executionInfo":{"status":"ok","timestamp":1678890969903,"user_tz":-60,"elapsed":2160,"user":{"displayName":"Michael Tschannen","userId":"02619997334944082183"}},"outputId":"359dcb12-823d-400a-92c8-c31de895ad79"},"execution_count":33,"outputs":[{"output_type":"display_data","data":{"text/plain":["
"],"image/png":"\n"},"metadata":{"needs_background":"light"}}]},{"cell_type":"markdown","source":["### Zero-shot classification"],"metadata":{"id":"5J7DEWgQk3y7"}},{"cell_type":"code","execution_count":25,"metadata":{"id":"iO85T44aEmMs","executionInfo":{"status":"ok","timestamp":1678890545725,"user_tz":-60,"elapsed":5,"user":{"displayName":"Michael Tschannen","userId":"02619997334944082183"}}},"outputs":[],"source":["# JIT-compile image embedding function to speed up the processing\n","@jax.jit\n","def embed_images(params, images):\n"," zimg, _, _ = model.apply({\"params\": params}, image=images)\n"," return zimg"]},{"cell_type":"code","source":["# Compute class embeddings\n","zclass = []\n","for i in range(0, len(class_prompts), 100):\n"," batch = class_prompts[i : i + 100]\n"," batch = np.stack(preprocess_texts(batch))\n"," zbatch = embed_images(params, batch)\n"," zclass.append(zbatch)\n","\n","zclass = np.concatenate(zclass)\n","zclass.shape\n"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"nM_1-0rj2ksS","executionInfo":{"status":"ok","timestamp":1678890755919,"user_tz":-60,"elapsed":210198,"user":{"displayName":"Michael Tschannen","userId":"02619997334944082183"}},"outputId":"03e58c30-614d-4aea-b360-14c0bb90e263"},"execution_count":26,"outputs":[{"output_type":"execute_result","data":{"text/plain":["(800, 768)"]},"metadata":{},"execution_count":26}]},{"cell_type":"code","execution_count":27,"metadata":{"id":"8Drfz7TACX2P","outputId":"b7d05c68-ebec-49d2-9ad4-bf8ebb4a3298","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1678890953378,"user_tz":-60,"elapsed":197470,"user":{"displayName":"Michael Tschannen","userId":"02619997334944082183"}}},"outputs":[{"output_type":"stream","name":"stderr","text":["100%|██████████| 20/20 [03:17<00:00, 9.87s/it]\n"]},{"output_type":"execute_result","data":{"text/plain":["(10000, 768)"]},"metadata":{},"execution_count":27}],"source":["# Compute all image/sentence embeddings and collect the correct labels\n","zimgs = []\n","labels = []\n","\n","for batch in tqdm.tqdm(ds_test.batch(500)):\n"," labels += list(batch['label'].numpy())\n"," zimg = embed_images(params, batch['image'].numpy())\n"," zimgs.append(np.array(zimg))\n","zimgs = np.concatenate(zimgs)\n","zimgs.shape"]},{"cell_type":"code","execution_count":28,"metadata":{"id":"uU_8LyMcE7zp","outputId":"183dad76-b4dd-42d9-fe0a-5d286e1bc5b0","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1678890953379,"user_tz":-60,"elapsed":21,"user":{"displayName":"Michael Tschannen","userId":"02619997334944082183"}}},"outputs":[{"output_type":"execute_result","data":{"text/plain":["(10000, 100)"]},"metadata":{},"execution_count":28}],"source":["# Compute similarities ...\n","sims = zimgs @ zclass.reshape([len(classnames), len(PROMPTS), -1]).mean(axis=1).T\n","sims.shape"]},{"cell_type":"code","execution_count":29,"metadata":{"id":"0mKZW4gNFl_6","outputId":"8eb4b896-dde1-4f72-8808-f219210f7006","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1678890953379,"user_tz":-60,"elapsed":15,"user":{"displayName":"Michael Tschannen","userId":"02619997334944082183"}}},"outputs":[{"output_type":"execute_result","data":{"text/plain":["0.5092"]},"metadata":{},"execution_count":29}],"source":["# ... and use most similar embedding to predict label.\n","(sims.argmax(axis=1) == np.array(labels)).mean()"]},{"cell_type":"code","source":["# CIFAR-100 expected results (7 prompts)\n","#\n","# clippo_b16_yfcc100m: 0.4535\n","# clippo_b16_yfcc100m_i21k_init: 0.5183\n","# clippo_b16_yfcc100m_i21k_init_25c4: 0.5092\n","# clippo_b16_yfcc100m_i21k_init_50c4: 0.4862\n","# clippo_b16_yfcc100m_i21k_init_50c4: 0.4204\n","#\n","# SST-2 expected results\n","#\n","# clippo_b16_yfcc100m_i21k_init_25c4: 0.6754\n","# clippo_b16_yfcc100m_i21k_init_75c4: 0.7006"],"metadata":{"id":"f-N90sgl6aKN","executionInfo":{"status":"ok","timestamp":1678889595136,"user_tz":-60,"elapsed":17,"user":{"displayName":"Michael Tschannen","userId":"02619997334944082183"}}},"execution_count":30,"outputs":[]}],"metadata":{"accelerator":"GPU","colab":{"provenance":[{"file_id":"https://github.com/google-research/vision_transformer/blob/main/lit.ipynb","timestamp":1678125470076}],"toc_visible":true},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"name":"python"},"widgets":{"application/vnd.jupyter.widget-state+json":{"dccdc3dac09849d89bb4e33cd84f342e":{"model_module":"@jupyter-widgets/controls","model_name":"HBoxModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HBoxModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HBoxView","box_style":"","children":["IPY_MODEL_5375b788c4434c70ab1174814f8f8c01","IPY_MODEL_d2ed5d9e693649dbbce268b311751c42","IPY_MODEL_3bb94f09d06b4f21a36d797603167e0e"],"layout":"IPY_MODEL_daecd79e3f4d4f8ea5bdfaddfa979047"}},"5375b788c4434c70ab1174814f8f8c01":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_8fdbc2c453284869844620f0b3d48118","placeholder":"​","style":"IPY_MODEL_0b6bae24732c4f4e8c6533057db1c320","value":"Dl Completed...: 100%"}},"d2ed5d9e693649dbbce268b311751c42":{"model_module":"@jupyter-widgets/controls","model_name":"FloatProgressModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"FloatProgressModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"ProgressView","bar_style":"success","description":"","description_tooltip":null,"layout":"IPY_MODEL_eeca15fa7845458aac9d0ac4564eb683","max":1,"min":0,"orientation":"horizontal","style":"IPY_MODEL_181c3157bb5c4d0682092d4362c71e38","value":1}},"3bb94f09d06b4f21a36d797603167e0e":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_0dae83dd44984685b6b547ea563f2ac5","placeholder":"​","style":"IPY_MODEL_499675f6f43449e1b457a943bb89f5e6","value":" 1/1 [01:42<00:00, 48.00s/ url]"}},"daecd79e3f4d4f8ea5bdfaddfa979047":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"8fdbc2c453284869844620f0b3d48118":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"0b6bae24732c4f4e8c6533057db1c320":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"eeca15fa7845458aac9d0ac4564eb683":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":"20px"}},"181c3157bb5c4d0682092d4362c71e38":{"model_module":"@jupyter-widgets/controls","model_name":"ProgressStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"ProgressStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","bar_color":null,"description_width":""}},"0dae83dd44984685b6b547ea563f2ac5":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"499675f6f43449e1b457a943bb89f5e6":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"ad89503ecf0a4ce09e5d817b6aeca758":{"model_module":"@jupyter-widgets/controls","model_name":"HBoxModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HBoxModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HBoxView","box_style":"","children":["IPY_MODEL_774eb9564b324fb398ec9f4257f67673","IPY_MODEL_1b3a0f324a4f477eade4332e2f40b115","IPY_MODEL_1986deffd06e4ed3bf882576f218ecb8"],"layout":"IPY_MODEL_93da5eeb980f416da18ed48ebd3d025d"}},"774eb9564b324fb398ec9f4257f67673":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_2a068826d3d94f9199c42bdc4381de7c","placeholder":"​","style":"IPY_MODEL_4a14f7780cae4c9ba6dd0b3588e713ea","value":"Dl Size...: 100%"}},"1b3a0f324a4f477eade4332e2f40b115":{"model_module":"@jupyter-widgets/controls","model_name":"FloatProgressModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"FloatProgressModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"ProgressView","bar_style":"success","description":"","description_tooltip":null,"layout":"IPY_MODEL_ecfea453250a49caadb5162d3b02ac40","max":1,"min":0,"orientation":"horizontal","style":"IPY_MODEL_9879ce5353e04f738660e92d78a3abcd","value":1}},"1986deffd06e4ed3bf882576f218ecb8":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_e3bc1b53a7f94ed3bd409c309ca7fa3a","placeholder":"​","style":"IPY_MODEL_f0122c02070d44e9ab72d6189e65cfb5","value":" 1485/1485 [01:42<00:00, 32.32 MiB/s]"}},"93da5eeb980f416da18ed48ebd3d025d":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"2a068826d3d94f9199c42bdc4381de7c":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"4a14f7780cae4c9ba6dd0b3588e713ea":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"ecfea453250a49caadb5162d3b02ac40":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":"20px"}},"9879ce5353e04f738660e92d78a3abcd":{"model_module":"@jupyter-widgets/controls","model_name":"ProgressStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"ProgressStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","bar_color":null,"description_width":""}},"e3bc1b53a7f94ed3bd409c309ca7fa3a":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"f0122c02070d44e9ab72d6189e65cfb5":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"1756f9ca57194e05a21e86be3390c160":{"model_module":"@jupyter-widgets/controls","model_name":"HBoxModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HBoxModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HBoxView","box_style":"","children":["IPY_MODEL_353e41bf89064820b79b2d0c0ea7df31","IPY_MODEL_bb8c6693952a4dba95e0d16305b6bd3f","IPY_MODEL_7d00a1fa21354a51805a415db2410a7c"],"layout":"IPY_MODEL_5c4eafb7fe01476e8306e62dda1d46f6"}},"353e41bf89064820b79b2d0c0ea7df31":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_e53866d027f64bd4b299ecfdcf4f0a1c","placeholder":"​","style":"IPY_MODEL_7f7af62a6dd14868a0b35e520fec9d6a","value":"Extraction completed...: 100%"}},"bb8c6693952a4dba95e0d16305b6bd3f":{"model_module":"@jupyter-widgets/controls","model_name":"FloatProgressModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"FloatProgressModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"ProgressView","bar_style":"success","description":"","description_tooltip":null,"layout":"IPY_MODEL_e1e62a44c0b74e9781e93342838a744e","max":1,"min":0,"orientation":"horizontal","style":"IPY_MODEL_a84b52168d624ca1b442e989950b3b97","value":0}},"7d00a1fa21354a51805a415db2410a7c":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_c5497a566d504a6d926c4a74d24fea75","placeholder":"​","style":"IPY_MODEL_c7cf618be98a4ef4954e83ca3b49874d","value":" 13395/13395 [01:43<00:00, 664.73 file/s]"}},"5c4eafb7fe01476e8306e62dda1d46f6":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"e53866d027f64bd4b299ecfdcf4f0a1c":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"7f7af62a6dd14868a0b35e520fec9d6a":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"e1e62a44c0b74e9781e93342838a744e":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":"20px"}},"a84b52168d624ca1b442e989950b3b97":{"model_module":"@jupyter-widgets/controls","model_name":"ProgressStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"ProgressStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","bar_color":null,"description_width":""}},"c5497a566d504a6d926c4a74d24fea75":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"c7cf618be98a4ef4954e83ca3b49874d":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"b157a99f67604979b1cd918b68fca6ac":{"model_module":"@jupyter-widgets/controls","model_name":"HBoxModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HBoxModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HBoxView","box_style":"","children":["IPY_MODEL_6c00492230ed4b93af80dfc8e92b029b","IPY_MODEL_b86e74d1b6ff42129cffb98dec1f3f7b","IPY_MODEL_651f7d28f1a24bb3b7e48b3692a220e2"],"layout":"IPY_MODEL_3498b3fee9b54ee0889f21bb893c78be"}},"6c00492230ed4b93af80dfc8e92b029b":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_a781b7bb93a7451ba0c80d0f29986898","placeholder":"​","style":"IPY_MODEL_0d321d1d25fd40ef83074eba17b66080","value":"Generating splits...: 100%"}},"b86e74d1b6ff42129cffb98dec1f3f7b":{"model_module":"@jupyter-widgets/controls","model_name":"FloatProgressModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"FloatProgressModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"ProgressView","bar_style":"","description":"","description_tooltip":null,"layout":"IPY_MODEL_4c0e7be5d0a94baaaadc5a6500dd9b6c","max":2,"min":0,"orientation":"horizontal","style":"IPY_MODEL_d34fe0909ab249058650c2cf8ceb35c4","value":2}},"651f7d28f1a24bb3b7e48b3692a220e2":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_7669118fa50947c5b9ae662fb6267f1d","placeholder":"​","style":"IPY_MODEL_6c6e27d70026499abaee7600b26fde29","value":" 2/2 [00:17<00:00, 7.75s/ splits]"}},"3498b3fee9b54ee0889f21bb893c78be":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":"hidden","width":null}},"a781b7bb93a7451ba0c80d0f29986898":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"0d321d1d25fd40ef83074eba17b66080":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"4c0e7be5d0a94baaaadc5a6500dd9b6c":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"d34fe0909ab249058650c2cf8ceb35c4":{"model_module":"@jupyter-widgets/controls","model_name":"ProgressStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"ProgressStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","bar_color":null,"description_width":""}},"7669118fa50947c5b9ae662fb6267f1d":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"6c6e27d70026499abaee7600b26fde29":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"ff8597efa78344a09d83e4a6db42abfd":{"model_module":"@jupyter-widgets/controls","model_name":"HBoxModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HBoxModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HBoxView","box_style":"","children":["IPY_MODEL_0725627ce3414d1c92b340d3d3f8f281","IPY_MODEL_de52e4ca1d6f4e0092d6c68274d3af8f","IPY_MODEL_78fa9b95d9c649d9be746399446eec77"],"layout":"IPY_MODEL_cdf7a866e43a4bdaa0c809b78741f757"}},"0725627ce3414d1c92b340d3d3f8f281":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_4a10c07fa0e1434984871acb73d5f46a","placeholder":"​","style":"IPY_MODEL_53893397fee74fd5b6a34dada619b6d1","value":"Generating train examples...: "}},"de52e4ca1d6f4e0092d6c68274d3af8f":{"model_module":"@jupyter-widgets/controls","model_name":"FloatProgressModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"FloatProgressModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"ProgressView","bar_style":"info","description":"","description_tooltip":null,"layout":"IPY_MODEL_3af6c8755e7c49f6aa01f45364a7e77e","max":1,"min":0,"orientation":"horizontal","style":"IPY_MODEL_7c4f5b790f1d4c00aa975141a8ede738","value":1}},"78fa9b95d9c649d9be746399446eec77":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_b5d6c41c9a9141dc94963faa2d4a676a","placeholder":"​","style":"IPY_MODEL_8e2d89124d264321ab04e8b7716c7d33","value":" 9408/? [00:10<00:00, 272.24 examples/s]"}},"cdf7a866e43a4bdaa0c809b78741f757":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":"hidden","width":null}},"4a10c07fa0e1434984871acb73d5f46a":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"53893397fee74fd5b6a34dada619b6d1":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"3af6c8755e7c49f6aa01f45364a7e77e":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":"20px"}},"7c4f5b790f1d4c00aa975141a8ede738":{"model_module":"@jupyter-widgets/controls","model_name":"ProgressStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"ProgressStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","bar_color":null,"description_width":""}},"b5d6c41c9a9141dc94963faa2d4a676a":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"8e2d89124d264321ab04e8b7716c7d33":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"5c711a79f5bc4b08b3dd2015e948adad":{"model_module":"@jupyter-widgets/controls","model_name":"HBoxModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HBoxModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HBoxView","box_style":"","children":["IPY_MODEL_d78bc7be8f3847569d3134d5db2b5c8f","IPY_MODEL_55ca92dd00c74bff961c46f74aa95558","IPY_MODEL_c0cd6e7644ce4769b1cf9a15ee503d0d"],"layout":"IPY_MODEL_9b2109cef3404964b4b3e0e299369ca3"}},"d78bc7be8f3847569d3134d5db2b5c8f":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_c7a7c829ed2e40d6b9aeba7cd2792361","placeholder":"​","style":"IPY_MODEL_d103db0b223d4b6a927321aaee909f1c","value":"Shuffling /root/tensorflow_datasets/imagenette/full-size-v2/1.0.0.incompleteZXF06J/imagenette-train.tfrecord*...: 100%"}},"55ca92dd00c74bff961c46f74aa95558":{"model_module":"@jupyter-widgets/controls","model_name":"FloatProgressModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"FloatProgressModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"ProgressView","bar_style":"","description":"","description_tooltip":null,"layout":"IPY_MODEL_df02c258cb1f47a48a5bda4dbd1cd01a","max":9469,"min":0,"orientation":"horizontal","style":"IPY_MODEL_1c2f0535a5494ae1b557dd8426fd69c1","value":9469}},"c0cd6e7644ce4769b1cf9a15ee503d0d":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_1ea41bb66d8c41639cf8d9273fa8c6d3","placeholder":"​","style":"IPY_MODEL_07a7298243ea44e5aa95adb338c2b5fa","value":" 9454/9469 [00:04<00:00, 2234.99 examples/s]"}},"9b2109cef3404964b4b3e0e299369ca3":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":"hidden","width":null}},"c7a7c829ed2e40d6b9aeba7cd2792361":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"d103db0b223d4b6a927321aaee909f1c":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"df02c258cb1f47a48a5bda4dbd1cd01a":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"1c2f0535a5494ae1b557dd8426fd69c1":{"model_module":"@jupyter-widgets/controls","model_name":"ProgressStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"ProgressStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","bar_color":null,"description_width":""}},"1ea41bb66d8c41639cf8d9273fa8c6d3":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"07a7298243ea44e5aa95adb338c2b5fa":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"4e71c9ec77e04e0fb85f83f739f48949":{"model_module":"@jupyter-widgets/controls","model_name":"HBoxModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HBoxModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HBoxView","box_style":"","children":["IPY_MODEL_0d2f9fb372734a2cb58c1027afd2ab64","IPY_MODEL_8ffb631b45d7432193da5c756028a061","IPY_MODEL_9e1179561ab6434eabea9eb1238dae43"],"layout":"IPY_MODEL_d41263cf6e4845c29612dc1f8a113a64"}},"0d2f9fb372734a2cb58c1027afd2ab64":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_e0cb8d6dfb7349d7b7cfe1698140b29b","placeholder":"​","style":"IPY_MODEL_5e4a96dea2154789be89dc50523167d1","value":"Generating validation examples...: "}},"8ffb631b45d7432193da5c756028a061":{"model_module":"@jupyter-widgets/controls","model_name":"FloatProgressModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"FloatProgressModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"ProgressView","bar_style":"info","description":"","description_tooltip":null,"layout":"IPY_MODEL_fa2922365ee043908b03abf0cc8b3d46","max":1,"min":0,"orientation":"horizontal","style":"IPY_MODEL_6c7f68bc82da44c89df1fe69883af09b","value":1}},"9e1179561ab6434eabea9eb1238dae43":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_c9c3213ba5184b9ca4f46601d372a291","placeholder":"​","style":"IPY_MODEL_825b35a2d7cf4606ad7056d60f59e907","value":" 3925/? [00:01<00:00, 2737.37 examples/s]"}},"d41263cf6e4845c29612dc1f8a113a64":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":"hidden","width":null}},"e0cb8d6dfb7349d7b7cfe1698140b29b":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"5e4a96dea2154789be89dc50523167d1":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"fa2922365ee043908b03abf0cc8b3d46":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":"20px"}},"6c7f68bc82da44c89df1fe69883af09b":{"model_module":"@jupyter-widgets/controls","model_name":"ProgressStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"ProgressStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","bar_color":null,"description_width":""}},"c9c3213ba5184b9ca4f46601d372a291":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"825b35a2d7cf4606ad7056d60f59e907":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"4464bdf63c684c748e530de69bdb8da3":{"model_module":"@jupyter-widgets/controls","model_name":"HBoxModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HBoxModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HBoxView","box_style":"","children":["IPY_MODEL_b4a8f48c427642b7a5e574794c67c450","IPY_MODEL_353d24473b444907b18c2d8eb1853d1b","IPY_MODEL_785f446240fb476a8aa5cb1b8fb38573"],"layout":"IPY_MODEL_8a813d3ace9e4c9f8defa24c29ee96ac"}},"b4a8f48c427642b7a5e574794c67c450":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_cf0747e7cf8a4fb6af9a8a2d2ff4885f","placeholder":"​","style":"IPY_MODEL_4cf0aab25b1c43338c8761e1ff492631","value":"Shuffling /root/tensorflow_datasets/imagenette/full-size-v2/1.0.0.incompleteZXF06J/imagenette-validation.tfrecord*...: 97%"}},"353d24473b444907b18c2d8eb1853d1b":{"model_module":"@jupyter-widgets/controls","model_name":"FloatProgressModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"FloatProgressModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"ProgressView","bar_style":"","description":"","description_tooltip":null,"layout":"IPY_MODEL_3993f8faf3794a9dae2c55e13b8938e4","max":3925,"min":0,"orientation":"horizontal","style":"IPY_MODEL_33a70dcd558646f09e73060acdb80a98","value":3925}},"785f446240fb476a8aa5cb1b8fb38573":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_af99dc606a7c4da0931365780439f632","placeholder":"​","style":"IPY_MODEL_e4047f4ac6824cacb4a0f2f1c5da6b48","value":" 3825/3925 [00:01<00:00, 2164.36 examples/s]"}},"8a813d3ace9e4c9f8defa24c29ee96ac":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":"hidden","width":null}},"cf0747e7cf8a4fb6af9a8a2d2ff4885f":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"4cf0aab25b1c43338c8761e1ff492631":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"3993f8faf3794a9dae2c55e13b8938e4":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"33a70dcd558646f09e73060acdb80a98":{"model_module":"@jupyter-widgets/controls","model_name":"ProgressStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"ProgressStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","bar_color":null,"description_width":""}},"af99dc606a7c4da0931365780439f632":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"e4047f4ac6824cacb4a0f2f1c5da6b48":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"664319014bea49ddb0682ca183f37d38":{"model_module":"@jupyter-widgets/controls","model_name":"HBoxModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HBoxModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HBoxView","box_style":"","children":["IPY_MODEL_7b80afb39bd34ffbb6ec71999f043a0d","IPY_MODEL_7f8cfe19a0774252b9a2e0acf8712b95","IPY_MODEL_24291661ee6a4e1ea39fe36a2461ad03"],"layout":"IPY_MODEL_456418b864a04211a89f7ded9893182a"}},"7b80afb39bd34ffbb6ec71999f043a0d":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_4e035e47a511466aaee1f410141dc219","placeholder":"​","style":"IPY_MODEL_1fe28e969c15490b8d5deddf404bbfc8","value":"Dl Completed...: 100%"}},"7f8cfe19a0774252b9a2e0acf8712b95":{"model_module":"@jupyter-widgets/controls","model_name":"FloatProgressModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"FloatProgressModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"ProgressView","bar_style":"success","description":"","description_tooltip":null,"layout":"IPY_MODEL_9b493b2b10d749af94288e1da31082de","max":1,"min":0,"orientation":"horizontal","style":"IPY_MODEL_7910356c0b024c48995d1215cfc40d92","value":1}},"24291661ee6a4e1ea39fe36a2461ad03":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_57bca4c7a4e048118328e261f2816485","placeholder":"​","style":"IPY_MODEL_1981f2194ce14b9d858b4e66ad749a74","value":" 1/1 [00:08<00:00, 6.55s/ url]"}},"456418b864a04211a89f7ded9893182a":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"4e035e47a511466aaee1f410141dc219":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"1fe28e969c15490b8d5deddf404bbfc8":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"9b493b2b10d749af94288e1da31082de":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":"20px"}},"7910356c0b024c48995d1215cfc40d92":{"model_module":"@jupyter-widgets/controls","model_name":"ProgressStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"ProgressStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","bar_color":null,"description_width":""}},"57bca4c7a4e048118328e261f2816485":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"1981f2194ce14b9d858b4e66ad749a74":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"7bf24c0b526849148f79ddbe489fecbe":{"model_module":"@jupyter-widgets/controls","model_name":"HBoxModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HBoxModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HBoxView","box_style":"","children":["IPY_MODEL_d96f75be285b4ece84a79b80525e07ff","IPY_MODEL_eae9d0ad46014301a4104063786c6513","IPY_MODEL_33daed476d1e48aea2c175d7d165007c"],"layout":"IPY_MODEL_25ac8ad9428446638d2e831154b67e9e"}},"d96f75be285b4ece84a79b80525e07ff":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_096d4f03e30642ecb75b856d32192b77","placeholder":"​","style":"IPY_MODEL_ff4aa1191d5546e29cfc0b2099deaf32","value":"Dl Size...: 100%"}},"eae9d0ad46014301a4104063786c6513":{"model_module":"@jupyter-widgets/controls","model_name":"FloatProgressModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"FloatProgressModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"ProgressView","bar_style":"success","description":"","description_tooltip":null,"layout":"IPY_MODEL_5bd7bfbeae434bd0ad6d1a004866e8bf","max":1,"min":0,"orientation":"horizontal","style":"IPY_MODEL_49a2dd8bfc2440dc87edd2a379b3b9df","value":1}},"33daed476d1e48aea2c175d7d165007c":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_117b9b987ba646239156062a37e44881","placeholder":"​","style":"IPY_MODEL_7bcffca94bd349ce8e36b41eed824e9f","value":" 160/160 [00:08<00:00, 27.88 MiB/s]"}},"25ac8ad9428446638d2e831154b67e9e":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"096d4f03e30642ecb75b856d32192b77":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"ff4aa1191d5546e29cfc0b2099deaf32":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"5bd7bfbeae434bd0ad6d1a004866e8bf":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":"20px"}},"49a2dd8bfc2440dc87edd2a379b3b9df":{"model_module":"@jupyter-widgets/controls","model_name":"ProgressStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"ProgressStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","bar_color":null,"description_width":""}},"117b9b987ba646239156062a37e44881":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"7bcffca94bd349ce8e36b41eed824e9f":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"14202fa017664237abee369e0b787b2a":{"model_module":"@jupyter-widgets/controls","model_name":"HBoxModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HBoxModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HBoxView","box_style":"","children":["IPY_MODEL_e77432e0ad7a48cda8a84a59c7c4fcd9","IPY_MODEL_81cb50b67148443a97315c6004b9bd9e","IPY_MODEL_072d46906efe4a07ac2a04bf532ca56e"],"layout":"IPY_MODEL_d4a3cb590ed2442682da8cc156ed9ad0"}},"e77432e0ad7a48cda8a84a59c7c4fcd9":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_df162d04da2a41209af80c51808caf2b","placeholder":"​","style":"IPY_MODEL_c1ed3a33cca94d0f8e91ea8cafbeffca","value":"Extraction completed...: 100%"}},"81cb50b67148443a97315c6004b9bd9e":{"model_module":"@jupyter-widgets/controls","model_name":"FloatProgressModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"FloatProgressModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"ProgressView","bar_style":"success","description":"","description_tooltip":null,"layout":"IPY_MODEL_621ed7453b7a48dfa3fb5d5c0da0bec3","max":1,"min":0,"orientation":"horizontal","style":"IPY_MODEL_8c3a583468f54969a8203d83a3b511ad","value":1}},"072d46906efe4a07ac2a04bf532ca56e":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_259cda89a60d4ae28a23d222bf17a4f4","placeholder":"​","style":"IPY_MODEL_5d02bdf68b19451dad6399a596c700f9","value":" 4/4 [00:08<00:00, 8.23s/ file]"}},"d4a3cb590ed2442682da8cc156ed9ad0":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"df162d04da2a41209af80c51808caf2b":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"c1ed3a33cca94d0f8e91ea8cafbeffca":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"621ed7453b7a48dfa3fb5d5c0da0bec3":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":"20px"}},"8c3a583468f54969a8203d83a3b511ad":{"model_module":"@jupyter-widgets/controls","model_name":"ProgressStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"ProgressStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","bar_color":null,"description_width":""}},"259cda89a60d4ae28a23d222bf17a4f4":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"5d02bdf68b19451dad6399a596c700f9":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"31d1942f4b104c048fe34b06fd5612e4":{"model_module":"@jupyter-widgets/controls","model_name":"HBoxModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HBoxModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HBoxView","box_style":"","children":["IPY_MODEL_e22728de4ffb4c2ca07c8a6ac6da10de","IPY_MODEL_282d4ebc49794e869d299726b52bcf21","IPY_MODEL_f74afaef34b14da7a6e5f0c2f569f68d"],"layout":"IPY_MODEL_70dfcbb20077452a95f7b585fd182be6"}},"e22728de4ffb4c2ca07c8a6ac6da10de":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_99cdd80186aa49b891f7a5d4de962f18","placeholder":"​","style":"IPY_MODEL_5f8b6e4a048a4fd5b944e4d7853b9a71","value":"Generating splits...: 100%"}},"282d4ebc49794e869d299726b52bcf21":{"model_module":"@jupyter-widgets/controls","model_name":"FloatProgressModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"FloatProgressModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"ProgressView","bar_style":"","description":"","description_tooltip":null,"layout":"IPY_MODEL_1418a4fa26a64e6bbdcca38fb682e39d","max":2,"min":0,"orientation":"horizontal","style":"IPY_MODEL_c1e9314ff7904b8dbc50a8708aab1eb1","value":2}},"f74afaef34b14da7a6e5f0c2f569f68d":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_0bf192b5c7034c2191c5f92c7df1c9dd","placeholder":"​","style":"IPY_MODEL_2f71d1430e9c45799e8949ec0158a222","value":" 2/2 [01:03<00:00, 28.13s/ splits]"}},"70dfcbb20077452a95f7b585fd182be6":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":"hidden","width":null}},"99cdd80186aa49b891f7a5d4de962f18":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"5f8b6e4a048a4fd5b944e4d7853b9a71":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"1418a4fa26a64e6bbdcca38fb682e39d":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"c1e9314ff7904b8dbc50a8708aab1eb1":{"model_module":"@jupyter-widgets/controls","model_name":"ProgressStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"ProgressStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","bar_color":null,"description_width":""}},"0bf192b5c7034c2191c5f92c7df1c9dd":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"2f71d1430e9c45799e8949ec0158a222":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"07689a70f74446bcba36184e47560bd1":{"model_module":"@jupyter-widgets/controls","model_name":"HBoxModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HBoxModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HBoxView","box_style":"","children":["IPY_MODEL_65f301b7b02d4c4dba0f233f9e9e425a","IPY_MODEL_e4412596145f4c01af01d51402ed45d6","IPY_MODEL_179510514bfd4e3ea9c46889b13da181"],"layout":"IPY_MODEL_3021ffddf75e42759b0749f82cb17eb4"}},"65f301b7b02d4c4dba0f233f9e9e425a":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_5c101b460ebd4a27a01a2ada97f777a7","placeholder":"​","style":"IPY_MODEL_807aa95840eb41b4859a300500d8beec","value":"Generating train examples...: "}},"e4412596145f4c01af01d51402ed45d6":{"model_module":"@jupyter-widgets/controls","model_name":"FloatProgressModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"FloatProgressModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"ProgressView","bar_style":"info","description":"","description_tooltip":null,"layout":"IPY_MODEL_46d2eb5e2620481f9495b472a8c84775","max":1,"min":0,"orientation":"horizontal","style":"IPY_MODEL_d7cbb87d3c9245a8b339ecbcea83e8c2","value":1}},"179510514bfd4e3ea9c46889b13da181":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_6d3ac9db950848f3baae1fa93b8e68ea","placeholder":"​","style":"IPY_MODEL_dee93ec00db64a9d8f3f1edd65a90ac8","value":" 49940/? [00:52<00:00, 1017.08 examples/s]"}},"3021ffddf75e42759b0749f82cb17eb4":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":"hidden","width":null}},"5c101b460ebd4a27a01a2ada97f777a7":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"807aa95840eb41b4859a300500d8beec":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"46d2eb5e2620481f9495b472a8c84775":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":"20px"}},"d7cbb87d3c9245a8b339ecbcea83e8c2":{"model_module":"@jupyter-widgets/controls","model_name":"ProgressStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"ProgressStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","bar_color":null,"description_width":""}},"6d3ac9db950848f3baae1fa93b8e68ea":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"dee93ec00db64a9d8f3f1edd65a90ac8":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"4dc9469114be4125b4d344d644f5aff9":{"model_module":"@jupyter-widgets/controls","model_name":"HBoxModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HBoxModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HBoxView","box_style":"","children":["IPY_MODEL_274ee0dd728f4a33be912fcd1786f058","IPY_MODEL_5708d1ee150f46029173593d2ec4c21f","IPY_MODEL_33749fbfb6d9403f962568ba42c1e8ed"],"layout":"IPY_MODEL_a586b099410a4223b0c62b21e94b6c10"}},"274ee0dd728f4a33be912fcd1786f058":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_22b60f7e0f234a7e83c01a0b7f5ad9b1","placeholder":"​","style":"IPY_MODEL_a39974f014e24fc5ac43e72b34342e9d","value":"Shuffling /root/tensorflow_datasets/cifar100/3.0.2.incompleteW3MSAW/cifar100-train.tfrecord*...: 80%"}},"5708d1ee150f46029173593d2ec4c21f":{"model_module":"@jupyter-widgets/controls","model_name":"FloatProgressModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"FloatProgressModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"ProgressView","bar_style":"","description":"","description_tooltip":null,"layout":"IPY_MODEL_f10a575c66664f618e179e4df8dc6fa3","max":50000,"min":0,"orientation":"horizontal","style":"IPY_MODEL_5c3235bc34964d51b54fa096fc92c1f9","value":50000}},"33749fbfb6d9403f962568ba42c1e8ed":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_f43832732153414a951c1810c899c531","placeholder":"​","style":"IPY_MODEL_84bce2199a5e4d1c8526b1a6d1059cf5","value":" 40132/50000 [00:00<00:00, 140860.69 examples/s]"}},"a586b099410a4223b0c62b21e94b6c10":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":"hidden","width":null}},"22b60f7e0f234a7e83c01a0b7f5ad9b1":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"a39974f014e24fc5ac43e72b34342e9d":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"f10a575c66664f618e179e4df8dc6fa3":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"5c3235bc34964d51b54fa096fc92c1f9":{"model_module":"@jupyter-widgets/controls","model_name":"ProgressStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"ProgressStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","bar_color":null,"description_width":""}},"f43832732153414a951c1810c899c531":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"84bce2199a5e4d1c8526b1a6d1059cf5":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"40ad9dad27c747f5afc95f69b5bf42fa":{"model_module":"@jupyter-widgets/controls","model_name":"HBoxModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HBoxModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HBoxView","box_style":"","children":["IPY_MODEL_3978ee7fa0f444bebf6da70d16c26095","IPY_MODEL_f01d3d3fc2c84fb2b2bd14560cdd7fa3","IPY_MODEL_8f28e75e20b24966b582137d4dfa31ee"],"layout":"IPY_MODEL_4aaf504fe9fc4d618d4174ae9844b347"}},"3978ee7fa0f444bebf6da70d16c26095":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_44a3a4df887a4bebbabd04033546d7c9","placeholder":"​","style":"IPY_MODEL_7fd53f280d914a8395c6c1c3cdc677ce","value":"Generating test examples...: "}},"f01d3d3fc2c84fb2b2bd14560cdd7fa3":{"model_module":"@jupyter-widgets/controls","model_name":"FloatProgressModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"FloatProgressModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"ProgressView","bar_style":"info","description":"","description_tooltip":null,"layout":"IPY_MODEL_b83e550fff71444480e85f9155edb0d0","max":1,"min":0,"orientation":"horizontal","style":"IPY_MODEL_3742d1e304584823a975a123bfe2ccb4","value":1}},"8f28e75e20b24966b582137d4dfa31ee":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_135831c870fd49aeaf21294f8256f1f9","placeholder":"​","style":"IPY_MODEL_2d8ab80c4b5742a2b8ac5154d58d52d3","value":" 9955/? [00:10<00:00, 837.05 examples/s]"}},"4aaf504fe9fc4d618d4174ae9844b347":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":"hidden","width":null}},"44a3a4df887a4bebbabd04033546d7c9":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"7fd53f280d914a8395c6c1c3cdc677ce":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"b83e550fff71444480e85f9155edb0d0":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":"20px"}},"3742d1e304584823a975a123bfe2ccb4":{"model_module":"@jupyter-widgets/controls","model_name":"ProgressStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"ProgressStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","bar_color":null,"description_width":""}},"135831c870fd49aeaf21294f8256f1f9":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"2d8ab80c4b5742a2b8ac5154d58d52d3":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"4be15f66f88c4401a769ce297c325705":{"model_module":"@jupyter-widgets/controls","model_name":"HBoxModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HBoxModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HBoxView","box_style":"","children":["IPY_MODEL_e09cae9d4eb44f84892d1e4f705b13f8","IPY_MODEL_ad869a7a2c3948f6b2485510529df75b","IPY_MODEL_8e343e5259294118806a3144ef125d44"],"layout":"IPY_MODEL_07ed03f79ec643028f7ed7c741d284a9"}},"e09cae9d4eb44f84892d1e4f705b13f8":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_bb2436bfe4b047979ad69dbf58b12836","placeholder":"​","style":"IPY_MODEL_9e683e88a8e042889c004fd9e34e5ba9","value":"Shuffling /root/tensorflow_datasets/cifar100/3.0.2.incompleteW3MSAW/cifar100-test.tfrecord*...: 75%"}},"ad869a7a2c3948f6b2485510529df75b":{"model_module":"@jupyter-widgets/controls","model_name":"FloatProgressModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"FloatProgressModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"ProgressView","bar_style":"","description":"","description_tooltip":null,"layout":"IPY_MODEL_d247a7b15d814991933dcf6705f5904b","max":10000,"min":0,"orientation":"horizontal","style":"IPY_MODEL_b8130ff7e9f64911b7f11dd947d9eea4","value":10000}},"8e343e5259294118806a3144ef125d44":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_cca976a0f62f41c9a6161e60be7f757c","placeholder":"​","style":"IPY_MODEL_47fb22e77cd44a75b752f4b8ee16c026","value":" 7540/10000 [00:00<00:00, 75389.29 examples/s]"}},"07ed03f79ec643028f7ed7c741d284a9":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":"hidden","width":null}},"bb2436bfe4b047979ad69dbf58b12836":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"9e683e88a8e042889c004fd9e34e5ba9":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"d247a7b15d814991933dcf6705f5904b":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"b8130ff7e9f64911b7f11dd947d9eea4":{"model_module":"@jupyter-widgets/controls","model_name":"ProgressStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"ProgressStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","bar_color":null,"description_width":""}},"cca976a0f62f41c9a6161e60be7f757c":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"47fb22e77cd44a75b752f4b8ee16c026":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"bb1c1dafddfc4ab1bf0ea96916d088dc":{"model_module":"@jupyter-widgets/controls","model_name":"HBoxModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HBoxModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HBoxView","box_style":"","children":["IPY_MODEL_412d72f0cab84cdc9ccb202015b8f09b","IPY_MODEL_321f978c7d154fa9a793612bd45e34fb","IPY_MODEL_c98c2b0e046e4d11bdb41466c8b9bbed"],"layout":"IPY_MODEL_3d795eb785364e04b3458db07a689231"}},"412d72f0cab84cdc9ccb202015b8f09b":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_47defbe6eded463aa913f6c94eba650c","placeholder":"​","style":"IPY_MODEL_90a7c3c7f67c4b9dabae3124d4f4807a","value":"Dl Completed...: 100%"}},"321f978c7d154fa9a793612bd45e34fb":{"model_module":"@jupyter-widgets/controls","model_name":"FloatProgressModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"FloatProgressModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"ProgressView","bar_style":"success","description":"","description_tooltip":null,"layout":"IPY_MODEL_f9c8321259e64b5da15a5780e48387cc","max":1,"min":0,"orientation":"horizontal","style":"IPY_MODEL_7d9c02eab6b14e64b0d54535a51d3f26","value":1}},"c98c2b0e046e4d11bdb41466c8b9bbed":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_7261da104bf4488489651c59b57c810e","placeholder":"​","style":"IPY_MODEL_292cacad99804debab40eaff1f7241f0","value":" 1/1 [00:02<00:00, 2.14s/ url]"}},"3d795eb785364e04b3458db07a689231":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"47defbe6eded463aa913f6c94eba650c":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"90a7c3c7f67c4b9dabae3124d4f4807a":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"f9c8321259e64b5da15a5780e48387cc":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":"20px"}},"7d9c02eab6b14e64b0d54535a51d3f26":{"model_module":"@jupyter-widgets/controls","model_name":"ProgressStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"ProgressStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","bar_color":null,"description_width":""}},"7261da104bf4488489651c59b57c810e":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"292cacad99804debab40eaff1f7241f0":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"16c4f6ca9f104713aaed23017b953711":{"model_module":"@jupyter-widgets/controls","model_name":"HBoxModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HBoxModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HBoxView","box_style":"","children":["IPY_MODEL_c9382c9cc43144ccbc94ff9c126c75cb","IPY_MODEL_5d35963be08f4f5e855c86ad7360b061","IPY_MODEL_9ba65fbf4ff54df8832d0790b326bd74"],"layout":"IPY_MODEL_26e7683fe07e46f2b6e42ab2cf525d1c"}},"c9382c9cc43144ccbc94ff9c126c75cb":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_2e2f1aca25d64fc6a5eb6c5984ea1884","placeholder":"​","style":"IPY_MODEL_4b045a43035c4f4aa8f72acfc44361a0","value":"Dl Size...: 100%"}},"5d35963be08f4f5e855c86ad7360b061":{"model_module":"@jupyter-widgets/controls","model_name":"FloatProgressModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"FloatProgressModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"ProgressView","bar_style":"success","description":"","description_tooltip":null,"layout":"IPY_MODEL_999d70d6ea1c43f0a1b7005999dc2d9d","max":1,"min":0,"orientation":"horizontal","style":"IPY_MODEL_4e45f77fa3fe477eaca52d987b4d0217","value":1}},"9ba65fbf4ff54df8832d0790b326bd74":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_86979c617e50477db737303ad9042a7c","placeholder":"​","style":"IPY_MODEL_23c0c48a8eb04a6aa12a39e102a71322","value":" 7/7 [00:02<00:00, 3.97 MiB/s]"}},"26e7683fe07e46f2b6e42ab2cf525d1c":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"2e2f1aca25d64fc6a5eb6c5984ea1884":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"4b045a43035c4f4aa8f72acfc44361a0":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"999d70d6ea1c43f0a1b7005999dc2d9d":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":"20px"}},"4e45f77fa3fe477eaca52d987b4d0217":{"model_module":"@jupyter-widgets/controls","model_name":"ProgressStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"ProgressStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","bar_color":null,"description_width":""}},"86979c617e50477db737303ad9042a7c":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"23c0c48a8eb04a6aa12a39e102a71322":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"60999d5a3045449b87f83734c86adcce":{"model_module":"@jupyter-widgets/controls","model_name":"HBoxModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HBoxModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HBoxView","box_style":"","children":["IPY_MODEL_7e297ca7eaf64ce1b3bdf8685acbf74f","IPY_MODEL_e9b3045f23b24882b847fcf7c821f3bd","IPY_MODEL_a6322a45f87d43dd8cb4edc4f366742e"],"layout":"IPY_MODEL_56a2a0d2d23243cf8f5c6424c6a704bd"}},"7e297ca7eaf64ce1b3bdf8685acbf74f":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_bb36af1417914c8aa0728cf0294be51c","placeholder":"​","style":"IPY_MODEL_9cd358bace4a46288e5a11e03256eee1","value":"Extraction completed...: 100%"}},"e9b3045f23b24882b847fcf7c821f3bd":{"model_module":"@jupyter-widgets/controls","model_name":"FloatProgressModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"FloatProgressModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"ProgressView","bar_style":"success","description":"","description_tooltip":null,"layout":"IPY_MODEL_71ea73a230af420c8c49d9cfaa5ab007","max":1,"min":0,"orientation":"horizontal","style":"IPY_MODEL_1199e807dbea492486a7faf9aeb2e8d2","value":1}},"a6322a45f87d43dd8cb4edc4f366742e":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_b0d8ca926f9e469ebaaf8216ad9ede8c","placeholder":"​","style":"IPY_MODEL_c2867984ecb1465f95a65d126c939991","value":" 11/11 [00:02<00:00, 2.25s/ file]"}},"56a2a0d2d23243cf8f5c6424c6a704bd":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"bb36af1417914c8aa0728cf0294be51c":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"9cd358bace4a46288e5a11e03256eee1":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"71ea73a230af420c8c49d9cfaa5ab007":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":"20px"}},"1199e807dbea492486a7faf9aeb2e8d2":{"model_module":"@jupyter-widgets/controls","model_name":"ProgressStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"ProgressStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","bar_color":null,"description_width":""}},"b0d8ca926f9e469ebaaf8216ad9ede8c":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"c2867984ecb1465f95a65d126c939991":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"523eaf4ccda54016b3fbbc7bcba69d09":{"model_module":"@jupyter-widgets/controls","model_name":"HBoxModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HBoxModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HBoxView","box_style":"","children":["IPY_MODEL_e4a736a0c2134a51bf605daa8882db31","IPY_MODEL_f288cf7496284b1bb1ee1b8c5dbf9afe","IPY_MODEL_ed84f6bbac894e62ad0599dae35a9b8d"],"layout":"IPY_MODEL_0ca6bb008cdf457fa868d567e4c5cbc9"}},"e4a736a0c2134a51bf605daa8882db31":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_652284bf59e846cd980fef70006cbbd7","placeholder":"​","style":"IPY_MODEL_5842f56fc48440169916b714284e80fe","value":"Generating splits...: 100%"}},"f288cf7496284b1bb1ee1b8c5dbf9afe":{"model_module":"@jupyter-widgets/controls","model_name":"FloatProgressModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"FloatProgressModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"ProgressView","bar_style":"","description":"","description_tooltip":null,"layout":"IPY_MODEL_aa19d3fac97b48678fbf88350a462765","max":3,"min":0,"orientation":"horizontal","style":"IPY_MODEL_0afb2cb83138438daa94145fafa706cc","value":3}},"ed84f6bbac894e62ad0599dae35a9b8d":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_af89e62625854f85bce7d651bfc4acfe","placeholder":"​","style":"IPY_MODEL_7ab055add3f54e939b70b9937ed100cd","value":" 3/3 [00:07<00:00, 1.85s/ splits]"}},"0ca6bb008cdf457fa868d567e4c5cbc9":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":"hidden","width":null}},"652284bf59e846cd980fef70006cbbd7":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"5842f56fc48440169916b714284e80fe":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"aa19d3fac97b48678fbf88350a462765":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"0afb2cb83138438daa94145fafa706cc":{"model_module":"@jupyter-widgets/controls","model_name":"ProgressStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"ProgressStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","bar_color":null,"description_width":""}},"af89e62625854f85bce7d651bfc4acfe":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"7ab055add3f54e939b70b9937ed100cd":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"959675b37c284184a97ddfe7fed30088":{"model_module":"@jupyter-widgets/controls","model_name":"HBoxModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HBoxModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HBoxView","box_style":"","children":["IPY_MODEL_f82ad804566d4c018941d9efe57aa753","IPY_MODEL_6e5a097140b14e0b9d18bf49ad747ea1","IPY_MODEL_29f7931b409b48d7b23ca62965c11447"],"layout":"IPY_MODEL_a00b6e479fed421eb64cc26089d127ec"}},"f82ad804566d4c018941d9efe57aa753":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_e6b51524f5f34035a302c1e25c905f6f","placeholder":"​","style":"IPY_MODEL_2b7e38d56e6c406d85d4c2f8dbd14381","value":"Generating train examples...: "}},"6e5a097140b14e0b9d18bf49ad747ea1":{"model_module":"@jupyter-widgets/controls","model_name":"FloatProgressModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"FloatProgressModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"ProgressView","bar_style":"info","description":"","description_tooltip":null,"layout":"IPY_MODEL_9da221b94d634814a2b2f5b2598adfbf","max":1,"min":0,"orientation":"horizontal","style":"IPY_MODEL_fe302d219a9d48f386835333902f745c","value":1}},"29f7931b409b48d7b23ca62965c11447":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_7a64baf19db946708ac3cfe9bba0ecff","placeholder":"​","style":"IPY_MODEL_9f6f17c6ffb44fa09431423d27740cdf","value":" 67299/? [00:07<00:00, 10076.71 examples/s]"}},"a00b6e479fed421eb64cc26089d127ec":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":"hidden","width":null}},"e6b51524f5f34035a302c1e25c905f6f":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"2b7e38d56e6c406d85d4c2f8dbd14381":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"9da221b94d634814a2b2f5b2598adfbf":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":"20px"}},"fe302d219a9d48f386835333902f745c":{"model_module":"@jupyter-widgets/controls","model_name":"ProgressStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"ProgressStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","bar_color":null,"description_width":""}},"7a64baf19db946708ac3cfe9bba0ecff":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"9f6f17c6ffb44fa09431423d27740cdf":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"e5ad5de128d4472dbe7c8c722685dd63":{"model_module":"@jupyter-widgets/controls","model_name":"HBoxModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HBoxModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HBoxView","box_style":"","children":["IPY_MODEL_6baf1d03147a4d509ffeca5b46bfaef8","IPY_MODEL_bd1b3ebfa0f040b6a3409f83f0689612","IPY_MODEL_a14de20b05b74107912737505564a9b6"],"layout":"IPY_MODEL_ea8f4fe35c91417ea8b9f5411ff3609e"}},"6baf1d03147a4d509ffeca5b46bfaef8":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_9c09190e2fdd4f97affc2234e86c5bee","placeholder":"​","style":"IPY_MODEL_864e3bd5051a47a8901a28f7aaab9ccd","value":"Shuffling /root/tensorflow_datasets/glue/sst2/2.0.0.incompleteEYA068/glue-train.tfrecord*...: 87%"}},"bd1b3ebfa0f040b6a3409f83f0689612":{"model_module":"@jupyter-widgets/controls","model_name":"FloatProgressModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"FloatProgressModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"ProgressView","bar_style":"","description":"","description_tooltip":null,"layout":"IPY_MODEL_1afb48bfcdc64803b2e63dac7380ffcc","max":67349,"min":0,"orientation":"horizontal","style":"IPY_MODEL_a437ae7a7d67409eb066f470b01896e8","value":67349}},"a14de20b05b74107912737505564a9b6":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_246027ed95ea40a9a037bb1218a5bab6","placeholder":"​","style":"IPY_MODEL_e02e735df1f24875ae30e06c37434a96","value":" 58529/67349 [00:00<00:00, 318506.08 examples/s]"}},"ea8f4fe35c91417ea8b9f5411ff3609e":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":"hidden","width":null}},"9c09190e2fdd4f97affc2234e86c5bee":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"864e3bd5051a47a8901a28f7aaab9ccd":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"1afb48bfcdc64803b2e63dac7380ffcc":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"a437ae7a7d67409eb066f470b01896e8":{"model_module":"@jupyter-widgets/controls","model_name":"ProgressStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"ProgressStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","bar_color":null,"description_width":""}},"246027ed95ea40a9a037bb1218a5bab6":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"e02e735df1f24875ae30e06c37434a96":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"50f38b8c1de741eb895e3de8d10a6e97":{"model_module":"@jupyter-widgets/controls","model_name":"HBoxModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HBoxModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HBoxView","box_style":"","children":["IPY_MODEL_0727fde380a54a26b84117fb58b34814","IPY_MODEL_b3806705ac714b7caa823d184d09dd0a","IPY_MODEL_df6c3e7278a0402cbdcf56861181d1a0"],"layout":"IPY_MODEL_dac1bd568a134f1ab6752ccea84d2cd7"}},"0727fde380a54a26b84117fb58b34814":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_9b4029630887481f895100433ea29ca5","placeholder":"​","style":"IPY_MODEL_48e4c9c0a9ad45f488833b9533db4a67","value":"Generating validation examples...: "}},"b3806705ac714b7caa823d184d09dd0a":{"model_module":"@jupyter-widgets/controls","model_name":"FloatProgressModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"FloatProgressModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"ProgressView","bar_style":"info","description":"","description_tooltip":null,"layout":"IPY_MODEL_225d0c8c8ab54188bb8823f2618a7abb","max":1,"min":0,"orientation":"horizontal","style":"IPY_MODEL_79bde2b3bba444ff83109656e67760d1","value":1}},"df6c3e7278a0402cbdcf56861181d1a0":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_ac5c7cce9d7641d89ac2921a6ced5480","placeholder":"​","style":"IPY_MODEL_463efc760f3c424c944084d4853ccb81","value":" 348/? [00:00<00:00, 3478.61 examples/s]"}},"dac1bd568a134f1ab6752ccea84d2cd7":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":"hidden","width":null}},"9b4029630887481f895100433ea29ca5":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"48e4c9c0a9ad45f488833b9533db4a67":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"225d0c8c8ab54188bb8823f2618a7abb":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":"20px"}},"79bde2b3bba444ff83109656e67760d1":{"model_module":"@jupyter-widgets/controls","model_name":"ProgressStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"ProgressStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","bar_color":null,"description_width":""}},"ac5c7cce9d7641d89ac2921a6ced5480":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"463efc760f3c424c944084d4853ccb81":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"de8dc6464f554376834078c8fbeb4e9b":{"model_module":"@jupyter-widgets/controls","model_name":"HBoxModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HBoxModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HBoxView","box_style":"","children":["IPY_MODEL_2c301232781c44c5a9883bb0e3d6d5a8","IPY_MODEL_3e2396bc62d74a5ba3dd9a11731de392","IPY_MODEL_4062069fd10849e48aa2f42f0ccad406"],"layout":"IPY_MODEL_d058e12b7c9a47b6957916f0a5735a51"}},"2c301232781c44c5a9883bb0e3d6d5a8":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_a50bd51c0e774c63a9a0c23de522c9e7","placeholder":"​","style":"IPY_MODEL_13a1ce102a2045b69e4565e7a63449ae","value":"Shuffling /root/tensorflow_datasets/glue/sst2/2.0.0.incompleteEYA068/glue-validation.tfrecord*...: 0%"}},"3e2396bc62d74a5ba3dd9a11731de392":{"model_module":"@jupyter-widgets/controls","model_name":"FloatProgressModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"FloatProgressModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"ProgressView","bar_style":"","description":"","description_tooltip":null,"layout":"IPY_MODEL_5176c6d1c19c434aa88256aa83bccd57","max":872,"min":0,"orientation":"horizontal","style":"IPY_MODEL_274282e991c142e2a4f62410e9fd29d3","value":872}},"4062069fd10849e48aa2f42f0ccad406":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_d48e121c27884d7eadfda819c78e9617","placeholder":"​","style":"IPY_MODEL_56546bf856364f65b289cdf7ef5895fa","value":" 0/872 [00:00<?, ? examples/s]"}},"d058e12b7c9a47b6957916f0a5735a51":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":"hidden","width":null}},"a50bd51c0e774c63a9a0c23de522c9e7":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"13a1ce102a2045b69e4565e7a63449ae":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"5176c6d1c19c434aa88256aa83bccd57":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"274282e991c142e2a4f62410e9fd29d3":{"model_module":"@jupyter-widgets/controls","model_name":"ProgressStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"ProgressStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","bar_color":null,"description_width":""}},"d48e121c27884d7eadfda819c78e9617":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"56546bf856364f65b289cdf7ef5895fa":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"98133af35daf47b38b785da163078430":{"model_module":"@jupyter-widgets/controls","model_name":"HBoxModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HBoxModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HBoxView","box_style":"","children":["IPY_MODEL_e99cda177fcf455281a2a2b606d0d8ef","IPY_MODEL_87591ac7a7e247c88cd2149b259802b9","IPY_MODEL_9dba966e37cd44398d55e7d4e191f1f0"],"layout":"IPY_MODEL_ae3d380509cf4aba8e3a7dbe3b54000a"}},"e99cda177fcf455281a2a2b606d0d8ef":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_691c71d141d34e79ab4becc4f26252df","placeholder":"​","style":"IPY_MODEL_ab344b76b2e84abb8bebacd105941428","value":"Generating test examples...: "}},"87591ac7a7e247c88cd2149b259802b9":{"model_module":"@jupyter-widgets/controls","model_name":"FloatProgressModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"FloatProgressModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"ProgressView","bar_style":"info","description":"","description_tooltip":null,"layout":"IPY_MODEL_993f4775c78243ca8241ce3de9111b7f","max":1,"min":0,"orientation":"horizontal","style":"IPY_MODEL_4dbba02a274a488e9e4874752a0ded05","value":1}},"9dba966e37cd44398d55e7d4e191f1f0":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_6552a8e5158644e099d74dd9e1f561df","placeholder":"​","style":"IPY_MODEL_2e8ff466b6b94290aef03bb2a39389c7","value":" 1220/? [00:00<00:00, 6451.57 examples/s]"}},"ae3d380509cf4aba8e3a7dbe3b54000a":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":"hidden","width":null}},"691c71d141d34e79ab4becc4f26252df":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"ab344b76b2e84abb8bebacd105941428":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"993f4775c78243ca8241ce3de9111b7f":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":"20px"}},"4dbba02a274a488e9e4874752a0ded05":{"model_module":"@jupyter-widgets/controls","model_name":"ProgressStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"ProgressStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","bar_color":null,"description_width":""}},"6552a8e5158644e099d74dd9e1f561df":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"2e8ff466b6b94290aef03bb2a39389c7":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"009034a493bc4bc7b3658f1381bce4cc":{"model_module":"@jupyter-widgets/controls","model_name":"HBoxModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HBoxModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HBoxView","box_style":"","children":["IPY_MODEL_9d9eef4bb31f4183aa14346b7aa5cb20","IPY_MODEL_6fadc7132d0d42ef9c3fc83dc8ee1564","IPY_MODEL_a0cb3869861e41f88d8ae0b7e9627122"],"layout":"IPY_MODEL_1776a1f9c103417bab4873d44ab31202"}},"9d9eef4bb31f4183aa14346b7aa5cb20":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_fff0dce5a48745978722580454349072","placeholder":"​","style":"IPY_MODEL_891a5a8d28f5434bada9cdaea176aea7","value":"Shuffling /root/tensorflow_datasets/glue/sst2/2.0.0.incompleteEYA068/glue-test.tfrecord*...: 0%"}},"6fadc7132d0d42ef9c3fc83dc8ee1564":{"model_module":"@jupyter-widgets/controls","model_name":"FloatProgressModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"FloatProgressModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"ProgressView","bar_style":"","description":"","description_tooltip":null,"layout":"IPY_MODEL_07b3506aa2e444379a0e41486d8154c4","max":1821,"min":0,"orientation":"horizontal","style":"IPY_MODEL_fcbfa3898e0a4d01962c7502974db9d5","value":1821}},"a0cb3869861e41f88d8ae0b7e9627122":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_4303aa6e9a424f46b1291aac04db28f4","placeholder":"​","style":"IPY_MODEL_8635d996805741ce87e22fac012a9501","value":" 0/1821 [00:00<?, ? examples/s]"}},"1776a1f9c103417bab4873d44ab31202":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":"hidden","width":null}},"fff0dce5a48745978722580454349072":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"891a5a8d28f5434bada9cdaea176aea7":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"07b3506aa2e444379a0e41486d8154c4":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"fcbfa3898e0a4d01962c7502974db9d5":{"model_module":"@jupyter-widgets/controls","model_name":"ProgressStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"ProgressStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","bar_color":null,"description_width":""}},"4303aa6e9a424f46b1291aac04db28f4":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"8635d996805741ce87e22fac012a9501":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}}}}},"nbformat":4,"nbformat_minor":0} \ No newline at end of file diff --git a/big_vision/configs/proj/clippo/train_clippo.py b/big_vision/configs/proj/clippo/train_clippo.py new file mode 100644 index 0000000000000000000000000000000000000000..8fc683b985d096a010744bb565c1d3b47d03c96e --- /dev/null +++ b/big_vision/configs/proj/clippo/train_clippo.py @@ -0,0 +1,199 @@ +# Copyright 2022 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""Trains CLIP with Pixels Only (CLIPPO), https://arxiv.org/abs/2212.08045 + +IMPORTANT NOTE: This config uses coco_captions by default for demonstration +purposes since the TFDS catalog does not provide any large image/alt-text data +set; the training will not produce a model with useful accuracy. Please +replace the data set below (marked by a comment) with an appropriate image/ +alt-text data set wrapped in TFDS (for example LAION-400M) and run the config +with the suffix `:test_with_coco=False` to train on your data set. Refer to +the following guide to build a TFDS wrapper for your favorite image/alt-text +data set: +https://www.tensorflow.org/datasets/add_dataset + +Also note that evaluation on ImageNet requires manual TFDS setup, see +https://github.com/google-research/big_vision#preparing-tfds-data + + +Example training: + +big_vision.trainers.proj.image_text.contrastive \ + --config big_vision/configs/proj/clippo/train_clippo.py \ + --workdir gs://[your_bucket]/big_vision/`date '+%Y-%m-%d_%H%M'` + +""" + +import big_vision.configs.common as bvcc +from big_vision.configs.common_fewshot import get_fewshot_lsr +from big_vision.configs.proj.image_text import common +from ml_collections import ConfigDict + + +def get_config(arg=None): + """The base configuration.""" + arg = bvcc.parse_arg( + arg, res=224, runlocal=False, variant='B/16', + test_with_coco=True, i1k_eval=False) + config = ConfigDict() + + config.input = {} + if arg.test_with_coco: + # Use COCO Captions for sanity-checking + config.input.data = dict(name='coco_captions', split='train') + val_data = dict(config.input.data) + val_data['split'] = 'val' + config.input.batch_size = 4000 if not arg.runlocal else 32 + config.input.shuffle_buffer_size = 50_000 if not arg.runlocal else 50 + config.total_steps = 400 if not arg.runlocal else 10 + else: + # Please add your favorite image/alt-text dataset here + config.input.data = None + val_data = None + assert config.input.data is not None and val_data is not None, ( + config.input.data, val_data) + + # The value in the paper is 10 * 1024, which requires 128 TPUv3 cores or a + # memory optimized ViT implementation when running on 128 TPUv2 cores. + config.input.batch_size = 8 * 1024 if not arg.runlocal else 32 + config.input.shuffle_buffer_size = 250_000 if not arg.runlocal else 50 + config.total_steps = 100_000 if not arg.runlocal else 10 + + def tokenizer(inkey, outkey='labels'): + return (f'render_unifont(' + f'inkey="{inkey}", ' + f'outkey="{outkey}", ' + f'image_size={arg.res}, ' + f'lower=True, ' + f'font_size=16, ' + f'text_brightness=0, ' + f'background_brightness=127)|' + f'value_range(-1, 1, inkey="{outkey}", outkey="{outkey}")') + + pp_image = f'decode|resize({arg.res})|value_range(-1,1)' + if arg.test_with_coco: + # Train with augmentation when sanity-checking + pp_image_aug = ( + f'decode|resize({arg.res})|flip_lr|randaug(2,10)|value_range(-1,1)') + config.input.pp = pp_eval = ( + f'{pp_image_aug}|flatten|{tokenizer("captions/text")}|' + f'keep("image", "labels")') + else: + config.input.pp = pp_eval = ( + f'{pp_image}|flatten|{tokenizer("text")}|keep("image", "labels")') + + config.pp_modules = [ + 'ops_general', 'ops_image', 'ops_text', 'proj.clippo.pp_ops'] + + config.log_training_steps = 50 + config.ckpt_steps = 1000 + config.keep_ckpt_steps = 5000 + + config.loss_use_global_batch = True + + # Define the model + config.model_name = 'proj.clippo.one_tower' + + config.model = ConfigDict() + config.model.image_model = 'vit' + config.model.image = ConfigDict({ + 'variant': arg.variant, + 'pool_type': 'map', + 'head_zeroinit': False, + }) + + if arg.test_with_coco: + # Initialize with ImageNet21k pretrained checkpoint for sanity-checking + assert arg.variant == 'B/16', arg.variant + config.model_init = {'image': 'howto-i21k-B/16'} + config.model_load = {} + config.model_load['img_load_kw'] = { + 'dont_load': ['^head/.*', '^MAPHead_0/.*', 'cls']} + + config.model.temperature_init = 10.0 + config.model.out_dim = 768 + + # Define the optimizer + config.optax_name = 'big_vision.scale_by_adafactor' + config.grad_clip_norm = 1.0 + + if arg.test_with_coco: + # Short schedule for sanity-checking + config.lr = 0.0001 + config.wd = 0.0003 + config.schedule = dict(decay_type='rsqrt', + timescale=100, + warmup_steps=100 if not arg.runlocal else 5, + cooldown_steps=100 if not arg.runlocal else 5) + else: + config.lr = 0.001 + config.wd = 0.0001 + config.schedule = dict(decay_type='rsqrt', + timescale=10_000, + warmup_steps=10_000 if not arg.runlocal else 5, + cooldown_steps=10_000 if not arg.runlocal else 5) + + # Eval section (Both few-shot and zero-shot) + eval_common = dict( + type='proj.image_text.contrastive', + use_global_batch=config.loss_use_global_batch, + log_steps=1000 if not arg.runlocal else 5, + ) + config.evals = {} + sub = '[:4]' if arg.runlocal else '' + config.evals.val = { + **eval_common, + 'data': val_data, + 'pp_fn': pp_eval, + } + config.evals.coco = { + **eval_common, + 'data': dict(name='coco_captions', split=f'val{sub}'), + 'pp_fn': ( + f'{pp_image}|flatten|{tokenizer("captions/text")}|' + f'keep("image", "labels")'), + } + + if arg.i1k_eval: + # Requires manual download, see + # https://github.com/google-research/big_vision#preparing-tfds-data + config.evals.imagenet = { + **eval_common, + 'data': dict(name='imagenet2012', split=f'validation{sub}'), + 'pp_fn': ( + f'{pp_image}|clip_i1k_label_names|' + f'{tokenizer("labels")}|keep("image", "labels")'), + } + config.evals.disclf = dict( + type='proj.image_text.discriminative_classifier', + pp_txt=tokenizer('texts', 'labels'), + prefix='z/0shot/', + log_steps=5_000 if not arg.runlocal else 5) + + config.evals.retrieval_coco = common.get_coco( + pp_img=f'resize({arg.res})|value_range(-1, 1)', + pp_txt=tokenizer('texts'), + log_steps=5_000 if not arg.runlocal else 5, + ) + + # Few-shot metrics + config.evals.fewshot = get_fewshot_lsr() + config.evals.fewshot.log_steps = 5_000 if not arg.runlocal else 5 + config.evals.fewshot.representation_layer = 'img/pre_logits' + + config.seed = 0 + + return config diff --git a/big_vision/configs/proj/distill/README.md b/big_vision/configs/proj/distill/README.md new file mode 100644 index 0000000000000000000000000000000000000000..44e10c3670d40307b74e367a9e6725ca1f848c61 --- /dev/null +++ b/big_vision/configs/proj/distill/README.md @@ -0,0 +1,43 @@ +# Knowledge distillation: A good teacher is patient and consistent +*by Lucas Beyer, Xiaohua Zhai, Amélie Royer, Larisa Markeeva, Rohan Anil, Alexander Kolesnikov* + +## Introduction +We publish all teacher models, and configurations for the main experiments of +the paper, as well as training logs and student models. + +Please read the main [big_vision README](/README.md) to learn how to run +configs, and remember that each config file contains an example invocation in +the top-level comment. + +## Results + +We provide the following [colab to read and plot the logfiles](https://colab.research.google.com/drive/1nMykzUzsfQ_uAxfj3k35DYsATnG_knPl?usp=sharing) +of a few runs that we reproduced on Cloud. + +### ImageNet-1k + +The file [bit_i1k.py](bit_i1k.py) is the configuration which reproduces our +distillation runs on ImageNet-1k reported in Figures 1 and 5(left) and the first +row of Table1. + +We release both student and teacher models: + +| Model | Download link | Resolution | ImageNet top-1 acc. (paper) | +| :--- | :---: | :---: | :---: | +| BiT-R50x1 | [link](https://storage.googleapis.com/bit_models/distill/R50x1_160.npz) | 160 | 80.5 | +| BiT-R50x1 | [link](https://storage.googleapis.com/bit_models/distill/R50x1_224.npz) | 224 | 82.8 | +| BiT-R152x2 | [link](https://storage.googleapis.com/bit_models/distill/R152x2_T_224.npz) | 224 | 83.0 | +| BiT-R152x2 | [link](https://storage.googleapis.com/bit_models/distill/R152x2_T_384.npz) | 384 | 84.3 | + +### Flowers/Pet/Food/Sun + +The files [bigsweep_flowers_pet.py](bigsweep_flowers_pet.py) and +[bigsweep_food_sun.py](bigsweep_food_sun.py) can be used to reproduce the +distillation runs on these datasets and shown in Figures 3,4,9-12, and Table4. + +While our open-source release does not currently support doing hyper-parameter +sweeps, we still provide an example of the sweeps at the end of the configs +for reference. + +### Teacher models +Links to all teacher models we used can be found in [common.py](common.py). diff --git a/big_vision/configs/proj/distill/bigsweep_flowers_pet.py b/big_vision/configs/proj/distill/bigsweep_flowers_pet.py new file mode 100644 index 0000000000000000000000000000000000000000..977f7d2e5cf672bae6032b108148d123ecca4265 --- /dev/null +++ b/big_vision/configs/proj/distill/bigsweep_flowers_pet.py @@ -0,0 +1,164 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""Distilling BiT-R152x2 into BiT-R50x1 on Flowers/Pet as in https://arxiv.org/abs/2106.05237 + +While many epochs are required, this is a small dataset, and thus overall it +is still fast and possible to run on the relatively small v3-8TPUs (or GPUs). + +This configuration contains the recommended settings from Fig3/Tab4 of the +paper, which can be selected via the fast/medium/long config argument. +(best settings were selected on a 10% minival) + +For Flowers: +- The `fast` variant takes ~1h10m on a v2-8 TPU. + Example logs at gs://big_vision/distill/bit_flowers_fast_06-18_2008/big_vision_metrics.txt +- The `long` variant takes ~25h on a v3-32 TPU. + Example logs at gs://big_vision/distill/bit_flowers_long_06-19_0524/big_vision_metrics.txt +For Pet: +- The `fast` variant takes ~28min on a v2-8 TPU. + Example logs at gs://big_vision/distill/bit_pet_fast_06-16_2338/big_vision_metrics.txt +- The `long` variant takes ~11h on a v2-8 and ~8h on a v3-32. + Example logs at gs://big_vision/distill/bit_pet_long_06-17_0050/big_vision_metrics.txt + +big_vision.trainers.proj.distill.distill \ + --config big_vision/configs/proj/distill/bigsweep_flowers_pet.py:data=flowers,variant=fast \ + --workdir gs://[your_bucket]/big_vision/`date '+%m-%d_%H%M'` \ +""" + +import big_vision.configs.common as bvcc +import big_vision.configs.proj.distill.common as cd +import ml_collections as mlc + +NCLS = dict(flowers=102, pet=37) + + +def get_config(arg=None): + """Config for massive hypothesis-test on pet.""" + arg = bvcc.parse_arg(arg, runlocal=False, data='flowers', variant='medium', crop='inception_crop(128)') + config = mlc.ConfigDict() + + config.input = {} + config.input.data = dict( + name=dict(flowers='oxford_flowers102', pet='oxford_iiit_pet')[arg.data], + split=dict(flowers='train', pet='train[:90%]')[arg.data], + ) + config.input.batch_size = 512 + config.input.cache_raw = True + config.input.shuffle_buffer_size = 50_000 + config.prefetch_to_device = 4 + + config.num_classes = NCLS[arg.data] + config.total_epochs = { + 'flowers': {'fast': 10_000, 'medium': 100_000, 'long': 1_000_000}, + 'pet': {'fast': 1000, 'medium': 3000, 'long': 30_000}, + }[arg.data][arg.variant] + + config.log_training_steps = 100 + config.ckpt_steps = 2500 + + # Model section + config.student_name = 'bit_paper' + config.student = dict(depth=50, width=1) + + config.teachers = ['prof_m'] + config.prof_m_name = 'bit_paper' + config.prof_m_init = cd.inits[f'BiT-M R152x2 {arg.data} rc128'] + config.prof_m = dict(depth=152, width=2) + + # Preprocessing pipeline for student & tacher. + pp_common = ( + '|value_range(-1, 1)' + f'|onehot({config.num_classes}, key="label", key_result="labels")' + '|keep("image", "labels")' + ) + config.input.pp = f'decode|{arg.crop}|flip_lr' + pp_common + ppv = 'decode|resize_small(160)|central_crop(128)' + pp_common + + config.mixup = dict(p=1.0) + + # Distillation settings + config.distance = 'kl' + config.distance_kw = dict(t={ + 'flowers': {'fast': 10., 'medium': 1., 'long': 1.}, + 'pet': {'fast': 5., 'medium': 10., 'long': 2.}, + }[arg.data][arg.variant]) + + # Optimizer section + config.grad_clip_norm = 1.0 + config.optax_name = 'scale_by_adam' + config.optax = dict(mu_dtype='bfloat16') + + config.lr = { + 'flowers': {'fast': 0.003, 'medium': 0.001, 'long': 0.0003}, + 'pet': {'fast': 0.01, 'medium': 0.003, 'long': 0.003}, + }[arg.data][arg.variant] + config.wd = { + 'flowers': {'fast': 3e-4, 'medium': 1e-4, 'long': 1e-5}, + 'pet': {'fast': 1e-3, 'medium': 3e-4, 'long': 1e-5}, + }[arg.data][arg.variant] + config.schedule = dict(warmup_steps=1500, decay_type='cosine') + config.optim_name = 'adam_hp' + + # Eval section + minitrain_split = 'train[:512]' if not arg.runlocal else 'train[:16]' + if arg.data == 'flowers': + val_split = 'validation' if not arg.runlocal else 'validation[:16]' + test_split = 'test' if not arg.runlocal else 'test[:16]' + elif arg.data == 'pet': + val_split = 'train[90%:]' if not arg.runlocal else 'train[:16]' + test_split = 'test' if not arg.runlocal else 'test[:16]' + + def get_eval(split): + return dict( + type='classification', + pred='student_fwd', + data=dict(name=config.input.data.name, split=split), + pp_fn=ppv, + loss_name='softmax_xent', + log_steps=500, + ) + config.evals = {} + config.evals.student_train = get_eval(minitrain_split) + config.evals.student_val = get_eval(val_split) + config.evals.student_test = get_eval(test_split) + + # Teacher is fixed, so rare evals. + teacher = dict(log_steps=100_000, pred='prof_m_fwd') + config.evals.teacher_train = {**config.evals.student_train, **teacher} + config.evals.teacher_val = {**config.evals.student_val, **teacher} + config.evals.teacher_test = {**config.evals.student_test, **teacher} + + # Could in principle also look at agreement on other datasets! + def get_dist(split): + return dict( + type='proj.distill.distance', + pred='student_prof_m_fwd', + data=dict(name=config.input.data.name, split=split), + pp_fn=ppv + '|keep("image")', + log_steps=1000, + distances=({'kind': 'kl'}, {'kind': 'euclidean'}, + {'kind': 'agree', 'k': 1}, {'kind': 'agree', 'k': 5}), + ) + config.evals.dist_train = get_dist(minitrain_split) + config.evals.dist_val = get_dist(val_split) + config.evals.dist_test = get_dist(test_split) + + # Make a few things much smaller for quick local debugging testruns. + if arg.runlocal: + config.input.shuffle_buffer_size = 10 + config.input.batch_size = 8 + + return config \ No newline at end of file diff --git a/big_vision/configs/proj/distill/bigsweep_food_sun.py b/big_vision/configs/proj/distill/bigsweep_food_sun.py new file mode 100644 index 0000000000000000000000000000000000000000..36362e658cc0b0cec13f3b33d2f31cd7cd45ff37 --- /dev/null +++ b/big_vision/configs/proj/distill/bigsweep_food_sun.py @@ -0,0 +1,213 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""Distilling BiT-R152x2 into BiT-R50x1 on Food101/Sun397 as in https://arxiv.org/abs/2106.05237 + +While many epochs are required, this is a small dataset, and thus overall it +is still fast and possible to run on the relatively small v3-8TPUs (or GPUs). + +This configuration contains the recommended settings from Fig3/Tab4 of the +paper, which can be selected via the fast/medium/long config argument. +(best settings were selected on a 10% minival) + +For Food101: +- The `fast` variant takes ~45min on a v2-8 TPU. + Example logs at gs://big_vision/distill/bit_food_fast_06-19_0547/big_vision_metrics.txt + Example logs at gs://big_vision/distill/bit_sun_fast_06-20_1839/big_vision_metrics.txt +- The `long` variant takes ~14h on a v3-8 TPU. + Example logs at gs://big_vision/distill/bit_food_long_06-19_0614/big_vision_metrics.txt + Example logs at gs://big_vision/distill/bit_sun_long_06-20_1912/big_vision_metrics.txt + +big_vision.trainers.proj.distill.distill \ + --config big_vision/configs/proj/distill/bigsweep_food_sun.py:data=food,variant=fast \ + --workdir gs://[your_bucket]/big_vision/`date '+%m-%d_%H%M'` \ +""" + +import big_vision.configs.common as bvcc +import big_vision.configs.proj.distill.common as cd +import ml_collections as mlc + +H, L = 160, 128 +NCLS = dict(food=101, sun=397) + + +def get_config(arg=None): + """Config for massive hypothesis-test on pet.""" + arg = bvcc.parse_arg(arg, runlocal=False, data='food', variant='medium', crop='inception_crop(128)') + config = mlc.ConfigDict() + + config.input = {} + config.input.data = dict( + name=dict(food='food101', sun='sun397')[arg.data], + split=dict(food='train[:90%]', sun='train')[arg.data], + ) + config.input.batch_size = 512 + config.input.cache_raw = True + config.input.shuffle_buffer_size = 50_000 + config.prefetch_to_device = 4 + + config.num_classes = NCLS[arg.data] + config.total_epochs = {'fast': 100, 'medium': 1000, 'long': 3000}[arg.variant] + + config.log_training_steps = 50 + config.ckpt_steps = 2500 + + # Model section + config.student_name = 'bit_paper' + config.student = dict(depth=50, width=1) + + config.teachers = ['prof_m'] + config.prof_m_name = 'bit_paper' + config.prof_m_init = cd.inits[f'BiT-M R152x2 {arg.data} rc128'] + config.prof_m = dict(depth=152, width=2) + + # Preprocessing pipeline for student & tacher. + pp_common = ( + '|value_range(-1, 1)' + f'|onehot({config.num_classes}, key="label", key_result="labels")' + '|keep("image", "labels")' + ) + config.input.pp = f'decode|{arg.crop}|flip_lr' + pp_common + ppv = 'decode|resize_small(160)|central_crop(128)' + pp_common + + config.mixup = dict(p=1.0) + + # Distillation settings + config.distance = 'kl' + config.distance_kw = dict(t={ + 'food': {'fast': 10., 'medium': 10., 'long': 5.}, + 'sun': {'fast': 10., 'medium': 10., 'long': 10.}, + }[arg.data][arg.variant]) + + # Optimizer section + config.grad_clip_norm = 1.0 + config.optax_name = 'scale_by_adam' + config.optax = dict(mu_dtype='bfloat16') + + config.lr = { + 'food': {'fast': 0.01, 'medium': 0.001, 'long': 0.01}, + 'sun': {'fast': 0.01, 'medium': 0.001, 'long': 0.01}, + }[arg.data][arg.variant] + config.wd = { + 'food': {'fast': 1e-3, 'medium': 3e-4, 'long': 1e-4}, + 'sun': {'fast': 1e-3, 'medium': 1e-4, 'long': 3e-5}, + }[arg.data][arg.variant] + config.schedule = dict(warmup_steps=1500, decay_type='cosine') + config.optim_name = 'adam_hp' + + # Eval section + minitrain_split = 'train[:1024]' if not arg.runlocal else 'train[:16]' + if arg.data == 'food': + val_split = 'train[90%:]' if not arg.runlocal else 'train[:16]' + test_split = 'validation' if not arg.runlocal else 'test[:16]' + elif arg.data == 'sun': + val_split = 'validation' if not arg.runlocal else 'validation[:16]' + test_split = 'test' if not arg.runlocal else 'test[:16]' + + def get_eval(split): + return dict( + type='classification', + pred='student_fwd', + data=dict(name=config.input.data.name, split=split), + pp_fn=ppv, + loss_name='softmax_xent', + log_steps=500, + ) + config.evals = {} + config.evals.student_train = get_eval(minitrain_split) + config.evals.student_val = get_eval(val_split) + config.evals.student_test = get_eval(test_split) + + # Teacher is fixed, so rare evals. + teacher = dict(log_steps=100_000, pred='prof_m_fwd') + config.evals.teacher_train = {**config.evals.student_train, **teacher} + config.evals.teacher_val = {**config.evals.student_val, **teacher} + config.evals.teacher_test = {**config.evals.student_test, **teacher} + + # Could in principle also look at agreement on other datasets! + def get_dist(split): + return dict( + type='proj.distill.distance', + pred='student_prof_m_fwd', + data=dict(name=config.input.data.name, split=split), + pp_fn=ppv + '|keep("image")', + log_steps=1000, + distances=({'kind': 'kl'}, {'kind': 'euclidean'}, + {'kind': 'agree', 'k': 1}, {'kind': 'agree', 'k': 5}), + ) + config.evals.dist_train = get_dist(minitrain_split) + config.evals.dist_val = get_dist(val_split) + config.evals.dist_test = get_dist(test_split) + + # Make a few things much smaller for quick local debugging testruns. + if arg.runlocal: + config.input.shuffle_buffer_size = 10 + config.input.batch_size = 8 + + return config + + +def get_hyper(hyper): + """Hyper sweep.""" + # TODO: update, similar to flowers_pet sweep. + # By default, not running the MASSIVE sweep, just the recommended setting + # across durations. However, code for sweep is left for reference/convenience. + return hyper.zipit([ + hyper.sweep('config.total_epochs', [100, 1_000]), + hyper.sweep('config.mixup.p', [0.0, 1.0]), + hyper.sweep('config.weight_decay', [1e-3, 1e-5]), + ]) + + # pylint: disable=unreachable + + def fix(**kw): + return hyper.product([hyper.fixed(f'config.{k}', v, length=1) + for k, v in kw.items()]) + + def setting(p, l, m, crop, pp_end=None, **extra): + pp_end = pp_end or ( + f'|value_range(-1, 1, key="image")' + f'|onehot({NCLS}, key="label", key_result="labels")' + f'|keep("image", "labels")' + ) + return hyper.product([ + fix(**{'mixup.p': p}), + fix(l=l, m=m, crop=crop), + fix(pp_train=f'decode|{crop}|flip_lr|randaug({l},{m})' + pp_end), + fix(**extra) + ]) + + # Mixup, Layers and Mag in randaug. + plm = [(0.0, 0, 0), (0.1, 0, 0), (0.5, 0, 0), (1.0, 0, 0)] + return hyper.product([ + hyper.sweep('config.total_epochs', [100, 1000, 3000]), + hyper.sweep('config.lr.base', [0.001, 0.003, 0.01]), + hyper.sweep('config.distance_kw.t', [1.0, 2.0, 5.0, 10.0]), + hyper.sweep('config.weight_decay', [1e-5, 3e-5, 1e-4, 3e-4, 1e-3]), + hyper.chainit( + [setting(p=p, l=l, m=m, + crop=(f'resize({H})' + f'|inception_crop({L}, outkey="student")' + f'|central_crop({L}, outkey="teacher")'), + pp_end=( + f'|value_range(-1, 1, key="student")' + f'|value_range(-1, 1, key="teacher")' + f'|onehot({NCLS}, key="label", key_result="labels")' + f'|keep("student", "teacher", "labels")')) + for p, l, m in plm] + + [setting(p=p, l=l, m=m, crop=f'inception_crop({L})') for + p, l, m in plm], + ) + ]) diff --git a/big_vision/configs/proj/distill/bit_i1k.py b/big_vision/configs/proj/distill/bit_i1k.py new file mode 100644 index 0000000000000000000000000000000000000000..24c2a0388109de1e3de6504422c54f0317cbed8a --- /dev/null +++ b/big_vision/configs/proj/distill/bit_i1k.py @@ -0,0 +1,167 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""Distilling BiT-R152x2 into BiT-R50x1 on ILSVRC-2012 as in https://arxiv.org/abs/2106.05237 + +Note that as per paper title, good results require many epochs and thus +a lot of _patience_. For experimentation/exploration, consider +using the smaller datasets. + +300ep take about 15h on a v3-32 TPU, an example log is available at: + Example logs at gs://big_vision/distill/bit_i1k_300ep_06-16/big_vision_metrics.txt + +big_vision.trainers.proj.distill.distill \ + --config big_vision/configs/proj/distill/bit_i1k.py \ + --workdir gs://[your_bucket]/big_vision/`date '+%m-%d_%H%M'` \ + --config.total_epochs 1200 +""" + +import big_vision.configs.common as bvcc +from big_vision.configs.common_fewshot import get_fewshot_lsr +import big_vision.configs.proj.distill.common as cd +import ml_collections as mlc + + +def get_config(arg=None): + """Config for distilling on ImageNet.""" + arg = bvcc.parse_arg(arg, runlocal=False) + config = mlc.ConfigDict() + + config.input = {} + config.input.data = dict(name='imagenet2012', split='train[:98%]') + config.input.batch_size = 4096 + config.input.shuffle_buffer_size = 250_000 + + config.num_classes = 1000 + config.total_epochs = 1200 # A good middle-ground + + config.log_training_steps = 50 + config.ckpt_steps = 1000 + config.keep_ckpt_steps = 20000 + + # Model section + config.student_name = 'bit_paper' + config.student = dict(depth=50, width=1) + + config.teachers = ['prof_m'] # You could even add multiple. + + # TODO: use public checkpoint name. + config.prof_m_name = 'bit_paper' + config.prof_m_init = cd.inits['BiT-M R152x2 imagenet2012 ic224'] + config.prof_m = dict(depth=152, width=2) + + pp_common = ( + '|value_range(-1, 1)' + '|onehot(1000, key="{lbl}", key_result="labels")' + '|keep("image", "labels")' + ) + config.input.pp = ( + 'decode_jpeg_and_inception_crop(224)|flip_lr' + + pp_common.format(lbl='label') + ) + ppv = 'decode|resize_small(256)|central_crop(224)' + pp_common + + config.mixup = dict(p=1.0) + + # Distillation settings + config.distance = 'kl' + config.distance_kw = dict(t=1.0) + + # Optimizer section + config.grad_clip_norm = 1.0 + config.optax_name = 'scale_by_adam' + config.optax = dict(mu_dtype='bfloat16') + + config.lr = 0.03 + config.wd = 0.0003 + config.schedule = dict(warmup_steps=5000, decay_type='cosine') + + # Eval section + minitrain_split = 'train[:2%]' if not arg.runlocal else 'train[:16]' + minival_split = 'train[99%:]' if not arg.runlocal else 'train[:16]' + val_split = 'validation' if not arg.runlocal else 'validation[:16]' + real_split = 'validation' if not arg.runlocal else 'validation[:16]' + v2_split = 'test' if not arg.runlocal else 'test[:16]' + + def get_eval(split, dataset='imagenet2012'): + return dict( + type='classification', + pred='student_fwd', + data=dict(name=dataset, split=split), + pp_fn=ppv.format(lbl='label'), + loss_name='softmax_xent', + log_steps=1000, + ) + + config.evals = {} + config.evals.student_train = get_eval(minitrain_split) + config.evals.student_minival = get_eval(minival_split) + config.evals.student_val = get_eval(val_split) + config.evals.student_v2 = get_eval(v2_split, dataset='imagenet_v2') + config.evals.student_real = get_eval(real_split, dataset='imagenet2012_real') + config.evals.student_real.pp_fn = ppv.format(lbl='real_label') + + config.evals.student_fewshot = get_fewshot_lsr(runlocal=arg.runlocal) + config.evals.student_fewshot.pred = 'student_fwd' + config.evals.student_fewshot.log_steps = 10_000 + + teacher_eval = dict( + log_steps=100_000, # Teacher is fixed, so rare evals. + pred='prof_m_fwd', + ) + config.evals.teacher_train = {**config.evals.student_train, **teacher_eval} + config.evals.teacher_minival = {**config.evals.student_minival, **teacher_eval} + config.evals.teacher_val = {**config.evals.student_val, **teacher_eval} + config.evals.teacher_v2 = {**config.evals.student_v2, **teacher_eval} + config.evals.teacher_real = {**config.evals.student_real, **teacher_eval} + config.evals.teacher_fewshot = {**config.evals.student_fewshot, **teacher_eval} + config.evals.teacher_fewshot.prefix = 'z_teacher/' + + # Could in principle also look at agreement on other datasets! + def get_dist(split, dataset='imagenet2012'): + return dict( + type='proj.distill.distance', + pred='student_prof_m_fwd', + data=dict(name=dataset, split=split), + pp_fn=ppv.format(lbl='label') + '|keep("image")', + log_steps=1000, + distances=({'kind': 'kl'}, {'kind': 'euclidean'}, + {'kind': 'agree', 'k': 1}, {'kind': 'agree', 'k': 5}), + ) + config.evals.dist_train = get_dist(minitrain_split) + config.evals.dist_minival = get_dist(minival_split) + config.evals.dist_val = get_dist(val_split) + config.evals.dist_v2 = get_dist(v2_split, dataset='imagenet_v2') + + # NOTE: CKA evaluator does not work with batch padding, so the size of the + # split must be a multiple of the batch size. + def get_cka(split): + return dict( + type='proj.distill.cka', + pred='student_prof_m_fwd', + data=dict(name='imagenet2012', split=split), + pp_fn=ppv.format(lbl='label') + '|keep("image")', + log_steps=1000, + ) + config.evals.cka_train = get_cka('train[:24576]' if not arg.runlocal else 'train[:16]') + config.evals.cka_minival = get_cka('train[-24576:]' if not arg.runlocal else 'train[:16]') + config.evals.cka_val = get_cka('validation[:49152]' if not arg.runlocal else 'validation[:16]') + + # Make a few things much smaller for quick local debugging testruns. + if arg.runlocal: + config.input.shuffle_buffer_size = 10 + config.input.batch_size = 8 + + return config \ No newline at end of file diff --git a/big_vision/configs/proj/distill/common.py b/big_vision/configs/proj/distill/common.py new file mode 100644 index 0000000000000000000000000000000000000000..f4bb89ab540878672424e95856194ff98f3616f5 --- /dev/null +++ b/big_vision/configs/proj/distill/common.py @@ -0,0 +1,27 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Most common teachers for distillation.""" + +# pylint: disable=line-too-long +inits = { # pylint: disable=duplicate-key Internally, we override some paths for convenience. + 'BiT-M R152x2 imagenet2012 ic224': 'gs://bit_models/distill/R152x2_T_224.npz', + 'BiT-M R152x2 imagenet2012 rc384': 'gs://bit_models/distill/R152x2_T_384.npz', + 'BiT-M R152x2 flowers rc128': 'gs://bit_models/distill/R152x2_T_flowers128.npz', + 'BiT-M R152x2 pet rc128': 'gs://bit_models/distill/R152x2_T_pet128.npz', + 'BiT-M R152x2 food rc128': 'gs://bit_models/distill/R152x2_T_food128.npz', + 'BiT-M R152x2 sun rc128': 'gs://bit_models/distill/R152x2_T_sun128.npz', + +} +# pylint: enable=line-too-long diff --git a/big_vision/configs/proj/flexivit/README.md b/big_vision/configs/proj/flexivit/README.md new file mode 100644 index 0000000000000000000000000000000000000000..2d41d8a954cb199a5b481e30eb32a1095a8523dc --- /dev/null +++ b/big_vision/configs/proj/flexivit/README.md @@ -0,0 +1,64 @@ +# FlexiViT: One Model for All Patch Sizes +*by Lucas Beyer, Pavel Izmailov, Alexander Kolesnikov, Mathilde Caron, Simon Kornblith, Xiaohua Zhai, Matthias Minderer, Michael Tschannen, Ibrahim Alabdulmohsin, Filip Pavetic* + +## Introduction +We publish all pre-trained FlexiViT models, and configurations for training +those, as well as training logs for one run. + +Please read the main [big_vision README](/README.md) to learn how to run +configs, and remember that each config file contains an example invocation in +the top-level comment. + +## Pre-trained paper models + +Here are the models that we used as backbones in the paper. See Tables in the +appendix of the paper for expected scores at various patch-sizes and on various +datasets. + +First, the recommended models we used for all experiments. +Remember that the input is 240px, not 224px: + +| Dataset | Model | Download link | Notes | +| :--- | :---: | :---: | :---: | +| ImageNet-1k | FlexiViT-L | [link](https://storage.googleapis.com/big_vision/flexivit/flexivit_l_i1k.npz) | 1200ep version | +| ImageNet-1k | FlexiViT-B | [link](https://storage.googleapis.com/big_vision/flexivit/flexivit_b_i1k.npz) | 1200ep version | +| ImageNet-1k | FlexiViT-S | [link](https://storage.googleapis.com/big_vision/flexivit/flexivit_s_i1k.npz) | 1200ep version | +| ImageNet-21k | FlexiViT-B | [link](https://storage.googleapis.com/big_vision/flexivit/flexivit_b_i21k_300ep.npz) | 300ep version. 1000ep version below is better but was not used in the paper for fair comparison to baselines. | +| ImageNet-21k | ViT-B/16 | [link](https://storage.googleapis.com/big_vision/flexivit/vit_b16_i21k_300ep.npz) | Apples-to-apples non-flexi baseline used throughout the paper. | +| ImageNet-21k | ViT-B/30 | [link](https://storage.googleapis.com/big_vision/flexivit/vit_b30_i21k_300ep.npz) | Apples-to-apples non-flexi baseline used throughout the paper. | + +These models can be used directly in our codebase by specifying +`model_name = "proj.flexi.vit"` and `model_init = "FlexiViT-L i1k"` for example. +See the file `models/proj/flexi/vit.py` for more names. + +*Important detail:* When further re-using these models with a flexible patch +size, it is recommended to keep the patch-embedding parameter buffer at its +original size, and change patch-size on the fly using pi-resize, as opposed to +changing the parameter buffer's size at load-time. +For re-using the models with a fixed patch size, either way is fine. +(The reason is that it is impossible to chain multiple resizes without loss, +eg doing 32->8->32 does not result in the original weights.) + +Second, the list of all released models for completeness: + +| Dataset | Model | Download link | Notes | +| :--- | :---: | :---: | :---: | +| ImageNet-21k | FlexiViT-B | [link](https://storage.googleapis.com/big_vision/flexivit/flexivit_b_i21k_1000ep.npz) | 1000ep version. Should be the best available -B model. | +| ImageNet-21k | FlexiViT-B | [link](https://storage.googleapis.com/big_vision/flexivit/flexivit_b_i21k_90ep.npz) | 90ep version | +| ImageNet-1k | FlexiViT-L | [link](https://storage.googleapis.com/big_vision/flexivit/flexivit_l_i1k_600ep.npz) | 600ep version | +| ImageNet-1k | FlexiViT-L | [link](https://storage.googleapis.com/big_vision/flexivit/flexivit_l_i1k_300ep.npz) | 300ep version | +| ImageNet-1k | FlexiViT-L | [link](https://storage.googleapis.com/big_vision/flexivit/flexivit_l_i1k_90ep.npz) | 90ep version | +| ImageNet-1k | FlexiViT-B | [link](https://storage.googleapis.com/big_vision/flexivit/flexivit_b_i1k_600ep.npz) | 600ep version | +| ImageNet-1k | FlexiViT-B | [link](https://storage.googleapis.com/big_vision/flexivit/flexivit_b_i1k_300ep.npz) | 300ep version | +| ImageNet-1k | FlexiViT-B | [link](https://storage.googleapis.com/big_vision/flexivit/flexivit_b_i1k_90ep.npz) | 90ep version | +| ImageNet-1k | FlexiViT-S | [link](https://storage.googleapis.com/big_vision/flexivit/flexivit_s_i1k_600ep.npz) | 600ep version | +| ImageNet-1k | FlexiViT-S | [link](https://storage.googleapis.com/big_vision/flexivit/flexivit_s_i1k_300ep.npz) | 300ep version | +| ImageNet-1k | FlexiViT-S | [link](https://storage.googleapis.com/big_vision/flexivit/flexivit_s_i1k_90ep.npz) | 90ep version | + +## Results + +We provide full training logs for a run with this public code on Cloud that +reproduces the FlexiViT-S 90ep on i1k results: + - [metrics](https://storage.googleapis.com/big_vision/flexivit/deit3_i1k_s_90ep_12-15_2254/big_vision_metrics.txt) + - [config](https://storage.googleapis.com/big_vision/flexivit/deit3_i1k_s_90ep_12-15_2254/config.json) + - or `gs://big_vision/flexivit/deit3_i1k_s_90ep_12-15_2254`. diff --git a/big_vision/configs/proj/flexivit/i1k_deit3_distill.py b/big_vision/configs/proj/flexivit/i1k_deit3_distill.py new file mode 100644 index 0000000000000000000000000000000000000000..fd7bf3c14b7ef48279fe75ff653fb4f25434762a --- /dev/null +++ b/big_vision/configs/proj/flexivit/i1k_deit3_distill.py @@ -0,0 +1,187 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""Distillation of ViT models into FlexiViT on ImageNet1k. + +Run training of the -S variant for 90ep: + +big_vision.trainers.proj.flexi.distill \ + --config big_vision/configs/proj/flexivit/i1k_deit3_distill.py \ + --workdir gs://[your_bucket]/big_vision/`date '+%m-%d_%H%M'` \ + --config.total_epochs 90 --config.variant S + +Logdir for one reproduction run: + - gs://big_vision/flexivit/deit3_i1k_s_90ep_12-15_2254 + +Timing on Cloud: + - S on v3-32: Walltime:10h16m (4h39m eval) + +Note that we did not optimize the input for Cloud, +with tuned caching and prefetching, we should be able to get: + - S on v3-32: Walltime: ~6h30m (~1h30m eval) + - B on v3-32: Walltime: ~16h00m (~2h30m eval) +""" + +import big_vision.configs.common as bvcc + + +def get_config(arg=None): + """Config for distilling ViT on ImageNet1k.""" + c = bvcc.parse_arg(arg, runlocal=False, res=240) + + c.seed = 0 + c.total_epochs = 90 + c.num_classes = 1000 + c.loss = 'softmax_xent' + + c.input = {} + c.input.data = dict( + name='imagenet2012', + split='train[:99%]', + ) + c.input.batch_size = 1024 if not c.runlocal else 8 + c.input.cache_raw = False # Needs up to 120GB of RAM! + c.input.shuffle_buffer_size = 250_000 if not c.runlocal else 10 + + c.log_training_steps = 50 + c.ckpt_steps = 1000 + + # Model section + c.variant = 'B' + init = bvcc.format_str('deit3_{variant}_384_1k', c) + c.student_name = 'proj.flexi.vit' + c.student_init = init + c.student = dict(variant=c.get_ref('variant'), pool_type='tok', patch_size=(16, 16)) + + c.teachers = ['prof'] # You could even add multiple. + c.prof_name = 'vit' + c.prof_init = init + c.prof = dict(variant=c.get_ref('variant'), pool_type='tok', patch_size=(16, 16)) + + pp_label = '|onehot(1000, key="{lbl}", key_result="labels")|keep("image", "prof", "labels")' + c.input.pp = ( + f'decode|inception_crop|flip_lr' + '|copy("image", "prof")' + f'|resize({c.res})|value_range' + '|resize(384, key="prof")|value_range(key="prof")' + + pp_label.format(lbl='label') + ) + pp_eval_both = ( + 'decode|copy("image", "prof")|' + f'|resize({c.res//7*8})|central_crop({c.res})|value_range' + f'|resize({384//7*8}, key="prof")|central_crop(384, key="prof")|value_range(key="prof")|' + ) + pp_eval_student = ( + f'decode|resize({c.res//7*8})|central_crop({c.res})|value_range(-1, 1)' + ) + pp_eval_prof = ( + f'decode|resize({384//7*8})|central_crop(384)|value_range(outkey="prof")' + ) + + c.mixup = dict(p=1.0, n=2) + + # Distillation settings + c.distance = 'kl' + c.distance_kw = dict(t=1.0) + + # Optimizer section + c.grad_clip_norm = 1.0 + c.optax_name = 'scale_by_adam' + c.optax = dict(mu_dtype='bfloat16') + + c.lr = 1e-4 + c.wd = 1e-5 + c.schedule = dict(warmup_steps=5000, decay_type='cosine') + + # Define the model parameters which are flexible: + c.flexi = dict() + c.flexi.seqhw = dict( + # The settings to sample from. Corresponding patch-sizes at 240px: + # 48, 40, 30, 24, 20, 16, 15, 12, 10, 8 + v=(5, 6, 8, 10, 12, 15, 16, 20, 24, 30), + # The probabilities/weights of them. Default uniform. + p=(1, 1, 1, 1, 1, 1, 1, 1, 1, 1), + ) + + # Eval section + def mksplit(split): + if c.runlocal: + return split.split('[')[0] + '[:16]' + return split + + minitrain_split = mksplit('train[:2%]') + minival_split = mksplit('train[99%:]') + val_split = mksplit('validation') + test_split = mksplit('test') + c.aggressive_cache = False + + def get_eval(s, split, dataset='imagenet2012'): + return dict( + type='classification', + pred=f'student_seqhw={s}', + data=dict(name=dataset, split=split), + pp_fn=pp_eval_student + pp_label.format(lbl='label'), + loss_name='sigmoid_xent', + log_percent=0.05, + cache_final=False, + ) + + c.evals = {} + for s in c.flexi.seqhw.v: + c.evals[f'student_minitrain_{s:02d}'] = get_eval(s, minitrain_split) + c.evals[f'student_minival_{s:02d}'] = get_eval(s, minival_split) + c.evals[f'student_val_{s:02d}'] = get_eval(s, val_split) + c.evals[f'student_v2_{s:02d}'] = get_eval(s, test_split, 'imagenet_v2') + c.evals[f'student_a_{s:02d}'] = get_eval(s, test_split, 'imagenet_a') + c.evals[f'student_r_{s:02d}'] = get_eval(s, test_split, 'imagenet_r') + c.evals[f'student_real_{s:02d}'] = get_eval(s, val_split, 'imagenet2012_real') + c.evals[f'student_real_{s:02d}'].pp_fn = pp_eval_student + pp_label.format(lbl='real_label') + + def get_eval_t(split, dataset='imagenet2012'): + return dict( + type='classification', + pred='prof', + data=dict(name=dataset, split=split), + pp_fn=pp_eval_prof + pp_label.format(lbl='label'), + loss_name='sigmoid_xent', + log_percent=0.5, # Teacher is fixed, so eval just for plots. + cache_final=False, + ) + c.evals.teacher_minitrain = get_eval_t(minitrain_split) + c.evals.teacher_minival = get_eval_t(minival_split) + c.evals.teacher_val = get_eval_t(val_split) + c.evals.teacher_v2 = get_eval_t(test_split, 'imagenet_v2') + c.evals.teacher_a = get_eval_t(test_split, 'imagenet_a') + c.evals.teacher_r = get_eval_t(test_split, 'imagenet_r') + c.evals.teacher_real = get_eval_t(val_split, 'imagenet2012_real') + c.evals.teacher_real.pp_fn = pp_eval_prof + pp_label.format(lbl='real_label') + + # Distance evaluators + def get_dist(split, s): + return dict( + type='proj.distill.distance', + pred=f'student_seqhw={s}_prof', + data=dict(name='imagenet2012', split=split), + pp_fn=pp_eval_both + '|keep("image", "prof")', + log_percent=0.05, + distances=({'kind': 'kl'}, {'kind': 'logsoftmax_euclidean'}, + {'kind': 'agree', 'k': 1}, {'kind': 'agree', 'k': 5}), + cache_final=False, + ) + for s in c.flexi.seqhw.v: + c.evals[f'dist_minitrain_{s:02d}'] = get_dist(minitrain_split, s) + c.evals[f'dist_val_{s:02d}'] = get_dist(val_split, s) + + return c \ No newline at end of file diff --git a/big_vision/configs/proj/flexivit/i21k_distill.py b/big_vision/configs/proj/flexivit/i21k_distill.py new file mode 100644 index 0000000000000000000000000000000000000000..30eddec5785e9a551fa5caf825820a43dfdf76ae --- /dev/null +++ b/big_vision/configs/proj/flexivit/i21k_distill.py @@ -0,0 +1,216 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""Distill flexible-seqlen ViT on ImageNet-21k from (internal link) B/8. + +This config is for reference, we never ran it on public infrastructure. + +big_vision.trainers.proj.flexi.distill \ + --config big_vision/configs/proj/flexivit/i21k_distill.py \ + --workdir gs://[your_bucket]/big_vision/`date '+%m-%d_%H%M'` \ + --config.total_epochs 90 +""" + +import big_vision.configs.common as bvcc + + +def get_config(arg=None): + """Config for training.""" + # 240px is nice because it's divisible by + # [240, 120, 80, 60, 48, 40, 30, 24, 20, 16, 15, 12, 10, 8, 6, 5, 4, 3, 2, 1] + c = bvcc.parse_arg(arg, runlocal=False, res=240) + + c.seed = 0 + c.total_epochs = 90 + c.num_classes = 21843 + c.init_head_bias = -10.0 + c.loss = 'sigmoid_xent' + + c.input = dict() + c.input.data = dict( + name='imagenet21k', + split='full[51200:]', + ) + c.input.batch_size = 4096 if not c.runlocal else 8 + c.input.shuffle_buffer_size = 250_000 if not c.runlocal else 25 + + pp_label_i21k = f'|onehot({c.num_classes})|keep("image", "prof", "labels")' + pp_label_i1k = '|onehot(1000, key="{lbl}", key_result="labels")|keep("image", "prof", "labels")' + c.input.pp = ( + f'decode|inception_crop|flip_lr|copy("image", "prof")' + f'|resize({c.res})|value_range(-1, 1)' + f'|resize(224, outkey="prof")|value_range(-1, 1, key="prof")' + + pp_label_i21k + ) + pp_eval_both = ( + 'decode|copy("image", "prof")|' + f'|resize_small({c.res//7*8})|central_crop({c.res})|value_range(-1, 1)' + f'|resize_small(256, key="prof")|central_crop(224, key="prof")|value_range(-1, 1, key="prof")|' + ) + pp_eval_student = ( + f'decode|resize({c.res//7*8})|central_crop({c.res})|value_range(-1, 1)' + ) + pp_eval_prof = ( + 'decode|resize(256)|central_crop(224)|value_range(-1, 1, outkey="prof")' + ) + + # Aggressive pre-fetching because our models here are small, so we not only + # can afford it, but we also need it for the smallest models to not be + # bottle-necked by the input pipeline. Play around with it for -L models tho. + c.input.prefetch = 8 + c.prefetch_to_device = 4 + + c.log_training_steps = 50 + c.ckpt_steps = 1000 + + # Model section + init = 'howto-i21k-B/8' + c.student_name = 'proj.flexi.vit' + c.student_init = init + c.student = dict(variant='B', pool_type='tok', patch_size=(8, 8)) + + c.teachers = ['prof'] # You could even add multiple. + c.prof_name = 'vit' + c.prof_init = init + c.prof = dict(variant='B/8', pool_type='tok') + + # Define the model parameters which are flexible: + c.flexi = dict() + c.flexi.seqhw = dict( + # The settings to sample from. Corresponding patch-sizes at 240px: + # 48, 40, 30, 24, 20, 16, 15, 12, 10, 8 + v=(5, 6, 8, 10, 12, 15, 16, 20, 24, 30), + # The probabilities/weights of them. Default uniform. + p=(1, 1, 1, 1, 1, 1, 1, 1, 1, 1), + ) + + # Distillation settings + c.distance = 'kl' + c.distance_kw = dict(t=1.0) + + # Optimizer section + c.optax_name = 'scale_by_adam' + c.optax = dict(mu_dtype='bfloat16') + c.grad_clip_norm = 1.0 + + c.lr = 1e-4 + c.wd = 1e-5 + c.schedule = dict(warmup_steps=5000, decay_type='cosine') + + c.mixup = dict(p=1.0) + + #### + # Preparing for evals + c.evals = {} + def mksplit(split): + if c.runlocal: + return split.split('[')[0] + '[:16]' + return split + + #### + # Student evals + + # Evaluations on i21k itself. + def eval_i21k(s, split): + return dict( + type='classification', + pred=f'student_seqhw={s}', + data={**c.input.data, 'split': mksplit(split)}, + pp_fn=pp_eval_student + pp_label_i21k, + loss_name=c.loss, + log_steps=5000, # Very fast O(seconds) so it's fine to run it often. + ) + + for s in c.flexi.seqhw.v: + c.evals[f'student_test{s:02d}'] = eval_i21k(s, 'full[:25_600]') + c.evals[f'student_val{s:02d}'] = eval_i21k(s, 'full[25_600:51_200]') + c.evals[f'student_minitrain{s:02d}'] = eval_i21k(s, 'full[51_200:76_800]') + + # Evaluations on ImageNet1k variants by label-mapping. + def eval_i1k(s, dataset, split, lblmap): + return dict( + type='classification_with_labelmap', + pred=f'student_seqhw={s}', + data=dict(name=dataset, split=mksplit(split)), + pp_fn=pp_eval_student + pp_label_i1k.format(lbl='label'), + loss_name=c.loss, + log_steps=5000, # Very fast O(seconds) so it's fine to run it often. + label_mapping=lblmap, + ) + for s in c.flexi.seqhw.v: + c.evals[f'student_i1k_val{s:02d}'] = eval_i1k(s, 'imagenet2012', 'validation', 'i1k_i21k') + c.evals[f'student_i1k_v2{s:02d}'] = eval_i1k(s, 'imagenet_v2', 'test', 'i1k_i21k') + c.evals[f'student_i1k_a{s:02d}'] = eval_i1k(s, 'imagenet_a', 'test', 'i1ka_i21k') + c.evals[f'student_i1k_r{s:02d}'] = eval_i1k(s, 'imagenet_r', 'test', 'i1kr_i21k') + c.evals[f'student_i1k_real{s:02d}'] = eval_i1k(s, 'imagenet2012_real', 'validation', 'i1k_i21k') + c.evals[f'student_i1k_real{s:02d}'].pp_fn = pp_eval_student + pp_label_i1k.format(lbl='real_label') + # TODO: add objectnet. + + #### + # Teacher evals + + # Evaluations on i21k itself. + def eval_i21k_t(split): + return dict( + type='classification', + pred='prof', + data={**c.input.data, 'split': mksplit(split)}, + pp_fn=pp_eval_prof + pp_label_i21k, + loss_name=c.loss, + log_steps=5000, # Very fast O(seconds) so it's fine to run it often. + ) + + c.evals.teacher_test = eval_i21k_t('full[:25_600]') + c.evals.teacher_val = eval_i21k_t('full[25_600:51_200]') + c.evals.teacher_minitrain = eval_i21k_t('full[51_200:76_800]') + + # Evaluations on ImageNet1k variants by label-mapping. + def eval_i1k_t(dataset, split, lblmap): + return dict( + type='classification_with_labelmap', + pred='prof', + data=dict(name=dataset, split=mksplit(split)), + pp_fn=pp_eval_prof + pp_label_i1k.format(lbl='label'), + loss_name=c.loss, + log_percent=0.5, # Teacher is fixed, so eval just for plots. + label_mapping=lblmap, + ) + c.evals.teacher_i1k_val = eval_i1k_t('imagenet2012', 'validation', 'i1k_i21k') + c.evals.teacher_i1k_v2 = eval_i1k_t('imagenet_v2', 'test', 'i1k_i21k') + c.evals.teacher_i1k_a = eval_i1k_t('imagenet_a', 'test', 'i1ka_i21k') + c.evals.teacher_i1k_r = eval_i1k_t('imagenet_r', 'test', 'i1kr_i21k') + c.evals.teacher_i1k_real = eval_i1k_t('imagenet2012_real', 'validation', 'i1k_i21k') + c.evals.teacher_i1k_real.pp_fn = pp_eval_prof + pp_label_i1k.format(lbl='real_label') + # TODO: add objectnet. + + #### + # Combined evals + + def get_dist(split, s): + return dict( + type='proj.distill.distance', + pred=f'student_seqhw={s}_prof', + data=dict(name='imagenet2012', split=mksplit(split)), + pp_fn=pp_eval_both + '|keep("image", "prof")', + log_percent=0.05, + distances=({'kind': 'kl'}, {'kind': 'logsoftmax_euclidean'}, + {'kind': 'agree', 'k': 1}, {'kind': 'agree', 'k': 5}), + ) + for s in c.flexi.seqhw.v: + c.evals[f'dist_minitrain_{s:02d}'] = get_dist('full[51_200:76_800]', s) + c.evals[f'dist_val_{s:02d}'] = get_dist('full[25_600:51_200]', s) + + # Few-shot evaluators not added for overkill reasons for now. + return c diff --git a/big_vision/configs/proj/flexivit/i21k_sup.py b/big_vision/configs/proj/flexivit/i21k_sup.py new file mode 100644 index 0000000000000000000000000000000000000000..cca2443024bc4a4fe9c9474a2e34f9d53d4e52a9 --- /dev/null +++ b/big_vision/configs/proj/flexivit/i21k_sup.py @@ -0,0 +1,144 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""Pre-training flexible-seqlen ViT on ImageNet-21k following (internal link). + +This config is for reference, we never ran it on public infrastructure. + +big_vision.trainers.proj.flexi.train \ + --config big_vision/configs/proj/flexivit/i21k_sup.py \ + --workdir gs://[your_bucket]/big_vision/`date '+%m-%d_%H%M'` \ + --config.total_epochs 90 +""" + +import big_vision.configs.common as bvcc + + +def get_config(arg=None): + """Config for training.""" + # 240px is nice because it's divisible by + # [240, 120, 80, 60, 48, 40, 30, 24, 20, 16, 15, 12, 10, 8, 6, 5, 4, 3, 2, 1] + c = bvcc.parse_arg(arg, runlocal=False, res=240) + + c.seed = 0 + c.total_epochs = 90 + c.num_classes = 21843 + c.init_head_bias = -10.0 + c.loss = 'sigmoid_xent' + + c.input = dict() + c.input.data = dict( + name='imagenet21k', + split='full[51200:]', + ) + c.input.batch_size = 4096 if not c.runlocal else 8 + c.input.shuffle_buffer_size = 250_000 if not c.runlocal else 25 + + pp_common = '|value_range(-1, 1)|onehot({onehot_args})|keep("image", "labels")' + pp_common_i21k = pp_common.format(onehot_args=f'{c.num_classes}') + pp_common_i1k = pp_common.format(onehot_args='1000, key="{lbl}", key_result="labels"') + c.input.pp = f'decode_jpeg_and_inception_crop({c.res})|flip_lr|randaug(2,10)' + pp_common_i21k + def pp_eval(res=c.res): + return f'decode|resize_small({res//7*8})|central_crop({res})' + + # To continue using the near-defunct randaug op. + c.pp_modules = ['ops_general', 'ops_image', 'ops_text', 'archive.randaug'] + + # Aggressive pre-fetching because our models here are small, so we not only + # can afford it, but we also need it for the smallest models to not be + # bottle-necked by the input pipeline. Play around with it for -L models tho. + c.input.prefetch = 8 + c.prefetch_to_device = 4 + + c.log_training_steps = 50 + c.ckpt_steps = 1000 + + # Model section + c.model_name = 'proj.flexi.vit' + c.model = dict( + variant='B', + pool_type='tok', + posemb='learn', + # patch_size=(32, 32), + patch_size=(8, 8), + posemb_size=(7, 7), + seqhw=None, # Dynamic! + ) + + # Define the model parameters which are flexible: + c.flexi = dict() + c.flexi.seqhw = dict( + # The settings to sample from. Corresponding patch-sizes at 240px: + # 48, 40, 30, 24, 20, 16, 15, 12, 10, 8 + v=(5, 6, 8, 10, 12, 15, 16, 20, 24, 30), + # The probabilities/weights of them. Default uniform. + p=(1, 1, 1, 1, 1, 1, 1, 1, 1, 1), + ) + + # Optimizer section + c.optax_name = 'scale_by_adam' + c.optax = dict(mu_dtype='bfloat16') + c.grad_clip_norm = 1.0 + + c.lr = 0.001 + c.wd = 0.0001 + c.schedule = dict(warmup_steps=10_000, decay_type='cosine') + + c.mixup = dict(p=0.2, fold_in=None) + + def mksplit(split): + if c.runlocal: + return split.split('[')[0] + '[:16]' + return split + + # Evaluations on i21k itself. + def eval_i21k(s, split): + return dict( + type='classification', + pred=f'predict_seqhw={s}', + data={**c.input.data, 'split': mksplit(split)}, + pp_fn=pp_eval() + pp_common_i21k, + loss_name=c.loss, + log_steps=5000, # Very fast O(seconds) so it's fine to run it often. + ) + + c.evals = {} + for s in c.flexi.seqhw.v: + c.evals[f'test{s:02d}'] = eval_i21k(s, 'full[:25_600]') + c.evals[f'val{s:02d}'] = eval_i21k(s, 'full[25_600:51_200]') + c.evals[f'train{s:02d}'] = eval_i21k(s, 'full[51_200:76_800]') + + # Evaluations on ImageNet1k variants by label-mapping. + def eval_i1k(s, dataset, split, lblmap): + return dict( + type='classification_with_labelmap', + pred=f'predict_seqhw={s}', + data=dict(name=dataset, split=mksplit(split)), + pp_fn=pp_eval() + pp_common_i1k.format(lbl='label'), + loss_name=c.loss, + log_steps=5000, # Very fast O(seconds) so it's fine to run it often. + label_mapping=lblmap, + ) + for s in c.flexi.seqhw.v: + c.evals[f'i1k_val{s:02d}'] = eval_i1k(s, 'imagenet2012', 'validation', 'i1k_i21k') + c.evals[f'i1k_v2{s:02d}'] = eval_i1k(s, 'imagenet_v2', 'test', 'i1k_i21k') + c.evals[f'i1k_a{s:02d}'] = eval_i1k(s, 'imagenet_a', 'test', 'i1ka_i21k') + c.evals[f'i1k_r{s:02d}'] = eval_i1k(s, 'imagenet_r', 'test', 'i1kr_i21k') + c.evals[f'i1k_real{s:02d}'] = eval_i1k(s, 'imagenet2012_real', 'validation', 'i1k_i21k') + c.evals[f'i1k_real{s:02d}'].pp_fn = pp_eval() + pp_common_i1k.format(lbl='real_label') + # TODO: add objectnet. + + # Few-shot evaluators not added for overkill reasons for now. + return c diff --git a/big_vision/configs/proj/flexivit/timing.py b/big_vision/configs/proj/flexivit/timing.py new file mode 100644 index 0000000000000000000000000000000000000000..07514087765f6959309bf473a90932590346613f --- /dev/null +++ b/big_vision/configs/proj/flexivit/timing.py @@ -0,0 +1,53 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long,missing-function-docstring +r"""A config to run timing for FlexiViT (only inference, no I/O etc.). + +big_vision.tools.eval_only \ + --config big_vision/configs/proj/flexivit/timing.py \ + --workdir gs://[your_bucket]/big_vision/`date '+%m-%d_%H%M'` \ + --config.total_epochs 90 +""" + +from ml_collections import ConfigDict + + +def get_config(): + c = ConfigDict() + + shape = (240, 240, 3) + c.batch_size = 8 # swept + c.init_shapes = [(1, *shape)] + c.representation_layer = 'pre_logits' + + # Creating complete model using all params, the sweep will go over variants. + c.model_name = 'xp.flexivit.vit' + c.model = dict( + variant='B', + pool_type='tok', + patch_size=(10, 10), # Like deit@384 + seqhw=(24, 24), + ) + c.num_classes = 0 + + c.evals = {} + c.evals.timing = dict( + type='timing', + input_shapes=[shape], + timing=True, + pred_kw=dict(outputs=('pre_logits',)), + ) + + return c \ No newline at end of file diff --git a/big_vision/configs/proj/givt/README.md b/big_vision/configs/proj/givt/README.md new file mode 100644 index 0000000000000000000000000000000000000000..a6818a55aeb9846e3c8fe89ebcc44fcf08d51def --- /dev/null +++ b/big_vision/configs/proj/givt/README.md @@ -0,0 +1,111 @@ +# GIVT: Generative Infinite-Vocabulary Transformers + +*by Michael Tschannen, Cian Eastwood, Fabian Mentzer* [[arxiv]](https://arxiv.org/abs/2312.02116) [[colab]](https://colab.research.google.com/github/google-research/big_vision/blob/main/big_vision/configs/proj/givt/givt_demo_colab.ipynb) + +![GIVT overview](givt_overview.png) + + +### Summary + +We introduce generative infinite-vocabulary transformers (GIVT) which generate vector sequences with real-valued entries, instead of discrete tokens from a finite vocabulary. +To this end, we propose two surprisingly simple modifications to decoder-only transformers: 1) at the input, we replace the finite-vocabulary lookup table with a linear projection of the input vectors; and 2) at the output, we replace the logits prediction (usually mapped to a categorical distribution) with the parameters of a multivariate Gaussian mixture model. +Inspired by the image-generation paradigm of VQ-GAN and MaskGIT, where transformers are used to model the discrete latent sequences of a VQ-VAE, we use GIVT to model the unquantized real-valued latent sequences of a β-VAE. +In class-conditional image generation GIVT outperforms VQ-GAN (and improved variants thereof) as well as MaskGIT, and achieves performance competitive with recent latent diffusion models. +Finally, we obtain strong results outside of image generation when applying GIVT to panoptic segmentation and depth estimation with a VAE variant of the UViM framework. + +### Checkpoints + +We provide model checkpoints for a subset of the models from the paper. +These are meant as small-scale baselines for researchers interested in exploring GIVT, and are not optimized to provide the best possible visual quality (e.g. scaling the model size can substantially improve visual quality as shown in the paper). +See below for instructions to train your own models. + +**ImageNet 2012 VAEs** + +| β | 1e-5 | 2.5e-5 | 5e-5 | 1e-4 | 2e-4 | +|:-----------|:------:|:----:|:----:|:----:|:----:| +| checkpoint | [link][vae_i1k_0] | [link][vae_i1k_1] | [link][vae_i1k_2] | [link][vae_i1k_3] | [link][vae_i1k_4] | + +[vae_i1k_0]: https://storage.googleapis.com/big_vision/givt/vae_imagenet_2012_beta_1e-5_params +[vae_i1k_1]: https://storage.googleapis.com/big_vision/givt/vae_imagenet_2012_beta_2p5e-5_params +[vae_i1k_2]: https://storage.googleapis.com/big_vision/givt/vae_imagenet_2012_beta_5e-5_params +[vae_i1k_3]: https://storage.googleapis.com/big_vision/givt/vae_imagenet_2012_beta_1e-4_params +[vae_i1k_4]: https://storage.googleapis.com/big_vision/givt/vae_imagenet_2012_beta_2e-4_params + +**Class-conditional ImageNet 2012 generative models** + +| model | resolution | β | inference | FID | checkpoint | +|:------|:----------:|:------:|:-------------|:---:|:-----------| +| GIVT-Causal | 256 x 256 | 5e-5 | t=0.95, DB-CFG=0.4 | 3.35 | [link][givt_i1k_1] | +| GIVT-MaskGIT | 256 x 256 | 5e-5 | t_C=35, DB-CFG=0.1 | 4.53 | [link][givt_i1k_2] | +| GIVT-MaskGIT | 512 x 512 | 5e-5 | t_C=140 | 4.86 | [link][givt_i1k_3] | + +[givt_i1k_1]: https://storage.googleapis.com/big_vision/givt/givt_imagenet_2012_causal_params.npz +[givt_i1k_2]: https://storage.googleapis.com/big_vision/givt/givt_imagenet_2012_maskgit_params.npz +[givt_i1k_3]: https://storage.googleapis.com/big_vision/givt/givt_imagenet_2012_maskgit_512_params.npz + + +**UViM** + +| task | model | dataset | accuracy | checkpoint | +|:-----|:------|:--------|---------:|:-----------| +| Panoptic segmentation | VAE (stage 1) | [COCO (2017)] | 71.0 (PQ) | [link][vae_coco_panoptic] | +| Panoptic segmentation | GIVT (stage 2) | [COCO (2017)] | 40.2 (PQ) | [link][givt_coco_panoptic] | +| Depth estimation | VAE (stage 1) | [NYU Depth v2] | 0.195 (RMSE) | [link][vae_nyu_depth] | +| Depth estimation | GIVT (stage 2) | [NYU Depth v2] | 0.474 (RMSE) | [link][givt_nyu_depth] | + +[NYU Depth v2]: https://cs.nyu.edu/~silberman/datasets/nyu_depth_v2.html +[COCO (2017)]: https://cocodataset.org/#home +[vae_coco_panoptic]: https://storage.googleapis.com/big_vision/givt/vae_coco_panoptic_params.npz +[givt_coco_panoptic]: https://storage.googleapis.com/big_vision/givt/givt_coco_panoptic_params.npz +[vae_nyu_depth]: https://storage.googleapis.com/big_vision/givt/vae_nyu_depth_params.npz +[givt_nyu_depth]: https://storage.googleapis.com/big_vision/givt/givt_nyu_depth_params.npz + +### Training models + +This directory contains configs to train GIVT models as well as VAEs (for the UViM variants). +For training the ImageNet 2012 VAE models we used a modified version of the [MaskGIT code](https://github.com/google-research/maskgit). + +The `big_vision` input pipeline relies on [TensorFlow Datasets (TFDS)](https://www.tensorflow.org/datasets) +which supports some data sets out-of-the-box, whereas others require manual download of the data +(for example ImageNet and COCO (2017), see the `big_vision` [main README](../../../../#cloud-tpu-vm-setup) and the [UViM README](../uvim), respectively, for details). + +After setting up `big_vision` as described in the [main README](../../../../#cloud-tpu-vm-setup), training can be launched locally as follows + +``` +python -m big_vision.trainers.proj.givt.generative \ + --config big_vision/configs/proj/givt/givt_imagenet2012.py \ + --workdir gs://$GS_BUCKET_NAME/big_vision/`date '+%m-%d_%H%M'` +``` + +Add the suffix `:key1=value1,key2=value2,...` to the config path in the launch +command to modify the config with predefined arguments (see config for details). For example: +`--config big_vision/configs/proj/givt/givt_imagenet_2012.py:model_size=large`. +Note that `givt_imagenet2012.py` uses [Imagenette](https://github.com/fastai/imagenette) to ensure that the config is runnable without manual ImageNet download. +This is only meant for testing and will overfit immediately. Please download ImageNet to reproduce the paper results. + +VAE trainings for the GIVT variant of UViM can be launched as + +``` +python -m big_vision.trainers.proj.givt.vae \ + --config big_vision/configs/proj/givt/vae_nyu_depth.py \ + --workdir gs://$GS_BUCKET_NAME/big_vision/`date '+%m-%d_%H%M'` +``` + +Please refer to the [main README](../../../../#cloud-tpu-vm-setup) +for details on how to launch training on a (multi-host) TPU setup. + + +### Disclaimer + +This is not an official Google Product. + + +### Citation +``` +@article{tschannen2023givt, + title={GIVT: Generative Infinite-Vocabulary Transformers}, + author={Tschannen, Michael and Eastwood, Cian and Mentzer, Fabian}, + journal={arXiv:2312.02116}, + year={2023} +} +``` \ No newline at end of file diff --git a/big_vision/configs/proj/givt/givt_coco_panoptic.py b/big_vision/configs/proj/givt/givt_coco_panoptic.py new file mode 100644 index 0000000000000000000000000000000000000000..e89d43609d7fb317391c1e380cd9921abb2c10a8 --- /dev/null +++ b/big_vision/configs/proj/givt/givt_coco_panoptic.py @@ -0,0 +1,186 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""Train a GIVT encoder-decoder model on COCO panoptic.""" + +import itertools +import ml_collections + +ConfigDict = ml_collections.ConfigDict + +VTT_MODELS = { + 'base': dict(num_layers=12, num_decoder_layers=12, num_heads=12, mlp_dim=3072, emb_dim=768), + 'large': dict(num_layers=24, num_decoder_layers=24, num_heads=16, mlp_dim=4096, emb_dim=1024), +} + +RES = 512 +PATCH_SIZE = 16 +LABEL_RES = 512 +LABEL_PATCH_SIZE = 16 + + +def get_config(runlocal=False): + """Config for training.""" + config = ConfigDict() + + config.input = {} + config.input.pp = ( + f'decode|coco_panoptic|concat(["semantics","instances"], "labels")|' + f'randu("fliplr")|det_fliplr(key="image")|det_fliplr(key="labels")|' + f'inception_box|crop_box(key="image")|crop_box(key="labels")|' + f'resize({RES})|resize({LABEL_RES},key="labels",method="nearest")|' + f'value_range(-1, 1)|make_canonical|' + f'copy("image", "cond_image")|copy("labels", "image")|' + f'keep("image", "cond_image")' + ) + pp_eval = ( + f'decode|coco_panoptic|concat(["semantics","instances"], "labels")|' + f'resize({RES})|resize({LABEL_RES},key="labels",method="nearest")|' + f'value_range(-1, 1)|make_canonical|' + f'copy("image", "cond_image")|copy("labels", "image")|' + f'keep("image", "cond_image")' + ) + pp_predict = ( + f'decode|resize({RES})|value_range(-1, 1)|copy("image", "cond_image")|' + f'keep("cond_image", "image/id")' # image/id used for rng seeds. + ) + + config.input.data = dict(name='coco/2017_panoptic', split='train[4096:]') + config.input.batch_size = 512 + config.input.shuffle_buffer_size = 50_000 + + config.total_epochs = 200 + + config.log_training_steps = 50 + config.ckpt_steps = 1000 + config.keep_ckpt_steps = None + config.prefetch_to_device = 2 + config.seed = 0 + + # Optimizer section + config.optax_name = 'big_vision.scale_by_adafactor' + config.optax = dict(beta2_cap=0.95) + + config.ar_generation_config = ml_collections.ConfigDict() + config.ar_generation_config.temp = 0.85 + config.ar_generation_config.temp_probs = 1.0 + config.ar_generation_config.beam_size = 4 + config.ar_generation_config.fan_size = 8 + config.ar_generation_config.rand_top_k = False + config.ar_generation_config.rand_top_k_temp = 1.0 + + config.lr = 0.001 + config.wd = 0.000001 + config.lr_mults = [ + ('pos_embedding_encoder.*', 0.1), + ('EmbedPatches.*', 0.1), + ('encoder.*', 0.1), + ('decoder.*', 1.0) + ] + config.schedule = dict(decay_type='cosine', warmup_steps=4_000) + + # Oracle section + config.vae = ConfigDict() + config.vae.model_name = 'proj.givt.vit' + config.vae.model = ConfigDict() + config.vae.model.input_size = (RES, RES) + config.vae.model.patch_size = (PATCH_SIZE, PATCH_SIZE) + config.vae.model.code_len = 256 + config.vae.model.width = 768 + config.vae.model.enc_depth = 6 + config.vae.model.dec_depth = 12 + config.vae.model.mlp_dim = 3072 + config.vae.model.num_heads = 12 + config.vae.model.codeword_dim = 16 + config.vae.model.code_dropout = 'none' + config.vae.model.bottleneck_resize = True + # values: (channel index in source image, number of classes) + config.vae.model.inout_specs = { + 'semantics': (0, 133 + 1), # +1 for void label + 'instances': (1, 100), # COCO: actually 98 train/78 validation. + } + config.vae.model_init = 'gs://big_vision/givt/vae_coco_panoptic_params.npz' + + # Model section + config.model_name = 'proj.givt.givt' + # # Base model (for exploration) + # config.model_init = {'encoder': 'howto-i21k-B/16'} + # config.model = ConfigDict(VTT_MODELS['base']) + # Large model + config.model_init = {'encoder': 'howto-i21k-L/16'} + config.model_load = dict(dont_load=('cls', 'head/bias', 'head/kernel')) + config.model = ConfigDict(VTT_MODELS['large']) + config.model.patches = (PATCH_SIZE, PATCH_SIZE) + config.model.input_size = (RES, RES) + config.model.posemb_type = 'learn' + config.model.seq_len = config.vae.model.code_len + config.model.num_labels = None + config.model.num_mixtures = 1 + config.model.fix_square_plus = True + config.model.out_dim = config.vae.model.codeword_dim + config.model.scale_tol = 1e-6 + config.model.dec_dropout_rate = 0.0 + + # Evaluation section + config.evals = {} + config.evals.val = ConfigDict() + config.evals.val.type = 'mean' + config.evals.val.pred = 'validation' + config.evals.val.data = dict(name=config.input.data.name, split='train[:4096]') + config.evals.val.pp_fn = pp_eval + config.evals.val.log_steps = 1000 + + config.eval_only = False + + base = { + 'type': 'proj.givt.coco_panoptic', + 'data': {**config.input.data}, + 'pp_fn': pp_predict, + 'log_steps': 10_000, + 'pred': 'sample_panoptic', + # Filters objects that occupy less than 0.03^2 fraction of all pixels. + # 'pred_kw': {'min_fraction': 0.03 ** 2}, + } + config.evals.coco_panoptic_train = dict(base) + config.evals.coco_panoptic_train.data.split = 'train[4096:8192]' + config.evals.coco_panoptic_holdout = dict(base) + config.evals.coco_panoptic_holdout.data.split = 'train[:4096]' + config.evals.coco_panoptic = dict(base) + config.evals.coco_panoptic.data.split = 'validation' + + config.evals.save_pred = dict(type='proj.givt.save_predictions') + config.evals.save_pred.pred = 'sample_panoptic' + config.evals.save_pred.pp_fn = pp_eval + config.evals.save_pred.log_steps = 100_000 + config.evals.save_pred.data = dict(config.input.data) + config.evals.save_pred.data.split = 'validation[:1024]' + config.evals.save_pred.outfile = 'inference.npz' + + if runlocal: + config.input.batch_size = 4 + config.input.shuffle_buffer_size = 10 + config.evals.val.data.split = 'train[:16]' + config.evals.val.log_steps = 20 + config.model.num_layers = 1 + config.model.num_decoder_layers = 1 + del config.model_init + config.evals.val.data.split = 'validation[:4]' + config.evals.coco_panoptic.data.split = 'validation[:4]' + config.evals.save_pred.data.split = 'validation[:4]' + for k in config.evals.keys(): + if k not in ['val', 'coco_panoptic', 'save_pred']: + del config.evals[k] + + return config diff --git a/big_vision/configs/proj/givt/givt_demo_colab.ipynb b/big_vision/configs/proj/givt/givt_demo_colab.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..3e8c93004a87e42c31c02bc6fb3061413b0fd494 --- /dev/null +++ b/big_vision/configs/proj/givt/givt_demo_colab.ipynb @@ -0,0 +1,309 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [], + "gpuType": "T4" + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + }, + "accelerator": "GPU" + }, + "cells": [ + { + "cell_type": "markdown", + "source": [ + "# GIVT Demo colab\n", + "\n", + "[[paper]](https://arxiv.org/abs/2312.02116) [[GitHub]](https://github.com/google-research/big_vision/blob/main/big_vision/configs/proj/givt/README.md)\n", + "\n", + "This colab implements class-conditional image generation using GIVT-Causal and GIVT-MaskGIT for the 1k ImageNet2012 classes.\n", + "\n", + "The available model checkpoints are meant as small-scale baselines (~300M parameters) for researchers interested in exploring GIVT, and are not optimized to provide the best possible visual quality (e.g. scaling the model size can substantially improve visual quality as shown in the paper).\n", + "\n", + "The colab was tested with the CPU and T4 GPU runtimes. We recommend the T4 GPU runtime (the CPU rutime is very slow).\n", + "\n", + "_Disclaimer: This is not an official Google Product._" + ], + "metadata": { + "id": "botgo-GZiWI_" + } + }, + { + "cell_type": "markdown", + "source": [ + "### `big_vision` setup" + ], + "metadata": { + "id": "jQxc9UZ-mVrQ" + } + }, + { + "cell_type": "code", + "source": [ + "#@markdown Clone and set up repository\n", + "!git clone --branch=main --depth=1 https://github.com/google-research/big_vision\n", + "!cd big_vision && git pull\n", + "\n", + "# Install dependencies - pin TensorFlow-related packages to ensure compatibility\n", + "# which might not be needed in in the future\n", + "!echo -e \"keras==3.0.5\\ntensorflow==2.16.1\\ntensorflow-probability==0.24.0\" > big_vision/big_vision/constraints.txt\n", + "!pip install -r big_vision/big_vision/requirements.txt -c big_vision/big_vision/constraints.txt\n", + "%cd big_vision" + ], + "metadata": { + "id": "ZAXiVta3n2jL", + "cellView": "form" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "qYS7JNups4MU", + "cellView": "form" + }, + "outputs": [], + "source": [ + "#@markdown Imports\n", + "import jax\n", + "from functools import partial\n", + "import ml_collections\n", + "import matplotlib.pyplot as plt\n", + "\n", + "from big_vision.configs.proj.givt import givt_imagenet2012\n", + "from big_vision.datasets.imagenet import class_names as imagenet_class_names\n", + "from big_vision.models.proj.givt import givt, cnn, decode, parallel_decode\n", + "\n", + "jnp = jax.numpy" + ] + }, + { + "cell_type": "markdown", + "source": [ + "### Select and download model\n", + "\n" + ], + "metadata": { + "id": "MaCM_PIcd2Rb" + } + }, + { + "cell_type": "code", + "source": [ + "model = \"GIVT-Causal 256x256\" #@param [\"GIVT-Causal 256x256\", \"GIVT-MaskGIT 256x256\", \"GIVT-MaskGIT 512x512\"]\n", + "\n", + "givt_ckpt_path, cfg_w, temp, is_ar, res = {\n", + " \"GIVT-Causal 256x256\": (\n", + " \"gs://big_vision/givt/givt_imagenet_2012_causal_params.npz\", 0.4, 0.95, True, 256),\n", + " \"GIVT-MaskGIT 256x256\": (\n", + " \"gs://big_vision/givt/givt_imagenet_2012_maskgit_params.npz\", 0.0, 35.0, False, 256),\n", + " \"GIVT-MaskGIT 512x512\": (\n", + " \"gs://big_vision/givt/givt_imagenet_2012_maskgit_512_params.npz\", 0.0, 140.0, False, 512),\n", + "}[model]\n", + "\n", + "config = givt_imagenet2012.get_config(arg=f\"res={res},style={'ar' if is_ar else 'masked'}\")\n", + "\n", + "print(\"Loading VAE model...\")\n", + "vae_model = cnn.Model(**config.vae.model)\n", + "vae_params = cnn.load(None, config.vae.model_init, **config.vae.model_load)\n", + "\n", + "print(\"Loading GIVT model...\")\n", + "givt_model = givt.Model(**config.model)\n", + "givt_params = jax.device_put(\n", + " givt.load(None, givt_ckpt_path), jax.devices()[0])" + ], + "metadata": { + "id": "7l6QIjdyN3dg", + "cellView": "form" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "### VAE encode/decode and sampling loop" + ], + "metadata": { + "id": "SUj5k1bxd6wr" + } + }, + { + "cell_type": "code", + "source": [ + "@jax.jit\n", + "def vae_encode(images, rng):\n", + " \"\"\"Encode image with VAE encoder.\"\"\"\n", + " mu, logvar = vae_model.apply(\n", + " {\"params\": vae_params}, images, method=vae_model.encode,\n", + " )\n", + " return vae_model.apply(\n", + " {\"params\": vae_params},\n", + " mu,\n", + " logvar,\n", + " method=vae_model.reparametrize,\n", + " rngs={\"dropout\": rng},\n", + " )\n", + "\n", + "@jax.jit\n", + "def vae_decode(z):\n", + " \"\"\"Reconstruct image with VAE decoder from latent code z.\"\"\"\n", + " return vae_model.apply({\"params\": vae_params}, z, method=vae_model.decode)\n", + "\n", + "### jit-compilation seems to go OOM (RAM) on the free tier GPU colab, but might\n", + "### lead to speedups on machines with more resources\n", + "# @partial(jax.jit, static_argnums=(2, 3))\n", + "def sample(labels, rng, ar_generation_config=None, masked_generation_config=None):\n", + " \"\"\"Sample from GIVT-Causal or GIVT-MaskGIT.\"\"\"\n", + " print(f\"Sampling, style={givt_model.style}\")\n", + " shared_kwargs = dict(\n", + " labels=labels,\n", + " model=givt_model,\n", + " seq_len=config.model.seq_len,\n", + " feature_dim=config.model.out_dim,\n", + " )\n", + "\n", + " match givt_model.style:\n", + " case \"ar\":\n", + " sampled_codes, _ = decode.generate(\n", + " params={\"params\": givt_params},\n", + " seed=rng,\n", + " config=dict(ar_generation_config),\n", + " **shared_kwargs,\n", + " )\n", + " info = sampled_codes\n", + " case \"masked\":\n", + " masked_out = parallel_decode.decode_masked(\n", + " rng=rng,\n", + " variables={\"params\": givt_params},\n", + " config=masked_generation_config,\n", + " **shared_kwargs,\n", + " )\n", + " sampled_codes = masked_out.current_inputs_q\n", + " info = masked_out\n", + " case _:\n", + " raise NotImplementedError\n", + " return sampled_codes, info" + ], + "metadata": { + "id": "vSn7Si2FS1zi" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "### Generate images for class label" + ], + "metadata": { + "id": "tOnWaJZVeOIX" + } + }, + { + "cell_type": "code", + "source": [ + "rng = 0 #@param = 'int'\n", + "label = 'goldfish' #@param [\"tench\", \"goldfish\", \"great white shark\", \"tiger shark\", \"hammerhead shark\", \"electric ray\", \"stingray\", \"rooster\", \"hen\", \"ostrich\", \"brambling\", \"goldfinch\", \"house finch\", \"junco\", \"indigo bunting\", \"American robin\", \"bulbul\", \"jay\", \"magpie\", \"chickadee\", \"American dipper\", \"kite (bird of prey)\", \"bald eagle\", \"vulture\", \"great grey owl\", \"fire salamander\", \"smooth newt\", \"newt\", \"spotted salamander\", \"axolotl\", \"American bullfrog\", \"tree frog\", \"tailed frog\", \"loggerhead sea turtle\", \"leatherback sea turtle\", \"mud turtle\", \"terrapin\", \"box turtle\", \"banded gecko\", \"green iguana\", \"Carolina anole\", \"desert grassland whiptail lizard\", \"agama\", \"frilled-necked lizard\", \"alligator lizard\", \"Gila monster\", \"European green lizard\", \"chameleon\", \"Komodo dragon\", \"Nile crocodile\", \"American alligator\", \"triceratops\", \"worm snake\", \"ring-necked snake\", \"eastern hog-nosed snake\", \"smooth green snake\", \"kingsnake\", \"garter snake\", \"water snake\", \"vine snake\", \"night snake\", \"boa constrictor\", \"African rock python\", \"Indian cobra\", \"green mamba\", \"sea snake\", \"Saharan horned viper\", \"eastern diamondback rattlesnake\", \"sidewinder rattlesnake\", \"trilobite\", \"harvestman\", \"scorpion\", \"yellow garden spider\", \"barn spider\", \"European garden spider\", \"southern black widow\", \"tarantula\", \"wolf spider\", \"tick\", \"centipede\", \"black grouse\", \"ptarmigan\", \"ruffed grouse\", \"prairie grouse\", \"peafowl\", \"quail\", \"partridge\", \"african grey parrot\", \"macaw\", \"sulphur-crested cockatoo\", \"lorikeet\", \"coucal\", \"bee eater\", \"hornbill\", \"hummingbird\", \"jacamar\", \"toucan\", \"duck\", \"red-breasted merganser\", \"goose\", \"black swan\", \"tusker\", \"echidna\", \"platypus\", \"wallaby\", \"koala\", \"wombat\", \"jellyfish\", \"sea anemone\", \"brain coral\", \"flatworm\", \"nematode\", \"conch\", \"snail\", \"slug\", \"sea slug\", \"chiton\", \"chambered nautilus\", \"Dungeness crab\", \"rock crab\", \"fiddler crab\", \"red king crab\", \"American lobster\", \"spiny lobster\", \"crayfish\", \"hermit crab\", \"isopod\", \"white stork\", \"black stork\", \"spoonbill\", \"flamingo\", \"little blue heron\", \"great egret\", \"bittern bird\", \"crane bird\", \"limpkin\", \"common gallinule\", \"American coot\", \"bustard\", \"ruddy turnstone\", \"dunlin\", \"common redshank\", \"dowitcher\", \"oystercatcher\", \"pelican\", \"king penguin\", \"albatross\", \"grey whale\", \"killer whale\", \"dugong\", \"sea lion\", \"Chihuahua\", \"Japanese Chin\", \"Maltese\", \"Pekingese\", \"Shih Tzu\", \"King Charles Spaniel\", \"Papillon\", \"toy terrier\", \"Rhodesian Ridgeback\", \"Afghan Hound\", \"Basset Hound\", \"Beagle\", \"Bloodhound\", \"Bluetick Coonhound\", \"Black and Tan Coonhound\", \"Treeing Walker Coonhound\", \"English foxhound\", \"Redbone Coonhound\", \"borzoi\", \"Irish Wolfhound\", \"Italian Greyhound\", \"Whippet\", \"Ibizan Hound\", \"Norwegian Elkhound\", \"Otterhound\", \"Saluki\", \"Scottish Deerhound\", \"Weimaraner\", \"Staffordshire Bull Terrier\", \"American Staffordshire Terrier\", \"Bedlington Terrier\", \"Border Terrier\", \"Kerry Blue Terrier\", \"Irish Terrier\", \"Norfolk Terrier\", \"Norwich Terrier\", \"Yorkshire Terrier\", \"Wire Fox Terrier\", \"Lakeland Terrier\", \"Sealyham Terrier\", \"Airedale Terrier\", \"Cairn Terrier\", \"Australian Terrier\", \"Dandie Dinmont Terrier\", \"Boston Terrier\", \"Miniature Schnauzer\", \"Giant Schnauzer\", \"Standard Schnauzer\", \"Scottish Terrier\", \"Tibetan Terrier\", \"Australian Silky Terrier\", \"Soft-coated Wheaten Terrier\", \"West Highland White Terrier\", \"Lhasa Apso\", \"Flat-Coated Retriever\", \"Curly-coated Retriever\", \"Golden Retriever\", \"Labrador Retriever\", \"Chesapeake Bay Retriever\", \"German Shorthaired Pointer\", \"Vizsla\", \"English Setter\", \"Irish Setter\", \"Gordon Setter\", \"Brittany dog\", \"Clumber Spaniel\", \"English Springer Spaniel\", \"Welsh Springer Spaniel\", \"Cocker Spaniel\", \"Sussex Spaniel\", \"Irish Water Spaniel\", \"Kuvasz\", \"Schipperke\", \"Groenendael dog\", \"Malinois\", \"Briard\", \"Australian Kelpie\", \"Komondor\", \"Old English Sheepdog\", \"Shetland Sheepdog\", \"collie\", \"Border Collie\", \"Bouvier des Flandres dog\", \"Rottweiler\", \"German Shepherd Dog\", \"Dobermann\", \"Miniature Pinscher\", \"Greater Swiss Mountain Dog\", \"Bernese Mountain Dog\", \"Appenzeller Sennenhund\", \"Entlebucher Sennenhund\", \"Boxer\", \"Bullmastiff\", \"Tibetan Mastiff\", \"French Bulldog\", \"Great Dane\", \"St. Bernard\", \"husky\", \"Alaskan Malamute\", \"Siberian Husky\", \"Dalmatian\", \"Affenpinscher\", \"Basenji\", \"pug\", \"Leonberger\", \"Newfoundland dog\", \"Great Pyrenees dog\", \"Samoyed\", \"Pomeranian\", \"Chow Chow\", \"Keeshond\", \"brussels griffon\", \"Pembroke Welsh Corgi\", \"Cardigan Welsh Corgi\", \"Toy Poodle\", \"Miniature Poodle\", \"Standard Poodle\", \"Mexican hairless dog (xoloitzcuintli)\", \"grey wolf\", \"Alaskan tundra wolf\", \"red wolf or maned wolf\", \"coyote\", \"dingo\", \"dhole\", \"African wild dog\", \"hyena\", \"red fox\", \"kit fox\", \"Arctic fox\", \"grey fox\", \"tabby cat\", \"tiger cat\", \"Persian cat\", \"Siamese cat\", \"Egyptian Mau\", \"cougar\", \"lynx\", \"leopard\", \"snow leopard\", \"jaguar\", \"lion\", \"tiger\", \"cheetah\", \"brown bear\", \"American black bear\", \"polar bear\", \"sloth bear\", \"mongoose\", \"meerkat\", \"tiger beetle\", \"ladybug\", \"ground beetle\", \"longhorn beetle\", \"leaf beetle\", \"dung beetle\", \"rhinoceros beetle\", \"weevil\", \"fly\", \"bee\", \"ant\", \"grasshopper\", \"cricket insect\", \"stick insect\", \"cockroach\", \"praying mantis\", \"cicada\", \"leafhopper\", \"lacewing\", \"dragonfly\", \"damselfly\", \"red admiral butterfly\", \"ringlet butterfly\", \"monarch butterfly\", \"small white butterfly\", \"sulphur butterfly\", \"gossamer-winged butterfly\", \"starfish\", \"sea urchin\", \"sea cucumber\", \"cottontail rabbit\", \"hare\", \"Angora rabbit\", \"hamster\", \"porcupine\", \"fox squirrel\", \"marmot\", \"beaver\", \"guinea pig\", \"common sorrel horse\", \"zebra\", \"pig\", \"wild boar\", \"warthog\", \"hippopotamus\", \"ox\", \"water buffalo\", \"bison\", \"ram (adult male sheep)\", \"bighorn sheep\", \"Alpine ibex\", \"hartebeest\", \"impala (antelope)\", \"gazelle\", \"arabian camel\", \"llama\", \"weasel\", \"mink\", \"European polecat\", \"black-footed ferret\", \"otter\", \"skunk\", \"badger\", \"armadillo\", \"three-toed sloth\", \"orangutan\", \"gorilla\", \"chimpanzee\", \"gibbon\", \"siamang\", \"guenon\", \"patas monkey\", \"baboon\", \"macaque\", \"langur\", \"black-and-white colobus\", \"proboscis monkey\", \"marmoset\", \"white-headed capuchin\", \"howler monkey\", \"titi monkey\", \"Geoffroy's spider monkey\", \"common squirrel monkey\", \"ring-tailed lemur\", \"indri\", \"Asian elephant\", \"African bush elephant\", \"red panda\", \"giant panda\", \"snoek fish\", \"eel\", \"silver salmon\", \"rock beauty fish\", \"clownfish\", \"sturgeon\", \"gar fish\", \"lionfish\", \"pufferfish\", \"abacus\", \"abaya\", \"academic gown\", \"accordion\", \"acoustic guitar\", \"aircraft carrier\", \"airliner\", \"airship\", \"altar\", \"ambulance\", \"amphibious vehicle\", \"analog clock\", \"apiary\", \"apron\", \"trash can\", \"assault rifle\", \"backpack\", \"bakery\", \"balance beam\", \"balloon\", \"ballpoint pen\", \"Band-Aid\", \"banjo\", \"baluster / handrail\", \"barbell\", \"barber chair\", \"barbershop\", \"barn\", \"barometer\", \"barrel\", \"wheelbarrow\", \"baseball\", \"basketball\", \"bassinet\", \"bassoon\", \"swimming cap\", \"bath towel\", \"bathtub\", \"station wagon\", \"lighthouse\", \"beaker\", \"military hat (bearskin or shako)\", \"beer bottle\", \"beer glass\", \"bell tower\", \"baby bib\", \"tandem bicycle\", \"bikini\", \"ring binder\", \"binoculars\", \"birdhouse\", \"boathouse\", \"bobsleigh\", \"bolo tie\", \"poke bonnet\", \"bookcase\", \"bookstore\", \"bottle cap\", \"hunting bow\", \"bow tie\", \"brass memorial plaque\", \"bra\", \"breakwater\", \"breastplate\", \"broom\", \"bucket\", \"buckle\", \"bulletproof vest\", \"high-speed train\", \"butcher shop\", \"taxicab\", \"cauldron\", \"candle\", \"cannon\", \"canoe\", \"can opener\", \"cardigan\", \"car mirror\", \"carousel\", \"tool kit\", \"cardboard box / carton\", \"car wheel\", \"automated teller machine\", \"cassette\", \"cassette player\", \"castle\", \"catamaran\", \"CD player\", \"cello\", \"mobile phone\", \"chain\", \"chain-link fence\", \"chain mail\", \"chainsaw\", \"storage chest\", \"chiffonier\", \"bell or wind chime\", \"china cabinet\", \"Christmas stocking\", \"church\", \"movie theater\", \"cleaver\", \"cliff dwelling\", \"cloak\", \"clogs\", \"cocktail shaker\", \"coffee mug\", \"coffeemaker\", \"spiral or coil\", \"combination lock\", \"computer keyboard\", \"candy store\", \"container ship\", \"convertible\", \"corkscrew\", \"cornet\", \"cowboy boot\", \"cowboy hat\", \"cradle\", \"construction crane\", \"crash helmet\", \"crate\", \"infant bed\", \"Crock Pot\", \"croquet ball\", \"crutch\", \"cuirass\", \"dam\", \"desk\", \"desktop computer\", \"rotary dial telephone\", \"diaper\", \"digital clock\", \"digital watch\", \"dining table\", \"dishcloth\", \"dishwasher\", \"disc brake\", \"dock\", \"dog sled\", \"dome\", \"doormat\", \"drilling rig\", \"drum\", \"drumstick\", \"dumbbell\", \"Dutch oven\", \"electric fan\", \"electric guitar\", \"electric locomotive\", \"entertainment center\", \"envelope\", \"espresso machine\", \"face powder\", \"feather boa\", \"filing cabinet\", \"fireboat\", \"fire truck\", \"fire screen\", \"flagpole\", \"flute\", \"folding chair\", \"football helmet\", \"forklift\", \"fountain\", \"fountain pen\", \"four-poster bed\", \"freight car\", \"French horn\", \"frying pan\", \"fur coat\", \"garbage truck\", \"gas mask or respirator\", \"gas pump\", \"goblet\", \"go-kart\", \"golf ball\", \"golf cart\", \"gondola\", \"gong\", \"gown\", \"grand piano\", \"greenhouse\", \"radiator grille\", \"grocery store\", \"guillotine\", \"hair clip\", \"hair spray\", \"half-track\", \"hammer\", \"hamper\", \"hair dryer\", \"hand-held computer\", \"handkerchief\", \"hard disk drive\", \"harmonica\", \"harp\", \"combine harvester\", \"hatchet\", \"holster\", \"home theater\", \"honeycomb\", \"hook\", \"hoop skirt\", \"gymnastic horizontal bar\", \"horse-drawn vehicle\", \"hourglass\", \"iPod\", \"clothes iron\", \"carved pumpkin\", \"jeans\", \"jeep\", \"T-shirt\", \"jigsaw puzzle\", \"rickshaw\", \"joystick\", \"kimono\", \"knee pad\", \"knot\", \"lab coat\", \"ladle\", \"lampshade\", \"laptop computer\", \"lawn mower\", \"lens cap\", \"letter opener\", \"library\", \"lifeboat\", \"lighter\", \"limousine\", \"ocean liner\", \"lipstick\", \"slip-on shoe\", \"lotion\", \"music speaker\", \"loupe magnifying glass\", \"sawmill\", \"magnetic compass\", \"messenger bag\", \"mailbox\", \"tights\", \"one-piece bathing suit\", \"manhole cover\", \"maraca\", \"marimba\", \"mask\", \"matchstick\", \"maypole\", \"maze\", \"measuring cup\", \"medicine cabinet\", \"megalith\", \"microphone\", \"microwave oven\", \"military uniform\", \"milk can\", \"minibus\", \"miniskirt\", \"minivan\", \"missile\", \"mitten\", \"mixing bowl\", \"mobile home\", \"ford model t\", \"modem\", \"monastery\", \"monitor\", \"moped\", \"mortar and pestle\", \"graduation cap\", \"mosque\", \"mosquito net\", \"vespa\", \"mountain bike\", \"tent\", \"computer mouse\", \"mousetrap\", \"moving van\", \"muzzle\", \"metal nail\", \"neck brace\", \"necklace\", \"baby pacifier\", \"notebook computer\", \"obelisk\", \"oboe\", \"ocarina\", \"odometer\", \"oil filter\", \"pipe organ\", \"oscilloscope\", \"overskirt\", \"bullock cart\", \"oxygen mask\", \"product packet / packaging\", \"paddle\", \"paddle wheel\", \"padlock\", \"paintbrush\", \"pajamas\", \"palace\", \"pan flute\", \"paper towel\", \"parachute\", \"parallel bars\", \"park bench\", \"parking meter\", \"railroad car\", \"patio\", \"payphone\", \"pedestal\", \"pencil case\", \"pencil sharpener\", \"perfume\", \"Petri dish\", \"photocopier\", \"plectrum\", \"Pickelhaube\", \"picket fence\", \"pickup truck\", \"pier\", \"piggy bank\", \"pill bottle\", \"pillow\", \"ping-pong ball\", \"pinwheel\", \"pirate ship\", \"drink pitcher\", \"block plane\", \"planetarium\", \"plastic bag\", \"plate rack\", \"farm plow\", \"plunger\", \"Polaroid camera\", \"pole\", \"police van\", \"poncho\", \"pool table\", \"soda bottle\", \"plant pot\", \"potter's wheel\", \"power drill\", \"prayer rug\", \"printer\", \"prison\", \"missile\", \"projector\", \"hockey puck\", \"punching bag\", \"purse\", \"quill\", \"quilt\", \"race car\", \"racket\", \"radiator\", \"radio\", \"radio telescope\", \"rain barrel\", \"recreational vehicle\", \"fishing casting reel\", \"reflex camera\", \"refrigerator\", \"remote control\", \"restaurant\", \"revolver\", \"rifle\", \"rocking chair\", \"rotisserie\", \"eraser\", \"rugby ball\", \"ruler measuring stick\", \"sneaker\", \"safe\", \"safety pin\", \"salt shaker\", \"sandal\", \"sarong\", \"saxophone\", \"scabbard\", \"weighing scale\", \"school bus\", \"schooner\", \"scoreboard\", \"CRT monitor\", \"screw\", \"screwdriver\", \"seat belt\", \"sewing machine\", \"shield\", \"shoe store\", \"shoji screen / room divider\", \"shopping basket\", \"shopping cart\", \"shovel\", \"shower cap\", \"shower curtain\", \"ski\", \"balaclava ski mask\", \"sleeping bag\", \"slide rule\", \"sliding door\", \"slot machine\", \"snorkel\", \"snowmobile\", \"snowplow\", \"soap dispenser\", \"soccer ball\", \"sock\", \"solar thermal collector\", \"sombrero\", \"soup bowl\", \"keyboard space bar\", \"space heater\", \"space shuttle\", \"spatula\", \"motorboat\", \"spider web\", \"spindle\", \"sports car\", \"spotlight\", \"stage\", \"steam locomotive\", \"through arch bridge\", \"steel drum\", \"stethoscope\", \"scarf\", \"stone wall\", \"stopwatch\", \"stove\", \"strainer\", \"tram\", \"stretcher\", \"couch\", \"stupa\", \"submarine\", \"suit\", \"sundial\", \"sunglasses\", \"sunglasses\", \"sunscreen\", \"suspension bridge\", \"mop\", \"sweatshirt\", \"swim trunks / shorts\", \"swing\", \"electrical switch\", \"syringe\", \"table lamp\", \"tank\", \"tape player\", \"teapot\", \"teddy bear\", \"television\", \"tennis ball\", \"thatched roof\", \"front curtain\", \"thimble\", \"threshing machine\", \"throne\", \"tile roof\", \"toaster\", \"tobacco shop\", \"toilet seat\", \"torch\", \"totem pole\", \"tow truck\", \"toy store\", \"tractor\", \"semi-trailer truck\", \"tray\", \"trench coat\", \"tricycle\", \"trimaran\", \"tripod\", \"triumphal arch\", \"trolleybus\", \"trombone\", \"hot tub\", \"turnstile\", \"typewriter keyboard\", \"umbrella\", \"unicycle\", \"upright piano\", \"vacuum cleaner\", \"vase\", \"vaulted or arched ceiling\", \"velvet fabric\", \"vending machine\", \"vestment\", \"viaduct\", \"violin\", \"volleyball\", \"waffle iron\", \"wall clock\", \"wallet\", \"wardrobe\", \"military aircraft\", \"sink\", \"washing machine\", \"water bottle\", \"water jug\", \"water tower\", \"whiskey jug\", \"whistle\", \"hair wig\", \"window screen\", \"window shade\", \"Windsor tie\", \"wine bottle\", \"airplane wing\", \"wok\", \"wooden spoon\", \"wool\", \"split-rail fence\", \"shipwreck\", \"sailboat\", \"yurt\", \"website\", \"comic book\", \"crossword\", \"traffic or street sign\", \"traffic light\", \"dust jacket\", \"menu\", \"plate\", \"guacamole\", \"consomme\", \"hot pot\", \"trifle\", \"ice cream\", \"popsicle\", \"baguette\", \"bagel\", \"pretzel\", \"cheeseburger\", \"hot dog\", \"mashed potatoes\", \"cabbage\", \"broccoli\", \"cauliflower\", \"zucchini\", \"spaghetti squash\", \"acorn squash\", \"butternut squash\", \"cucumber\", \"artichoke\", \"bell pepper\", \"cardoon\", \"mushroom\", \"Granny Smith apple\", \"strawberry\", \"orange\", \"lemon\", \"fig\", \"pineapple\", \"banana\", \"jackfruit\", \"cherimoya (custard apple)\", \"pomegranate\", \"hay\", \"carbonara\", \"chocolate syrup\", \"dough\", \"meatloaf\", \"pizza\", \"pot pie\", \"burrito\", \"red wine\", \"espresso\", \"tea cup\", \"eggnog\", \"mountain\", \"bubble\", \"cliff\", \"coral reef\", \"geyser\", \"lakeshore\", \"promontory\", \"sandbar\", \"beach\", \"valley\", \"volcano\", \"baseball player\", \"bridegroom\", \"scuba diver\", \"rapeseed\", \"daisy\", \"yellow lady's slipper\", \"corn\", \"acorn\", \"rose hip\", \"horse chestnut seed\", \"coral fungus\", \"agaric\", \"gyromitra\", \"stinkhorn mushroom\", \"earth star fungus\", \"hen of the woods mushroom\", \"bolete\", \"corn cob\", \"toilet paper\"]\n", + "label_int = dict(\n", + " zip(imagenet_class_names.CLIP_IMAGENET_CLASS_NAMES,\n", + " range(len(imagenet_class_names.CLIP_IMAGENET_CLASS_NAMES))))[label]" + ], + "metadata": { + "cellView": "form", + "id": "_CiyXD_6nQbu" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "%%capture --no-display\n", + "batch_size = 8\n", + "\n", + "target_labels = jnp.full((batch_size,), label_int, jnp.int32)\n", + "\n", + "if is_ar:\n", + " ar_generation_config = dict(cfg_inference_weight=cfg_w, temp=temp)\n", + " masked_generation_config = None\n", + "else:\n", + " ar_generation_config = {}\n", + " masked_generation_config = parallel_decode.MaskedGenerationConfig(\n", + " cfg_inference_weight=cfg_w,\n", + " choice_temperature = temp,\n", + " num_steps = 16,\n", + " ordering = \"maskgit\",\n", + " schedule = \"cosine\",\n", + " )\n", + "\n", + "# Sample from GIVT and decode\n", + "sampled_codes, _ = sample(\n", + " target_labels, jax.random.PRNGKey(rng),\n", + " tuple(ar_generation_config.items()), masked_generation_config)\n", + "\n", + "generated_images = vae_decode(sampled_codes)" + ], + "metadata": { + "id": "sCcGB0m1oQY1" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "#@markdown Visualize images\n", + "ncols = 4\n", + "nrows = generated_images.shape[0] // ncols\n", + "fig, axes = plt.subplots(nrows, ncols, figsize=(4 * ncols, 4 * nrows))\n", + "\n", + "for idx, (ax, img) in enumerate(zip(axes.flat, generated_images)):\n", + " ax.imshow(img * .5 + .5)\n", + " if idx == 0:\n", + " ax.set_title(f'Label: {label} ({label_int})', fontsize=10, ha='left', loc='left')\n", + " ax.set_axis_off()" + ], + "metadata": { + "id": "4FWgfAghuh8P", + "cellView": "form" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "#@markdown Visualize latent codes\n", + "nrows = sampled_codes.shape[0]\n", + "ncols = sampled_codes.shape[-1] + 1\n", + "fig, axes = plt.subplots(nrows, ncols, figsize=(ncols, nrows))\n", + "\n", + "for r, (row_ax, code) in enumerate(zip(axes, sampled_codes)):\n", + " code_norm = (code - code.min()) / (code.max() - code.min())\n", + " for c, ax in enumerate(row_ax):\n", + " if c == 0:\n", + " cc = generated_images[r] * .5 + .5\n", + " else:\n", + " cc = code_norm[..., c - 1].reshape(res // 16, res // 16)\n", + " ax.imshow(cc)\n", + " ax.set_axis_off()" + ], + "metadata": { + "id": "zGPPeXONy0Am", + "cellView": "form" + }, + "execution_count": null, + "outputs": [] + } + ] +} \ No newline at end of file diff --git a/big_vision/configs/proj/givt/givt_imagenet2012.py b/big_vision/configs/proj/givt/givt_imagenet2012.py new file mode 100644 index 0000000000000000000000000000000000000000..84b0fd352487edbc66475ebde5883b087cf0f6ba --- /dev/null +++ b/big_vision/configs/proj/givt/givt_imagenet2012.py @@ -0,0 +1,222 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""Train Generative Infinite Vocabulary Transformer (GIVT) on ImageNet. + +Example launch command (local; see main README for launching on TPU servers): + + python -m big_vision.trainers.proj.givt.generative \ + --config big_vision/configs/proj/givt/givt_imagenet2012.py \ + --workdir gs://$GS_BUCKET_NAME/big_vision/`date '+%m-%d_%H%M'` + +Add the suffix `:key1=value1,key2=value2,...` to the config path in the launch +command to modify the the config with the arguments below. For example: +`--config big_vision/configs/proj/givt/givt_imagenet_2012.py:model_size=large` +""" + +import big_vision.configs.common as bvcc +import ml_collections + + +RES = 256 +PATCH_SIZE = 16 + +GIVT_MODELS = { + 'base': dict(num_decoder_layers=12, num_heads=12, mlp_dim=3072, emb_dim=768, dec_dropout_rate=0.1), + 'default': dict(num_decoder_layers=24, num_heads=16, mlp_dim=4096, emb_dim=1024, dec_dropout_rate=0.2), + 'large': dict(num_decoder_layers=48, num_heads=16, mlp_dim=8192, emb_dim=1536, dec_dropout_rate=0.3) +} + + +def get_config(arg=None): + """A config for training a simple VAE on imagenet2012.""" + arg = bvcc.parse_arg(arg, res=RES, patch_size=PATCH_SIZE, style='ar', # 'ar' or 'masked' + model_size='default', runlocal=False, singlehost=False, + adaptor=False) + config = ml_collections.ConfigDict() + + config.input = {} + ### Using Imagenette here to ensure this config is runnable without manual + ### download of ImageNet. This is only meant for testing and will overfit + ### immediately. Please download ImageNet to reproduce the paper results. + # config.input.data = dict(name='imagenet2012', split='train[4096:]') + config.input.data = dict(name='imagenette', split='train') + + config.input.batch_size = 8 * 1024 if not arg.runlocal else 8 + config.input.shuffle_buffer_size = 25_000 if not arg.runlocal else 10 + + config.total_epochs = 500 + + config.input.pp = ( + f'decode_jpeg_and_inception_crop({arg.res},' + f'area_min=80, area_max=100, ratio_min=1.0, ratio_max=1.0,' + f'method="bicubic", antialias=True)' + f'|flip_lr' + f'|value_range(-1, 1, key="image")' + f'|copy("label", "labels")' + f'|keep("image", "labels")') + + pp_eval = ( + f'decode' + f'|resize_small({arg.res}, inkey="image", outkey="image",' + f'method="bicubic", antialias=True)' + f'|central_crop({arg.res})' + f'|value_range(-1, 1, key="image")' + f'|copy("label", "labels")' + f'|keep("image", "labels")') + + config.log_training_steps = 50 + config.ckpt_steps = 1000 + config.keep_ckpt_steps = None + + # Flags for AR model. + config.ar_generation_config = ml_collections.ConfigDict() + config.ar_generation_config.temp = 0.95 + config.ar_generation_config.temp_probs = 1.0 + config.ar_generation_config.beam_size = 1 + config.ar_generation_config.fan_size = 1 + config.ar_generation_config.rand_top_k = False + config.ar_generation_config.rand_top_k_temp = 1.0 + config.ar_generation_config.cfg_inference_weight = 0.4 + + # Flags for masked model. + config.masked_generation_config = ml_collections.ConfigDict() + config.masked_generation_config.choice_temperature = 35.0 + config.masked_generation_config.ordering = 'maskgit' + config.masked_generation_config.cfg_inference_weight = 0.0 + config.masked_generation_config.schedule = 'cosine' + + # Used for eval sweep. + config.eval_only = False + + # VAE section + config.vae = {} + config.vae.model = ml_collections.ConfigDict() + config.vae.model.code_len = (arg.res // arg.patch_size) ** 2 + config.vae.model_name = 'proj.givt.cnn' + config.vae.model.codeword_dim = 16 + config.vae.model.filters = 128 + config.vae.model.num_res_blocks = 2 + config.vae.model.channel_multipliers = (1, 1, 2, 2, 4) + config.vae.model.conv_downsample = False + config.vae.model.activation_fn = 'swish' + config.vae.model.norm_type = 'GN' + if arg.model_size == 'large': + config.vae.model_init = 'gs://big_vision/givt/vae_imagenet_2012_beta_1e-5_params' + else: + config.vae.model_init = 'gs://big_vision/givt/vae_imagenet_2012_beta_5e-5_params' + config.vae.model.malib_ckpt = True + config.vae.model_load = {} + config.vae.model_load.malib_ckpt = config.vae.model.malib_ckpt + config.vae.model_load.use_ema_params = True + + # GIVT section + config.model_name = 'proj.givt.givt' + config.model_init = '' + assert arg.model_size in GIVT_MODELS, f'Unknown model size: {arg.model_size}' + config.model = ml_collections.ConfigDict(GIVT_MODELS[arg.model_size]) + config.model.num_layers = 0 + config.model.num_labels = 1000 # None + config.model.seq_len = config.vae.model.code_len + config.model.out_dim = config.vae.model.codeword_dim + config.model.num_mixtures = 16 + config.model.posemb_type = 'learn' + config.model.scale_tol = 1e-6 + config.model.style = arg.style + config.model.min_masking_rate_training = 0.3 + config.model.mask_style = 'concat' + config.model.drop_labels_probability = 0.1 + config.model.fix_square_plus = True + config.model.per_channel_mixtures = False + config.model_init = '' + # Required for model sharding + config.model.scan = True + config.model.remat_policy = 'nothing_saveable' + + # Adaptor section + config.adaptor_name = 'proj.givt.adaptor' if arg.adaptor else '' + config.adaptor = {} + config.adaptor.model = ml_collections.ConfigDict() + config.adaptor.model.num_blocks = 8 + config.adaptor.model.num_channels_bottleneck = 4 * config.model.out_dim + + config.optax_name = 'scale_by_adam' + config.optax = dict(b2=0.95) + config.grad_clip_norm = 1.0 + + # FSDP training by default + config.sharding_strategy = [('.*', 'fsdp(axis="data")')] + config.sharding_rules = [('act_batch', ('data',))] + + # Standard schedule + config.lr = 0.001 + config.wd = 0.0001 + config.schedule = dict(decay_type='cosine', warmup_percent=0.1) + + # MaskGIT-specific parameters + if arg.style == 'masked': + config.model.dec_dropout_rate = 0.4 + config.wd = 0.0 + if arg.res == 512: + config.masked_generation_config.choice_temperature = 140 + # GIVT-Causal 512px specific parameters + elif arg.res == 512 and arg.model_size == 'large': + config.model.dec_dropout_rate = 0.1 + # Set up space-to-depth/pixel shuffle + config.vae.model.code_len //= 2 + config.vae.model.pixel_shuffle_patch_size = (1, 2) + config.model.seq_len //= 2 + config.model.out_dim = config.vae.model.codeword_dim * 2 + config.model.num_mixtures = 32 + config.adaptor.model.num_channels_bottleneck = 8 * config.model.out_dim + config.adaptor.model.pixel_shuffle_patch_size = (1, 2) + # Update sampling config + config.ar_generation_config.temp = 0.9 + config.ar_generation_config.cfg_inference_weight = 0.9 + + ### Evaluation section + config.evals = {} + config.evals.val = ml_collections.ConfigDict() + config.evals.val.type = 'mean' + config.evals.val.pred = 'validation' + config.evals.val.data = {**config.input.data} + config.evals.val.data.split = f'train[:{4096 if not arg.runlocal else 8}]' + config.evals.val.pp_fn = pp_eval + config.evals.val.log_steps = 1_000 if not arg.runlocal else 20 + + config.evals.save_pred_sampling = dict( + type='proj.givt.save_predictions', + pp_fn=pp_eval, + log_steps=10_000, + pred='sample', + batch_size=512, + data=dict(name=config.input.data.name, split='validation[:512]'), + outfile='inference_sampled.npz', + ) + + config.seed = 0 + + config.ckpt_timeout = 30 + + if arg.runlocal: + config.input.batch_size = 4 + config.input.shuffle_buffer_size = 10 + config.log_training_steps = 5 + config.model.num_decoder_layers = 2 + + config.evals.val.data.split = 'validation[:16]' + config.evals.val.log_steps = 20 + + return config diff --git a/big_vision/configs/proj/givt/givt_nyu_depth.py b/big_vision/configs/proj/givt/givt_nyu_depth.py new file mode 100644 index 0000000000000000000000000000000000000000..31c03499498c34f8b1a947d9e491ee3c0e8164b8 --- /dev/null +++ b/big_vision/configs/proj/givt/givt_nyu_depth.py @@ -0,0 +1,198 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""Train a GIVT encoder-decoder model for NYU depth prediction.""" + +import itertools +import big_vision.configs.common as bvcc +import ml_collections + +ConfigDict = ml_collections.ConfigDict + +VTT_MODELS = { + 'base': dict(num_layers=12, num_decoder_layers=12, num_heads=12, mlp_dim=3072, emb_dim=768), + 'large': dict(num_layers=24, num_decoder_layers=24, num_heads=16, mlp_dim=4096, emb_dim=1024), +} + +RES = 512 +PATCH_SIZE = 16 +LABEL_RES = 512 +LABEL_PATCH_SIZE = 16 +QUANTIZATION_BINS = 256 +MIN_DEPTH = 0.001 +MAX_DEPTH = 10.0 + + +def get_config(arg='split=sweep'): + """Config for training.""" + arg = bvcc.parse_arg(arg, split='sweep', runlocal=False, singlehost=False) + config = ConfigDict() + + config.input = {} + config.input.pp = ( + f'decode|nyu_depth|' + f'randu("fliplr")|det_fliplr(key="image")|det_fliplr(key="labels")|' + f'inception_box|crop_box(key="image")|crop_box(key="labels")|' + f'resize({RES})|' + f'resize({LABEL_RES},key="labels",method="nearest")|' + f'bin_nyu_depth(min_depth={MIN_DEPTH}, max_depth={MAX_DEPTH}, num_bins={QUANTIZATION_BINS})|' + f'value_range(-1,1)|' + f'copy("image", "cond_image")|copy("labels", "image")|' + f'keep("image", "cond_image")' + ) + pp_eval = ( + f'decode|nyu_depth|' + f'nyu_eval_crop|' + f'resize({RES})|' + f'resize({LABEL_RES},key="labels",method="nearest")|' + f'bin_nyu_depth(min_depth={MIN_DEPTH}, max_depth={MAX_DEPTH}, num_bins={QUANTIZATION_BINS})|' + f'value_range(-1,1)|' + f'copy("image", "cond_image")|copy("labels", "image")|' + f'keep("image", "cond_image")' + ) + pp_predict = ( + f'decode|nyu_depth|' + f'nyu_eval_crop|copy("labels","ground_truth")|' + f'resize({RES})|' + f'value_range(-1,1)|' + f'copy("image", "cond_image")|' + f'strong_hash(inkey="tfds_id", outkey="image/id")|' + f'keep("cond_image", "ground_truth", "image/id")' + ) + + config.input.data = dict(name='nyu_depth_v2', split='train') + config.input.batch_size = 512 + config.input.shuffle_buffer_size = 50_000 + + config.total_epochs = 50 + + config.log_training_steps = 50 + config.ckpt_steps = 1000 + config.keep_ckpt_steps = None + config.prefetch_to_device = 2 + config.seed = 0 + + # Optimizer section + config.optax_name = 'big_vision.scale_by_adafactor' + config.optax = dict(beta2_cap=0.95) + + config.ar_generation_config = ConfigDict() + config.ar_generation_config.temp = 0.9 + config.ar_generation_config.temp_probs = 1.0 + config.ar_generation_config.beam_size = 2 + config.ar_generation_config.fan_size = 8 + config.ar_generation_config.rand_top_k = False + config.ar_generation_config.rand_top_k_temp = 1.0 + + config.lr = 0.001 + config.wd = 0.000001 + config.lr_mults = [ + ('pos_embedding_encoder.*', 0.1), + ('EmbedPatches.*', 0.1), + ('encoder.*', 0.1), + ('decoder.*', 1.0) + ] + config.schedule = dict(decay_type='cosine', warmup_percent=0.1) + + # Oracle section + config.min_depth = MIN_DEPTH + config.max_depth = MAX_DEPTH + config.vae = ConfigDict() + config.vae.model_name = 'proj.givt.vit' + config.vae.model = ConfigDict() + config.vae.model.input_size = (RES, RES) + config.vae.model.patch_size = (PATCH_SIZE, PATCH_SIZE) + config.vae.model.code_len = 256 + config.vae.model.width = 768 + config.vae.model.enc_depth = 6 + config.vae.model.dec_depth = 12 + config.vae.model.mlp_dim = 3072 + config.vae.model.num_heads = 12 + config.vae.model.codeword_dim = 16 + config.vae.model.code_dropout = 'none' + config.vae.model.bottleneck_resize = True + # values: (channel index in source image, number of classes) + config.vae.model.inout_specs = { + 'depth': (0, QUANTIZATION_BINS), + } + config.vae.model_init = 'gs://big_vision/givt/vae_nyu_depth_params.npz' + + # Model section + config.model_name = 'proj.givt.givt' + # # Base model (for exploration) + # config.model_init = {'encoder': 'howto-i21k-B/16'} + # config.model = ConfigDict(VTT_MODELS['base']) + # Large model + config.model_init = {'encoder': 'howto-i21k-L/16'} + config.model_load = dict(dont_load=('cls', 'head/bias', 'head/kernel')) + config.model = ConfigDict(VTT_MODELS['large']) + config.model.patches = (PATCH_SIZE, PATCH_SIZE) + config.model.input_size = (RES, RES) + config.model.posemb_type = 'learn' + config.model.seq_len = config.vae.model.code_len + config.model.num_labels = None + config.model.num_mixtures = 1 + config.model.fix_square_plus = True + config.model.out_dim = config.vae.model.codeword_dim + config.model.scale_tol = 1e-6 + config.model.dec_dropout_rate = 0.0 + + # Evaluation section + config.evals = {} + config.evals.val = ConfigDict() + config.evals.val.type = 'mean' + config.evals.val.pred = 'validation' + config.evals.val.data = {**config.input.data} + config.evals.val.data.split = 'validation' + config.evals.val.pp_fn = pp_eval + config.evals.val.log_steps = 250 + + base = { + 'type': 'proj.givt.nyu_depth', + 'data': {**config.input.data}, + 'pp_fn': pp_predict, + 'pred': 'sample_depth', + 'log_steps': 2000, + 'min_depth': MIN_DEPTH, + 'max_depth': MAX_DEPTH, + } + + config.evals.nyu_depth_val = dict(base) + config.evals.nyu_depth_val.data.split = 'validation' + + config.evals.save_pred = dict(base) + config.evals.save_pred.type = 'proj.givt.save_predictions' + del config.evals.save_pred.min_depth, config.evals.save_pred.max_depth + config.evals.save_pred.log_steps = 100_000 + config.evals.save_pred.data.split = 'validation[:128]' + config.evals.save_pred.outfile = 'inference.npz' + + config.eval_only = False + config.seed = 0 + + if arg.runlocal: + config.input.batch_size = 4 + config.input.shuffle_buffer_size = 10 + config.evals.val.log_steps = 20 + config.evals.val.data.split = 'validation[:4]' + config.evals.nyu_depth_val.data.split = 'validation[:4]' + config.evals.save_pred.data.split = 'validation[:4]' + config.model.update(VTT_MODELS['base']) + del config.model_init + for k in config.evals.keys(): + if k not in ['val', 'nyu_depth_val', 'save_pred']: + del config.evals[k] + + return config diff --git a/big_vision/configs/proj/givt/givt_overview.png b/big_vision/configs/proj/givt/givt_overview.png new file mode 100644 index 0000000000000000000000000000000000000000..a05da7e9468a77f36eab1c2eced78c6229ceeb6e Binary files /dev/null and b/big_vision/configs/proj/givt/givt_overview.png differ diff --git a/big_vision/configs/proj/givt/vae_coco_panoptic.py b/big_vision/configs/proj/givt/vae_coco_panoptic.py new file mode 100644 index 0000000000000000000000000000000000000000..db29d6e657d955f92e38313a201f34ed682eee5d --- /dev/null +++ b/big_vision/configs/proj/givt/vae_coco_panoptic.py @@ -0,0 +1,136 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""Train VAE for GIVT-based UViM COCO panoptic task. +""" + +import big_vision.configs.common as bvcc +import ml_collections as mlc + + +def get_config(arg='res=512,patch_size=16'): + """Config for training label compression on COCO-panoptic.""" + arg = bvcc.parse_arg(arg, res=512, patch_size=16, + runlocal=False, singlehost=False) + config = mlc.ConfigDict() + + config.input = {} + config.input.data = dict(name='coco/2017_panoptic', split='train[4096:]') + + config.input.batch_size = 1024 + config.input.shuffle_buffer_size = 25_000 + + config.total_epochs = 500 + + config.input.pp = ( + f'decode|coco_panoptic|concat(["semantics","instances"], "labels")|' + f'randu("fliplr")|det_fliplr(key="image")|det_fliplr(key="labels")|' + f'inception_box|crop_box(key="image")|crop_box(key="labels")|' + f'resize({arg.res})|resize({arg.res},key="labels",method="nearest")|' + f'value_range(-1, 1)|make_canonical|copy("labels","image")|keep("image")' + ) + pp_eval = ( + f'decode|coco_panoptic|concat(["semantics","instances"], "labels")|' + f'resize({arg.res})|resize({arg.res},key="labels",method="nearest")|' + f'value_range(-1, 1)|make_canonical|copy("labels","image")|keep("image", "image/id")' + ) + + config.log_training_steps = 50 + config.ckpt_steps = 1000 + config.keep_ckpt_steps = None + + # Model section + config.model_name = 'proj.givt.vit' + config.model = mlc.ConfigDict() + config.model.input_size = (arg.res, arg.res) + config.model.patch_size = (arg.patch_size, arg.patch_size) + config.model.code_len = 256 + config.model.width = 768 + config.model.enc_depth = 6 + config.model.dec_depth = 12 + config.model.mlp_dim = 3072 + config.model.num_heads = 12 + config.model.codeword_dim = 32 + config.model.code_dropout = 'none' + config.model.bottleneck_resize = True + config.model.scan = True + config.model.remat_policy = 'nothing_saveable' + + config.rec_loss_fn = 'xent' # xent, l2 + # values: (index in source image, number of classes) + config.model.inout_specs = { + 'semantics': (0, 133 + 1), # +1 for void label + 'instances': (1, 100), # COCO: actually 98 train/78 validation. + } + + config.beta = 2.5e-4 + config.beta_percept = 0.0 + + config.optax_name = 'scale_by_adam' + config.optax = dict(b2=0.95) + config.grad_clip_norm = 1.0 + + # FSDP training by default + config.sharding_strategy = [('.*', 'fsdp(axis="data")')] + config.sharding_rules = [('act_batch', ('data',))] + + config.lr = 1e-3 + config.wd = 1e-4 + config.schedule = dict(decay_type='cosine', warmup_steps=0.1) + config.grad_clip_norm = 1.0 + + # Evaluation section + config.evals = {} + config.evals.val = mlc.ConfigDict() + config.evals.val.type = 'mean' + config.evals.val.pred = 'validation' + config.evals.val.data = {**config.input.data} + config.evals.val.data.split = 'train[:4096]' + config.evals.val.pp_fn = pp_eval + config.evals.val.log_steps = 250 + + base = { + 'type': 'proj.givt.coco_panoptic', + 'pp_fn': pp_eval, + 'log_steps': 5_000, + 'pred': 'predict_panoptic', + # Filters objects that occupy less than 0.03^2 fraction of all pixels. + # 'pred_kw': {'min_fraction': 0.03 ** 2}, + } + config.evals.coco_panoptic_train = dict(**base, data={'split': 'train[4096:8192]'}) + config.evals.coco_panoptic_holdout = dict(**base, data={'split': 'train[:4096]'}) + config.evals.coco_panoptic = dict(**base, data={'split': 'validation'}) + + config.evals.save_pred = dict(type='proj.givt.save_predictions') + config.evals.save_pred.pp_fn = pp_eval + config.evals.save_pred.log_steps = 100_000 + config.evals.save_pred.pred = 'predict_panoptic' + config.evals.save_pred.data = {**config.input.data} + config.evals.save_pred.data.split = 'validation[:1024]' + config.evals.save_pred.outfile = 'inference.npz' + + config.seed = 0 + + if arg.singlehost: + config.input.batch_size = 128 + config.num_epochs = 100 + elif arg.runlocal: + config.input.batch_size = 16 + config.input.shuffle_buffer_size = 10 + config.log_training_steps = 5 + config.model.enc_depth = 1 + config.model.dec_depth = 1 + + return config \ No newline at end of file diff --git a/big_vision/configs/proj/givt/vae_nyu_depth.py b/big_vision/configs/proj/givt/vae_nyu_depth.py new file mode 100644 index 0000000000000000000000000000000000000000..1f28cf5feaab113f91fd002d03157d953b377384 --- /dev/null +++ b/big_vision/configs/proj/givt/vae_nyu_depth.py @@ -0,0 +1,158 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""Train VAE on NYU depth data for GIVT-based UViM. +""" + +import big_vision.configs.common as bvcc +import ml_collections as mlc + + +QUANTIZATION_BINS = 256 +MIN_DEPTH = 0.001 +MAX_DEPTH = 10.0 + + +def get_config(arg='res=512,patch_size=16'): + """Config for training label compression on NYU depth.""" + arg = bvcc.parse_arg(arg, res=512, patch_size=16, + runlocal=False, singlehost=False) + config = mlc.ConfigDict() + + config.input = {} + config.input.data = dict(name='nyu_depth_v2', split='train') + + config.input.batch_size = 1024 + config.input.shuffle_buffer_size = 25_000 + + config.total_epochs = 200 + + config.input.pp = ( + f'decode|nyu_depth|' + f'randu("fliplr")|det_fliplr(key="image")|det_fliplr(key="labels")|' + f'inception_box|crop_box(key="image")|crop_box(key="labels")|' + f'resize({arg.res})|resize({arg.res},key="labels",method="nearest")|' + f'bin_nyu_depth(min_depth={MIN_DEPTH}, max_depth={MAX_DEPTH}, num_bins={QUANTIZATION_BINS})|' + f'value_range(-1, 1)|copy("labels", "image")|keep("image")' + ) + pp_eval = ( + f'decode|nyu_depth|nyu_eval_crop|' + f'resize({arg.res})|resize({arg.res},key="labels",method="nearest")|' + f'bin_nyu_depth(min_depth={MIN_DEPTH}, max_depth={MAX_DEPTH}, num_bins={QUANTIZATION_BINS})|' + f'value_range(-1, 1)|copy("labels", "image")|keep("image")' + ) + pp_pred = ( + f'decode|nyu_depth|nyu_eval_crop|copy("labels","ground_truth")|' + f'resize({arg.res})|resize({arg.res},key="labels",method="nearest")|' + f'bin_nyu_depth(min_depth={MIN_DEPTH}, max_depth={MAX_DEPTH}, num_bins={QUANTIZATION_BINS})|' + f'value_range(-1, 1)|copy("labels", "image")|' + f'keep("image", "ground_truth")' + ) + + config.log_training_steps = 50 + config.ckpt_steps = 1000 + config.keep_ckpt_steps = None + + # Model section + config.min_depth = MIN_DEPTH + config.max_depth = MAX_DEPTH + config.model_name = 'proj.givt.vit' + config.model = mlc.ConfigDict() + config.model.input_size = (arg.res, arg.res) + config.model.patch_size = (arg.patch_size, arg.patch_size) + config.model.code_len = 256 + config.model.width = 768 + config.model.enc_depth = 6 + config.model.dec_depth = 12 + config.model.mlp_dim = 3072 + config.model.num_heads = 12 + config.model.codeword_dim = 16 + config.model.code_dropout = 'none' + config.model.bottleneck_resize = True + config.model.scan = True + config.model.remat_policy = 'nothing_saveable' + config.model_init = '' + + config.rec_loss_fn = 'xent' # xent, l2 + config.mask_zero_target = True + # values: (index in source image, number of classes) + config.model.inout_specs = { + 'depth': (0, QUANTIZATION_BINS), + } + + config.beta = 2e-4 + config.beta_percept = 0.0 + + # Optimizer section + config.optax_name = 'scale_by_adam' + config.optax = dict(b2=0.95) + + # FSDP training by default + config.sharding_strategy = [('.*', 'fsdp(axis="data")')] + config.sharding_rules = [('act_batch', ('data',))] + + config.lr = 1e-3 + config.wd = 1e-4 + config.schedule = dict(decay_type='cosine', warmup_steps=0.1) + config.grad_clip_norm = 1.0 + + # Evaluation section + config.evals = {} + config.evals.val = mlc.ConfigDict() + config.evals.val.type = 'mean' + config.evals.val.pred = 'validation' + config.evals.val.data = {**config.input.data} + config.evals.val.data.split = 'validation' + config.evals.val.pp_fn = pp_eval + config.evals.val.log_steps = 250 + + base = { + 'type': 'proj.givt.nyu_depth', + 'data': {**config.input.data}, + 'pp_fn': pp_pred, + 'pred': 'predict_depth', + 'log_steps': 2000, + 'min_depth': MIN_DEPTH, + 'max_depth': MAX_DEPTH, + } + config.evals.nyu_depth_val = {**base} + config.evals.nyu_depth_val.data.split = 'validation' + + # ### Uses a lot of memory + # config.evals.save_pred = dict(type='proj.givt.save_predictions') + # config.evals.save_pred.pp_fn = pp_eval + # config.evals.save_pred.log_steps = 100_000 + # config.evals.save_pred.data = {**config.input.data} + # config.evals.save_pred.data.split = 'validation[:64]' + # config.evals.save_pred.batch_size = 64 + # config.evals.save_pred.outfile = 'inference.npz' + + config.eval_only = False + config.seed = 0 + + if arg.singlehost: + config.input.batch_size = 128 + config.num_epochs = 50 + elif arg.runlocal: + config.input.batch_size = 16 + config.input.shuffle_buffer_size = 10 + config.log_training_steps = 5 + config.model.enc_depth = 1 + config.model.dec_depth = 1 + config.evals.val.data.split = 'validation[:16]' + config.evals.val.log_steps = 20 + config.evals.nyu_depth_val.data.split = 'validation[:16]' + + return config \ No newline at end of file diff --git a/big_vision/configs/proj/gsam/vit_i1k_gsam_no_aug.py b/big_vision/configs/proj/gsam/vit_i1k_gsam_no_aug.py new file mode 100644 index 0000000000000000000000000000000000000000..90ac3ee1e0c9786af4e708aa5b68a31a869c22a9 --- /dev/null +++ b/big_vision/configs/proj/gsam/vit_i1k_gsam_no_aug.py @@ -0,0 +1,134 @@ +# Copyright 2022 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""Pre-training ViT on ILSVRC-2012 with GSAM in https://arxiv.org/abs/2203.08065 + +Run training of a B/32 model: + +big_vision.trainers.proj.gsam.train \ + --config big_vision/configs/proj/gsam/vit_i1k_gsam_no_aug.py \ + --workdir gs://[your_bucket]/big_vision/`date '+%m-%d_%H%M'` + +""" + +import big_vision.configs.common as bvcc +from big_vision.configs.common_fewshot import get_fewshot_lsr +import ml_collections as mlc + +def get_config(arg=None): + """Config for training.""" + arg = bvcc.parse_arg(arg, variant='B/32', runlocal=False) + config = mlc.ConfigDict() + + config.dataset = 'imagenet2012' + config.train_split = 'train[:99%]' + config.cache_raw = not arg.runlocal # Needs up to 120GB of RAM! + config.shuffle_buffer_size = 250_000 # Per host, so small-ish is ok. + config.num_classes = 1000 + config.loss = 'sigmoid_xent' + config.batch_size = 4096 + config.num_epochs = 300 + + pp_common = ( + '|value_range(-1, 1)' + '|onehot(1000, key="{lbl}", key_result="labels")' + '|keep("image", "labels")' + ) + config.pp_train = ( + 'decode_jpeg_and_inception_crop(224)|flip_lr|' + + pp_common.format(lbl='label') + ) + pp = 'decode|resize_small(256)|central_crop(224)' + pp_common + + # Aggressive pre-fetching because our models here are small, so we not only + # can afford it, but we also need it for the smallest models to not be + # bottle-necked by the input pipeline. Play around with it for -L models tho. + config.prefetch_to_host = 8 + config.prefetch_to_device = 4 + + config.log_training_steps = 50 + config.checkpoint_steps = 1000 + + # Model section + config.model_name = 'vit' + config.model = dict( + variant=arg.variant, + rep_size=False, + pool_type='gap', + ) + config.init_head_bias = -10.0 + + # Optimizer section + config.grad_clip_norm = 1.0 + config.optax_name = 'scale_by_adam' + config.optax = dict(mu_dtype='float32') + # The modified AdaFactor we introduced in https://arxiv.org/abs/2106.04560 + # almost always behaves exactly like adam, but at a fraction of the memory + # cost (specifically, adam_bf16 = +1.5M, adafactor = +0.5M), hence it is a + # good idea to try it when you are memory-bound! + # config.optax_name = 'big_vision.scale_by_adafactor' + # A good flag to play with when hitting instabilities, is the following: + # config.optax = dict(beta2_cap=0.95) + + config.lr = 0.003 + config.wd = 0.001 # default is 0.0001; paper used 0.3, effective wd=0.3*lr + config.schedule = dict( + warmup_steps=10_000, + decay_type='linear', + linear_end=0.01, + ) + + # GSAM settings. + # Note: when rho_max=rho_min and alpha=0, GSAM reduces to SAM. + config.gsam = dict( + rho_max=0.6, + rho_min=0.1, + alpha=0.6, + lr_max=config.get_ref('lr'), + lr_min=config.schedule.get_ref('linear_end') * config.get_ref('lr'), + ) + + # Eval section + eval_common = dict( + type='classification', + dataset='imagenet2012', + pp_fn=pp.format(lbl='label'), + loss_name=config.loss, + log_steps=2500, # Very fast O(seconds) so it's fine to run it often. + ) + config.evals = {} + config.evals.train = {**eval_common, 'split': 'train[:2%]'} + config.evals.minival = {**eval_common, 'split': 'train[99%:]'} + config.evals.val = {**eval_common, 'split': 'validation'} + config.evals.v2 = {**eval_common, 'dataset': 'imagenet_v2', 'split': 'test'} + + config.evals.real = {**eval_common} + config.evals.real.dataset = 'imagenet2012_real' + config.evals.real.split = 'validation' + config.evals.real.pp_fn = pp.format(lbl='real_label') + + config.fewshot = get_fewshot_lsr(runlocal=arg.runlocal) + config.fewshot.log_steps = 10_000 + + # Make a few things much smaller for quick local debugging testruns. + if arg.runlocal: + config.shuffle_buffer_size = 10 + config.batch_size = 8 + config.minival.split = 'train[:16]' + config.val.split = 'validation[:16]' + config.real.split = 'validation[:16]' + config.v2.split = 'test[:16]' + + return config diff --git a/big_vision/configs/proj/image_text/README.md b/big_vision/configs/proj/image_text/README.md new file mode 100644 index 0000000000000000000000000000000000000000..ec1381db4fb7b49e7179624bb448a34a2c7524ba --- /dev/null +++ b/big_vision/configs/proj/image_text/README.md @@ -0,0 +1,65 @@ +# Image/text models + +## LiT: Zero-Shot Transfer with Locked-image text Tuning + +*by Xiaohua Zhai, Xiao Wang, Basil Mustafa, Andreas Steiner, Daniel Keysers, Alexander Kolesnikov, Lucas Beyer* + +https://arxiv.org/abs/2111.07991 + +``` +@article{zhai2022lit, + title={LiT: Zero-Shot Transfer with Locked-image Text Tuning}, + author={Zhai, Xiaohua and Wang, Xiao and Mustafa, Basil and Steiner, Andreas and Keysers, Daniel and Kolesnikov, Alexander and Beyer, Lucas}, + journal={CVPR}, + year={2022} +} +``` + +Model card: +https://github.com/google-research/vision_transformer/blob/main/model_cards/lit.md + +Colabs: + +- https://colab.research.google.com/github/google-research/vision_transformer/blob/main/lit.ipynb +- https://colab.research.google.com/github/google-research/big_vision/blob/main/big_vision/configs/proj/image_text/lit.ipynb + +### Results + +| Model | Download link | ImageNet 0-shot | MS-COCO I→T | MS-COCO T→I | Config `arg` | +| :--- | :---: | :---: | :---: | :---: | :--- | +| mixed_L16L | [link](https://storage.googleapis.com/vit_models/lit/LiT-L16L.npz) | 75.7 | 48.5 | 31.2 | `txt=bert_large,img=L/16` | +| mixed_B16B | [link](https://storage.googleapis.com/vit_models/lit/LiT-B16B.npz) | 72.1 | 49.4 | 31.1 | `txt=bert_base,img=B/16,img_head` | +| mixed_B16B_2 | [link](https://storage.googleapis.com/vit_models/lit/LiT-B16B.npz) | 73.9 | 51.5 | 31.8 | `txt=bert_base,img=B/16` | +| coco_B16B | [link](https://storage.googleapis.com/vit_models/lit/big_vision/coco_B16B/checkpoint.npz) | 20.7 | 47.2 | 32.1 | `txt=bert_base,img=B/16` | + +The first three rows are the best available models trained on open source data, +originally published in the [`google-research/vision_transformer`] repository. +These models were re-evaluated with this codebase using the following commands: + +```bash +big_vision.tools.eval_only --config big_vision/configs/proj/image_text/lit_coco.py:txt=bert_base,img=B/16,img_head,init=gs://vit_models/lit/LiT-B16B.npz + +big_vision.tools.eval_only --config big_vision/configs/proj/image_text/lit_coco.py:txt=bert_base,img=B/16_2,init=gs://vit_models/lit/LiT-B16B_2.npz + +big_vision.tools.eval_only --config big_vision/configs/proj/image_text/lit_coco.py:txt=bert_large,img=L/16,init=gs://vit_models/lit/LiT-L16L.npz +``` + +Unfortunately, the public multi-modal datasets [`CC12M`] and [`YFCC100M`] are +not yet available in [`tfds`], so these models cannot be reproduced with the +codebase. For this reason we provide the much weaker model `coco_B16B` in the +third row, which was trained on the small `tfds` dataset [`coco_captions`], and +can be used to verify correctness of the codebase +([workdir](https://console.cloud.google.com/storage/browser/vit_models/lit/big_vision/coco_B16B/)). + +[`google-research/vision_transformer`]: https://github.com/google-research/vision_transformer +[`CC12M`]: https://arxiv.org/abs/2102.08981 +[`YFCC100M`]: https://arxiv.org/abs/1503.01817 +[`tfds`]: https://www.tensorflow.org/datasets/api_docs/python/tfds +[`coco_captions`]: https://www.tensorflow.org/datasets/catalog/coco_captions + + +### Changelog + +- 2022-08-18: Added LiT-B16B_2 model that was trained for 60k steps + (LiT_B16B: 30k) without linear head on the image side (LiT_B16B: 768) and has + better performance. diff --git a/big_vision/configs/proj/image_text/SigLIP_demo.ipynb b/big_vision/configs/proj/image_text/SigLIP_demo.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..120f447f27a662e747fafe74e6015bb0c2500d9b --- /dev/null +++ b/big_vision/configs/proj/image_text/SigLIP_demo.ipynb @@ -0,0 +1,1022 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [], + "gpuType": "T4" + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + }, + "accelerator": "GPU" + }, + "cells": [ + { + "cell_type": "markdown", + "source": [ + "# General information\n", + "\n", + "Example colab for SigLIP models described in [the SigLIP paper](https://arxiv.org/abs/2303.15343).\n", + "\n", + "**These models are not official Google products and were trained and released for research purposes.**\n", + "\n", + "If you find our model(s) useful for your research, consider citing\n", + "\n", + "```\n", + "@article{zhai2023sigmoid,\n", + " title={Sigmoid loss for language image pre-training},\n", + " author={Zhai, Xiaohua and Mustafa, Basil and Kolesnikov, Alexander and Beyer, Lucas},\n", + " journal={International Conference on Computer Vision ({ICCV})},\n", + " year={2023}\n", + "}\n", + "```\n", + "\n", + "If you use our released models in your products, we will appreciate any direct feedback. We are reachable by xzhai@google.com, basilm@google.com, akolesnikov@google.com and lbeyer@google.com.\n", + "\n", + "\n", + "Only the models explicitly marked with `i18n` in the name are expected to perform reasonably well on non-english data." + ], + "metadata": { + "id": "wR53lePHuiP-" + } + }, + { + "cell_type": "code", + "source": [ + "#@markdown # Environment setup\n", + "#@markdown **IMPORTANT NOTE**: Modern jax (>0.4) does not support the Colab TPU\n", + "#@markdown anymore, so don't select TPU runtime here. CPU and GPU work and are both fast enough.\n", + "\n", + "# Install the right jax version for TPU/GPU/CPU\n", + "import os\n", + "if 'COLAB_TPU_ADDR' in os.environ:\n", + " raise \"TPU colab not supported.\"\n", + "elif 'NVIDIA_PRODUCT_NAME' in os.environ:\n", + " !nvidia-smi\n", + "import jax\n", + "jax.devices()\n", + "\n", + "\n", + "# Get latest version of big_vision codebase.\n", + "!git clone --quiet --branch=main --depth=1 https://github.com/google-research/big_vision\n", + "!cd big_vision && git pull --rebase --quiet\n", + "!pip -q install -r big_vision/big_vision/requirements.txt\n", + "# Gives us ~2x faster gsutil cp to get the model checkpoints.\n", + "!pip3 -q install --no-cache-dir -U crcmod\n", + "\n", + "%cd big_vision\n", + "\n", + "\n", + "import numpy as np\n", + "import matplotlib as mpl\n", + "import matplotlib.pyplot as plt\n", + "\n", + "%matplotlib inline\n", + "%config InlineBackend.figure_format = 'retina'\n", + "\n", + "import jax\n", + "import jax.numpy as jnp\n", + "import ml_collections\n", + "\n", + "from google.colab.output import _publish as publish" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "kXSdSXVg2PAI", + "outputId": "ba908946-0cd3-4468-9034-cd108529986f", + "cellView": "form" + }, + "execution_count": 1, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Thu Sep 28 09:08:47 2023 \n", + "+-----------------------------------------------------------------------------+\n", + "| NVIDIA-SMI 525.105.17 Driver Version: 525.105.17 CUDA Version: 12.0 |\n", + "|-------------------------------+----------------------+----------------------+\n", + "| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n", + "| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n", + "| | | MIG M. |\n", + "|===============================+======================+======================|\n", + "| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 |\n", + "| N/A 75C P8 14W / 70W | 0MiB / 15360MiB | 0% Default |\n", + "| | | N/A |\n", + "+-------------------------------+----------------------+----------------------+\n", + " \n", + "+-----------------------------------------------------------------------------+\n", + "| Processes: |\n", + "| GPU GI CI PID Type Process name GPU Memory |\n", + "| ID ID Usage |\n", + "|=============================================================================|\n", + "| No running processes found |\n", + "+-----------------------------------------------------------------------------+\n", + "fatal: destination path 'big_vision' already exists and is not an empty directory.\n", + " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", + " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", + " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", + " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", + "/content/big_vision\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "# Choose and load model, perform inference" + ], + "metadata": { + "id": "byHpmgAO6inM" + } + }, + { + "cell_type": "code", + "source": [ + "# Pick your hero: (WHEN CHANGING THIS, RERUN IMAGE/TEXT EMBEDDING CELLS)\n", + "# Give this cell 1-3mins.\n", + "\n", + "# VARIANT, RES = 'B/16', 224\n", + "# VARIANT, RES = 'B/16', 256\n", + "# VARIANT, RES = 'B/16', 384\n", + "# VARIANT, RES = 'B/16', 512\n", + "# VARIANT, RES = 'L/16', 256\n", + "VARIANT, RES = 'L/16', 384\n", + "# VARIANT, RES = 'So400m/14', 224\n", + "# VARIANT, RES = 'So400m/14', 384\n", + "# VARIANT, RES = 'B/16-i18n', 256\n", + "\n", + "CKPT, TXTVARIANT, EMBDIM, SEQLEN, VOCAB = {\n", + " ('B/16', 224): ('webli_en_b16_224_63724782.npz', 'B', 768, 64, 32_000),\n", + " ('B/16', 256): ('webli_en_b16_256_60500360.npz', 'B', 768, 64, 32_000),\n", + " ('B/16', 384): ('webli_en_b16_384_68578854.npz', 'B', 768, 64, 32_000),\n", + " ('B/16', 512): ('webli_en_b16_512_68580893.npz', 'B', 768, 64, 32_000),\n", + " ('L/16', 256): ('webli_en_l16_256_60552751.npz', 'L', 1024, 64, 32_000),\n", + " ('L/16', 384): ('webli_en_l16_384_63634585.npz', 'L', 1024, 64, 32_000),\n", + " ('So400m/14', 224): ('webli_en_so400m_224_57633886.npz', 'So400m', 1152, 16, 32_000),\n", + " ('So400m/14', 384): ('webli_en_so400m_384_58765454.npz', 'So400m', 1152, 64, 32_000),\n", + " ('B/16-i18n', 256): ('webli_i18n_b16_256_66117334.npz', 'B', 768, 64, 250_000),\n", + " ('So400m/16', 256): ('webli_i18n_so400m_16_256_78061115.npz', 'So400m', 1152, 64, 250_000),\n", + "}[VARIANT, RES]\n", + "\n", + "# It is significantly faster to first copy the checkpoint (30s vs 8m30 for B and 1m vs ??? for L)\n", + "!test -f /tmp/{CKPT} || gsutil cp gs://big_vision/siglip/{CKPT} /tmp/\n", + "\n", + "if VARIANT.endswith('-i18n'):\n", + " VARIANT = VARIANT[:-len('-i18n')]\n", + "\n", + "import big_vision.models.proj.image_text.two_towers as model_mod\n", + "\n", + "model_cfg = ml_collections.ConfigDict()\n", + "model_cfg.image_model = 'vit' # TODO(lbeyer): remove later, default\n", + "model_cfg.text_model = 'proj.image_text.text_transformer' # TODO(lbeyer): remove later, default\n", + "model_cfg.image = dict(variant=VARIANT, pool_type='map')\n", + "model_cfg.text = dict(variant=TXTVARIANT, vocab_size=VOCAB)\n", + "model_cfg.out_dim = (None, EMBDIM) # (image_out_dim, text_out_dim)\n", + "model_cfg.bias_init = -10.0\n", + "model_cfg.temperature_init = 10.0\n", + "\n", + "model = model_mod.Model(**model_cfg)\n", + "\n", + "# Using `init_params` is slower but will lead to `load` below performing sanity-checks.\n", + "# init_params = jax.jit(model.init, backend=\"cpu\")(jax.random.PRNGKey(42), jnp.zeros([1, RES, RES, 3], jnp.float32), jnp.zeros([1, SEQLEN], jnp.int32))['params']\n", + "init_params = None # Faster but bypasses loading sanity-checks.\n", + "\n", + "params = model_mod.load(init_params, f'/tmp/{CKPT}', model_cfg)" + ], + "metadata": { + "id": "0DsOabGD7MRG", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "5afc9f52-7eb4-4a0d-b681-3ab5945ce9b4" + }, + "execution_count": 2, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Copying gs://big_vision/siglip/webli_i18n_b16_256_66117334.npz...\n", + "- [1 files][ 1.3 GiB/ 1.3 GiB] 45.3 MiB/s \n", + "Operation completed over 1 objects/1.3 GiB. \n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "#@title Load and embed images\n", + "\n", + "import big_vision.pp.builder as pp_builder\n", + "import big_vision.pp.ops_general\n", + "import big_vision.pp.ops_image\n", + "import big_vision.pp.ops_text\n", + "import PIL\n", + "\n", + "!wget -q https://cdn.openai.com/multimodal-neurons/assets/apple/apple-ipod.jpg\n", + "!wget -q https://cdn.openai.com/multimodal-neurons/assets/apple/apple-blank.jpg\n", + "!wget -q 'https://images.unsplash.com/photo-1566467021888-b03548769dd1?ixlib=rb-4.0.3&q=85&fm=jpg&crop=entropy&cs=srgb&dl=svetlana-gumerova-hQHm2D1fH70-unsplash.jpg&w=640' -O cold_drink.jpg\n", + "!wget -q 'https://images.rawpixel.com/image_1300/czNmcy1wcml2YXRlL3Jhd3BpeGVsX2ltYWdlcy93ZWJzaXRlX2NvbnRlbnQvbHIvdXB3azU4ODU5NzY1LXdpa2ltZWRpYS1pbWFnZS1rb3diMmhkeC5qcGc.jpg' -O hot_drink.jpg\n", + "!wget -q https://storage.googleapis.com/big_vision/siglip/authors.jpg\n", + "!wget -q https://storage.googleapis.com/big_vision/siglip/siglip.jpg\n", + "!wget -q https://storage.googleapis.com/big_vision/siglip/caffeine.jpg\n", + "!wget -q https://storage.googleapis.com/big_vision/siglip/robosign.jpg\n", + "!wget -q https://storage.googleapis.com/big_vision/siglip/fried_fish.jpeg\n", + "!wget -q 'https://pbs.twimg.com/media/FTyEyxyXsAAyKPc?format=jpg&name=small' -O cow_beach.jpg\n", + "!wget -q 'https://storage.googleapis.com/big_vision/siglip/cow_beach2.jpg' -O cow_beach2.jpg\n", + "!wget -q 'https://pbs.twimg.com/media/Frb6NIEXwAA8-fI?format=jpg&name=medium' -O mountain_view.jpg\n", + "\n", + "\n", + "images = [PIL.Image.open(fname) for fname in [\n", + " 'apple-ipod.jpg',\n", + " 'apple-blank.jpg',\n", + " 'cold_drink.jpg',\n", + " 'hot_drink.jpg',\n", + " 'caffeine.jpg',\n", + " 'siglip.jpg',\n", + " 'authors.jpg',\n", + " 'robosign.jpg',\n", + " 'cow_beach.jpg',\n", + " 'cow_beach2.jpg',\n", + " 'mountain_view.jpg',\n", + "]]\n", + "\n", + "pp_img = pp_builder.get_preprocess_fn(f'resize({RES})|value_range(-1, 1)')\n", + "imgs = np.array([pp_img({'image': np.array(image)})['image'] for image in images])\n", + "zimg, _, out = model.apply({'params': params}, imgs, None)\n", + "\n", + "print(imgs.shape, zimg.shape)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "xmuXfCfBjgeF", + "outputId": "3627819b-007e-4107-e1f4-06b7ad3ac03a" + }, + "execution_count": 10, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "(11, 384, 384, 3) (11, 1024)\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "#@title Tokenize and embed texts\n", + "\n", + "texts = [\n", + " 'an apple',\n", + " 'a picture of an apple',\n", + " 'an ipod',\n", + " 'granny smith',\n", + " 'an apple with a note saying \"ipod\"',\n", + " 'a cold drink on a hot day',\n", + " 'a hot drink on a cold day',\n", + " 'a photo of a cold drink on a hot day',\n", + " 'a photo of a hot drink on a cold day',\n", + " #\n", + " 'a photo of two guys in need of caffeine',\n", + " 'a photo of two guys in need of water',\n", + " 'a photo of the SigLIP authors',\n", + " 'a photo of a rock band',\n", + " 'a photo of researchers at Google Brain',\n", + " 'a photo of researchers at OpenAI',\n", + " #\n", + " 'a robot on a sign',\n", + " 'a photo of a robot on a sign',\n", + " 'an empty street',\n", + " 'autumn in Toronto',\n", + " 'a photo of autumn in Toronto',\n", + " 'a photo of Toronto in autumn',\n", + " 'a photo of Toronto in summer',\n", + " 'autumn in Singapore',\n", + " #\n", + " 'cow',\n", + " 'a cow in a tuxedo',\n", + " 'a cow on the beach',\n", + " 'a cow in the prairie',\n", + " #\n", + " 'the real mountain view',\n", + " 'Zürich',\n", + " 'San Francisco',\n", + " 'a picture of a laptop with the lockscreen on, a cup of cappucino, salt and pepper grinders. The view through the window reveals lake Zürich and the Alps in the background of the city.',\n", + "]\n", + "\n", + "TOKENIZERS = {\n", + " 32_000: 'c4_en',\n", + " 250_000: 'mc4',\n", + "}\n", + "pp_txt = pp_builder.get_preprocess_fn(f'tokenize(max_len={SEQLEN}, model=\"{TOKENIZERS[VOCAB]}\", eos=\"sticky\", pad_value=1, inkey=\"text\")')\n", + "txts = np.array([pp_txt({'text': text})['labels'] for text in texts])\n", + "_, ztxt, out = model.apply({'params': params}, None, txts)\n", + "\n", + "print(txts.shape, ztxt.shape)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "KGrpkRTtjU-L", + "outputId": "7c43b56e-cd53-4801-b1e3-66774368a1d2" + }, + "execution_count": 11, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "(31, 64) (31, 1024)\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "# This is how to get all probabilities:\n", + "print(f\"Learned temperature {out['t'].item():.1f}, learned bias: {out['b'].item():.1f}\")\n", + "probs = jax.nn.sigmoid(zimg @ ztxt.T * out['t'] + out['b'])\n", + "print(f\"{probs[0][0]:.1%} that image 0 is '{texts[0]}'\")\n", + "print(f\"{probs[0][1]:.1%} that image 0 is '{texts[1]}'\")" + ], + "metadata": { + "id": "TIdAVw9VGEAw", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "22fc0d9a-8986-4679-ca89-6e4330a55c6e" + }, + "execution_count": 12, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Learned temperature 118.2, learned bias: -12.7\n", + "10.4% that image 0 is 'an apple'\n", + "42.8% that image 0 is 'a picture of an apple'\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "# @title Pretty demo (code)\n", + "from IPython.display import Javascript\n", + "\n", + "DEMO_IMG_SIZE = 96\n", + "\n", + "import base64\n", + "import io\n", + "\n", + "def bv2rgb(bv_img):\n", + " return (bv_img * 127.5 + 127.5).astype(np.uint8)\n", + "\n", + "def html_img(*, enc_img=None, pixels=None, id=None, size=100, max_size=None, max_height=None, style=\"\"):\n", + " if enc_img is None and pixels is not None:\n", + " with io.BytesIO() as buf:\n", + " PIL.Image.fromarray(np.asarray(pixels)).save(buf, format=\"JPEG\")\n", + " enc_img = buf.getvalue()\n", + "\n", + " img_data = base64.b64encode(np.ascontiguousarray(enc_img)).decode('ascii')\n", + "\n", + " id_spec = f'id={id}' if id else ''\n", + " if size is not None:\n", + " style_spec = f'style=\"{style}; width: {size}px; height: {size}px\"'\n", + " elif max_size is not None:\n", + " style_spec = f'style=\"{style}; width: auto; height: auto; max-width: {max_size}px; max-height: {max_size}px;\"'\n", + " elif max_height is not None:\n", + " style_spec = f'style=\"{style}; object-fit: cover; width: auto; height: {max_height}px;\"'\n", + " else: style_spec = ''\n", + "\n", + " return f''\n", + "\n", + "\n", + "def make_table(zimg, ztxt, out):\n", + " # The default learnable bias is a little conservative. Play around with it!\n", + " t, b = out['t'].item(), out['b'].item()\n", + " tempered_logits = zimg @ ztxt.T * t\n", + " probs = 1 / (1 + np.exp(-tempered_logits - b))\n", + " publish.javascript(f\"var logits = {tempered_logits.tolist()};\")\n", + "\n", + " def color(p):\n", + " return mpl.colors.rgb2hex(mpl.cm.Greens(p / 2)) if p >= 0.01 else \"transparent\"\n", + "\n", + " publish.javascript(f\"var cmap = {[color(x) for x in np.linspace(0, 1, 50)]};\")\n", + " def cell(x, iimg, itxt):\n", + " return f\"
{x * 100:>4.0f}%
\"\n", + "\n", + " html = f'''\n", + "

\n", + " \n", + " \n", + " \n", + "

\n", + " '''\n", + "\n", + " html += \"\\n\"\n", + " html += \"\"\n", + " html += \"\".join([f\"\" + \"\".join([cell(probs[iimg, itxt], iimg, itxt) for iimg in range(len(imgs))]) + f\"
\" + html_img(pixels=bv2rgb(img), size=DEMO_IMG_SIZE) for img in imgs])\n", + " html += \"\"\n", + " for itxt, txt in enumerate(texts):\n", + " html += f\"
{txt}\"\n", + "\n", + " publish.css(r\"\"\"\n", + " table {\n", + " border-collapse: collapse;\n", + " }\n", + "\n", + " tr {\n", + " border: 1px transparent;\n", + " }\n", + "\n", + " tr:nth-child(odd) {\n", + " background-color: #F5F5F5;\n", + " }\n", + "\n", + " tr:hover {\n", + " background-color: lightyellow;\n", + " border: 1px solid black;\n", + " }\n", + "\n", + " td.pct {\n", + " text-align: center;\n", + " }\n", + " \"\"\")\n", + " publish.html(html)\n", + "\n", + " # JS code to compute and write all probs from the logits.\n", + " display(Javascript('''\n", + " function update(b) {\n", + " for(var iimg = 0; iimg < logits.length; iimg++) {\n", + " for(var itxt = 0; itxt < logits[iimg].length; itxt++) {\n", + " const el = document.getElementById(`p_${iimg}_${itxt}`);\n", + " const p = Math.round(100 / (1 + Math.exp(-logits[iimg][itxt] - b)));\n", + " const pad = p < 10.0 ? ' ' : p < 100.0 ? ' ' : ''\n", + " el.innerHTML = pad + (p).toFixed(0) + '%';\n", + "\n", + " const td = document.getElementById(`td_${iimg}_${itxt}`);\n", + " const c = cmap[Math.round(p / 100 * (cmap.length - 1))];\n", + " td.style.backgroundColor = c;\n", + " }\n", + " }\n", + " }\n", + " '''))\n", + "\n", + " # JS code to connect the bias value slider\n", + " display(Javascript('''\n", + " const value = document.querySelector(\"#value\");\n", + " const input = document.querySelector(\"#b\");\n", + " value.textContent = input.value;\n", + " input.addEventListener(\"input\", (event) => {\n", + " value.textContent = event.target.value;\n", + " update(event.target.value);\n", + " });\n", + " '''))\n", + "\n", + " # Make the cell output as large as the table to avoid annoying scrollbars.\n", + " display(Javascript(f'update({b})'))\n", + " display(Javascript('google.colab.output.resizeIframeToContent()'))" + ], + "metadata": { + "cellView": "form", + "id": "eolOc7vd_ZSj" + }, + "execution_count": 7, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "make_table(zimg, ztxt, out)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 767 + }, + "id": "mt5BIywzzA6c", + "outputId": "3b06cfb9-a3da-42d7-8caf-d5366d058f8b" + }, + "execution_count": 14, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "application/javascript": [ + "var logits = [[10.509522438049316, 12.372017860412598, 13.07434368133545, 9.578202247619629, 21.19094467163086, 1.310517430305481, 1.2763848304748535, 3.0990359783172607, 2.360225200653076, -3.670855760574341, -4.780072212219238, -1.4530967473983765, -3.3108861446380615, -3.8945610523223877, -4.378420829772949, 0.35140618681907654, 2.7228779792785645, -6.806382656097412, -3.9012961387634277, -1.7843879461288452, -4.578653812408447, -7.306142807006836, -1.253274917602539, -1.8402824401855469, -6.329799175262451, -9.506726264953613, -5.78713846206665, -1.6370103359222412, -9.404793739318848, -4.342881202697754, -13.128281593322754], [12.365941047668457, 13.45022964477539, 0.9843839406967163, 12.809731483459473, 6.767915725708008, 2.808335304260254, 1.050551414489746, 3.6161491870880127, 1.152547001838684, -7.214369297027588, -5.146897792816162, -6.283102035522461, -11.463550567626953, -7.751645565032959, -11.252680778503418, -9.319047927856445, -8.11094856262207, -8.898587226867676, -2.15217661857605, -0.10237424820661545, -3.6214966773986816, -12.085700035095215, -1.599789023399353, -1.7422595024108887, -7.456813335418701, -8.457598686218262, -5.5325212478637695, -2.4997880458831787, -8.217476844787598, -8.986675262451172, -10.336335182189941], [-1.1052173376083374, -1.3570473194122314, -3.8713269233703613, 2.3654367923736572, -9.037796020507812, 11.620930671691895, 2.1417031288146973, 13.036051750183105, -0.11228565871715546, 0.33224615454673767, 3.9813454151153564, -6.005640029907227, -5.856462001800537, -7.669452667236328, -9.974565505981445, -11.242084503173828, -12.130292892456055, -5.630223274230957, -5.570030689239502, -6.117311000823975, -7.32966423034668, -5.952571392059326, 0.4303727447986603, -0.5507297515869141, -7.554576873779297, -3.3274905681610107, -3.4397053718566895, 0.9088093638420105, -4.845495700836182, -7.663942337036133, -10.328642845153809], [1.1323682069778442, 1.3157405853271484, 0.828519880771637, -1.6223008632659912, -7.967062950134277, 4.090002059936523, 14.007913589477539, 6.785359859466553, 16.369604110717773, 1.524818778038025, -4.911859035491943, -9.018620491027832, -9.306066513061523, -8.402979850769043, -11.57016658782959, -9.890503883361816, -10.68331527709961, -5.442021369934082, 4.999141216278076, 5.106411933898926, 4.015860080718994, -12.08991527557373, 6.171087741851807, -1.0262863636016846, -8.962656021118164, -6.404715538024902, -4.912563323974609, -2.5522496700286865, -6.039242744445801, -10.613517761230469, -6.997122287750244], [-3.4062156677246094, -3.2604005336761475, -4.109685897827148, -4.58593225479126, -9.489058494567871, 1.6483688354492188, 2.376404047012329, 0.7108156681060791, 0.5808579921722412, 17.98756980895996, 9.364227294921875, 1.8207945823669434, -6.545583724975586, 3.3331942558288574, 2.5704448223114014, -7.702937602996826, -9.870623588562012, -1.303507924079895, -5.957301616668701, -6.226568222045898, -6.917541980743408, -7.621560573577881, -0.5124773979187012, -2.2896718978881836, -12.721405029296875, -6.885163307189941, -9.90884780883789, -1.4125298261642456, 2.3772332668304443, -5.4370293617248535, -1.6405099630355835], [-3.2013378143310547, -3.3440065383911133, -1.2165169715881348, -4.172476291656494, -5.278318881988525, -2.3818702697753906, -3.210822582244873, -3.580622911453247, -5.1373138427734375, -1.7848750352859497, -1.4050911664962769, 16.463136672973633, -1.4766411781311035, 16.46843147277832, 11.259382247924805, -1.0086976289749146, -1.908290982246399, -4.666292667388916, -2.9601247310638428, -2.0503976345062256, -1.600439190864563, -1.4223682880401611, -2.251126289367676, -4.444605827331543, -9.10830020904541, -10.853714942932129, -11.52085018157959, -1.6640691757202148, 2.193969964981079, 2.127061367034912, -4.728240013122559], [-0.5153040289878845, -1.290441632270813, -1.3887863159179688, -2.88513445854187, -8.828889846801758, 1.3482768535614014, 0.010438825935125351, -0.6988681554794312, -2.9927048683166504, 2.8313045501708984, 2.5383071899414062, 6.094320297241211, -1.2357840538024902, 19.095901489257812, 12.049205780029297, -2.1667087078094482, -3.2871627807617188, -4.000303268432617, -2.7362473011016846, -1.7782089710235596, -1.643406629562378, -4.0933918952941895, -2.1210238933563232, -3.1019272804260254, -8.912919998168945, -8.04006290435791, -10.427931785583496, 0.8204227089881897, -1.7909467220306396, -0.8497583270072937, -5.065787315368652], [-1.4752472639083862, -0.13337232172489166, 1.7657679319381714, -2.7154576778411865, -2.644958257675171, -1.401767373085022, 0.21228086948394775, -0.5131799578666687, 1.4820858240127563, -2.5781843662261963, 3.075222969055176, -2.9382081031799316, -7.704923152923584, -3.6199238300323486, -3.213698625564575, 10.677529335021973, 12.515663146972656, 3.690605401992798, 10.979350090026855, 12.963836669921875, 11.986873626708984, 4.023745059967041, 0.9576215744018555, -4.142323970794678, -7.46238374710083, -9.735015869140625, -8.231826782226562, -1.0106267929077148, -2.2898473739624023, -2.2792820930480957, -6.5174055099487305], [-0.3335295617580414, 1.2584013938903809, -1.2919337749481201, -2.0686888694763184, -11.050207138061523, 5.148484706878662, 0.46310505270957947, 4.050027847290039, -1.6178984642028809, -6.791775703430176, -2.2926063537597656, -7.568892002105713, -10.240560531616211, -7.8912248611450195, -11.374415397644043, -7.808314323425293, -7.384036540985107, -5.577442646026611, -4.582977771759033, -4.019510746002197, -5.569993019104004, -2.2238216400146484, -0.21682055294513702, 12.080615043640137, 6.551390647888184, 17.416383743286133, 8.308161735534668, -0.3994586169719696, -1.8691462278366089, -2.187755823135376, -4.866983413696289], [-2.294813394546509, -1.4864670038223267, -1.4635752439498901, -2.9900710582733154, -14.971826553344727, 4.747520446777344, -0.9042328000068665, 3.1032114028930664, -3.679764747619629, -5.160387992858887, -1.1286523342132568, -7.035560607910156, -6.664344787597656, -7.769715309143066, -10.94699478149414, -6.526098251342773, -6.273430347442627, -6.723901271820068, -5.448723316192627, -5.721604824066162, -7.575157165527344, -4.370161056518555, -1.393196702003479, 11.913715362548828, 17.861845016479492, 15.086359024047852, 6.581197261810303, -0.31534600257873535, -2.1320040225982666, -4.305175304412842, -7.700469970703125], [-2.552478790283203, -1.305349349975586, 0.03923465311527252, -5.891383647918701, -7.833784580230713, 1.2974026203155518, 5.689708709716797, 2.8017938137054443, 7.800131320953369, -0.12797383964061737, -4.34028434753418, -4.815661430358887, -8.476018905639648, -1.2871994972229004, -1.1152652502059937, -6.992332458496094, -7.258864402770996, 0.09565334022045135, -6.82894229888916, -5.026597023010254, -3.2372162342071533, -7.9831085205078125, -3.8290252685546875, -0.595430850982666, -5.086977005004883, -4.143807888031006, -5.033395290374756, 4.200597763061523, 6.196822166442871, -4.807774066925049, 23.876855850219727]];\n", + "//# sourceURL=js_5e545691b3" + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "application/javascript": [ + "var cmap = ['transparent', '#f6fcf4', '#f4fbf2', '#f3faf0', '#f1faee', '#f0f9ec', '#eff9eb', '#edf8e9', '#ecf8e8', '#eaf7e6', '#e8f6e4', '#e7f6e3', '#e5f5e1', '#e4f5df', '#e1f3dc', '#def2d9', '#dcf2d7', '#daf0d4', '#d7efd1', '#d5efcf', '#d2edcc', '#d0edca', '#cdecc7', '#cbeac4', '#c9eac2', '#c6e8bf', '#c3e7bc', '#c0e6b9', '#bce4b5', '#bae3b3', '#b6e2af', '#b4e1ad', '#b0dfaa', '#acdea6', '#aadda4', '#a7dba0', '#a3da9d', '#a0d99b', '#9cd797', '#99d595', '#95d391', '#91d28e', '#8ed08b', '#8ace88', '#87cd86', '#83cb82', '#7fc97f', '#7cc87c', '#78c679', '#73c476'];\n", + "//# sourceURL=js_b212ab59e1" + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "" + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "\n", + "

\n", + " \n", + " \n", + " \n", + "

\n", + " \n", + "
  10%
  43%
   0%
   0%
   0%
   0%
   0%
   0%
   0%
   0%
   0%
an apple
  43%
  69%
   0%
   0%
   0%
   0%
   0%
   0%
   0%
   0%
   0%
a picture of an apple
  60%
   0%
   0%
   0%
   0%
   0%
   0%
   0%
   0%
   0%
   0%
an ipod
   4%
  54%
   0%
   0%
   0%
   0%
   0%
   0%
   0%
   0%
   0%
granny smith
 100%
   0%
   0%
   0%
   0%
   0%
   0%
   0%
   0%
   0%
   0%
an apple with a note saying \"ipod\"
   0%
   0%
  26%
   0%
   0%
   0%
   0%
   0%
   0%
   0%
   0%
a cold drink on a hot day
   0%
   0%
   0%
  79%
   0%
   0%
   0%
   0%
   0%
   0%
   0%
a hot drink on a cold day
   0%
   0%
  59%
   0%
   0%
   0%
   0%
   0%
   0%
   0%
   0%
a photo of a cold drink on a hot day
   0%
   0%
   0%
  98%
   0%
   0%
   0%
   0%
   0%
   0%
   1%
a photo of a hot drink on a cold day
   0%
   0%
   0%
   0%
 100%
   0%
   0%
   0%
   0%
   0%
   0%
a photo of two guys in need of caffeine
   0%
   0%
   0%
   0%
   4%
   0%
   0%
   0%
   0%
   0%
   0%
a photo of two guys in need of water
   0%
   0%
   0%
   0%
   0%
  98%
   0%
   0%
   0%
   0%
   0%
a photo of the SigLIP authors
   0%
   0%
   0%
   0%
   0%
   0%
   0%
   0%
   0%
   0%
   0%
a photo of a rock band
   0%
   0%
   0%
   0%
   0%
  98%
 100%
   0%
   0%
   0%
   0%
a photo of researchers at Google Brain
   0%
   0%
   0%
   0%
   0%
  20%
  35%
   0%
   0%
   0%
   0%
a photo of researchers at OpenAI
   0%
   0%
   0%
   0%
   0%
   0%
   0%
  12%
   0%
   0%
   0%
a robot on a sign
   0%
   0%
   0%
   0%
   0%
   0%
   0%
  46%
   0%
   0%
   0%
a photo of a robot on a sign
   0%
   0%
   0%
   0%
   0%
   0%
   0%
   0%
   0%
   0%
   0%
an empty street
   0%
   0%
   0%
   0%
   0%
   0%
   0%
  16%
   0%
   0%
   0%
autumn in Toronto
   0%
   0%
   0%
   0%
   0%
   0%
   0%
  57%
   0%
   0%
   0%
a photo of autumn in Toronto
   0%
   0%
   0%
   0%
   0%
   0%
   0%
  34%
   0%
   0%
   0%
a photo of Toronto in autumn
   0%
   0%
   0%
   0%
   0%
   0%
   0%
   0%
   0%
   0%
   0%
a photo of Toronto in summer
   0%
   0%
   0%
   0%
   0%
   0%
   0%
   0%
   0%
   0%
   0%
autumn in Singapore
   0%
   0%
   0%
   0%
   0%
   0%
   0%
   0%
  36%
  32%
   0%
cow
   0%
   0%
   0%
   0%
   0%
   0%
   0%
   0%
   0%
  99%
   0%
a cow in a tuxedo
   0%
   0%
   0%
   0%
   0%
   0%
   0%
   0%
  99%
  92%
   0%
a cow on the beach
   0%
   0%
   0%
   0%
   0%
   0%
   0%
   0%
   1%
   0%
   0%
a cow in the prairie
   0%
   0%
   0%
   0%
   0%
   0%
   0%
   0%
   0%
   0%
   0%
the real mountain view
   0%
   0%
   0%
   0%
   0%
   0%
   0%
   0%
   0%
   0%
   0%
Zürich
   0%
   0%
   0%
   0%
   0%
   0%
   0%
   0%
   0%
   0%
   0%
San Francisco
   0%
   0%
   0%
   0%
   0%
   0%
   0%
   0%
   0%
   0%
 100%
a picture of a laptop with the lockscreen on, a cup of cappucino, salt and pepper grinders. The view through the window reveals lake Zürich and the Alps in the background of the city." + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "application/javascript": [ + "\n", + " function update(b) {\n", + " for(var iimg = 0; iimg < logits.length; iimg++) {\n", + " for(var itxt = 0; itxt < logits[iimg].length; itxt++) {\n", + " const el = document.getElementById(`p_${iimg}_${itxt}`);\n", + " const p = Math.round(100 / (1 + Math.exp(-logits[iimg][itxt] - b)));\n", + " const pad = p < 10.0 ? ' ' : p < 100.0 ? ' ' : ''\n", + " el.innerHTML = pad + (p).toFixed(0) + '%';\n", + "\n", + " const td = document.getElementById(`td_${iimg}_${itxt}`);\n", + " const c = cmap[Math.round(p / 100 * (cmap.length - 1))];\n", + " td.style.backgroundColor = c;\n", + " }\n", + " }\n", + " }\n", + " " + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "application/javascript": [ + "\n", + " const value = document.querySelector(\"#value\");\n", + " const input = document.querySelector(\"#b\");\n", + " value.textContent = input.value;\n", + " input.addEventListener(\"input\", (event) => {\n", + " value.textContent = event.target.value;\n", + " update(event.target.value);\n", + " });\n", + " " + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "application/javascript": [ + "update(-12.661874771118164)" + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "application/javascript": [ + "google.colab.output.resizeIframeToContent()" + ] + }, + "metadata": {} + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "# More international examples (choose i18n model for this)" + ], + "metadata": { + "id": "f5lIiaD700UK" + } + }, + { + "cell_type": "code", + "source": [ + "#@title Load and embed images\n", + "\n", + "import big_vision.pp.builder as pp_builder\n", + "import big_vision.pp.ops_general\n", + "import big_vision.pp.ops_image\n", + "import big_vision.pp.ops_text\n", + "import PIL\n", + "\n", + "!wget -q 'https://live.staticflickr.com/4152/5189547658_3b2a7126cb_b.jpg' -O ants_climbing_a_tree_food.jpg\n", + "!wget -q 'https://storage.googleapis.com/big_vision/siglip/pexels-poranimm-athithawatthee-842401.jpg' -O ants_climbing_tree.jpg\n", + "!wget -q 'https://images.rawpixel.com/image_1300/cHJpdmF0ZS9zdGF0aWMvaW1hZ2Uvd2Vic2l0ZS8yMDIyLTA0L2xyL3B4OTE3NDYyLWltYWdlLWt3eW8ydmxrLmpwZw.jpg' -O lion_head.jpg\n", + "!wget -q 'https://images.rawpixel.com/image_1300/cHJpdmF0ZS9sci9pbWFnZXMvd2Vic2l0ZS8yMDIzLTA5L3Jhd3BpeGVsX29mZmljZV8yN19taW5pbWFsX3NpbXBsZV9fbGlvbl9fcGFwZXJfY29sbGFnZV9taW5pbWFsX183OGRlOGU3OS02ZTE3LTQ2YzAtYTUyOS02ZDAxM2YzNDg0OWVfMi5qcGc.jpg' -O lion_head_red.jpg\n", + "!wget -q https://live.staticflickr.com/232/551040940_87299a85ec_h.jpg -O meat_ball.jpg\n", + "!wget -q https://storage.googleapis.com/big_vision/siglip/squirrel_fish.jpg -O squirrel_fish.jpg\n", + "!wget -q 'https://ideogram.ai/api/images/direct/F3lMxBprSk6ligq5Vy3XSw' -O squirrel_fish2.jpg\n", + "!wget -q 'https://pbs.twimg.com/media/FTyEyxyXsAAyKPc?format=jpg&name=small' -O cow_beach.jpg\n", + "!wget -q 'https://storage.googleapis.com/big_vision/siglip/cow_beach2.jpg' -O cow_beach2.jpg\n", + "\n", + "\n", + "images = [PIL.Image.open(fname) for fname in [\n", + " 'ants_climbing_a_tree_food.jpg',\n", + " 'ants_climbing_tree.jpg',\n", + " 'meat_ball.jpg',\n", + " 'lion_head.jpg',\n", + " 'lion_head_red.jpg',\n", + " 'fried_fish.jpeg',\n", + " 'squirrel_fish.jpg',\n", + " 'squirrel_fish2.jpg',\n", + " 'cow_beach.jpg',\n", + " 'cow_beach2.jpg',\n", + "]]\n", + "\n", + "pp_img = pp_builder.get_preprocess_fn(f'resize({RES})|value_range(-1, 1)')\n", + "imgs = np.array([pp_img({'image': np.array(image)})['image'] for image in images])\n", + "zimg, _, out = model.apply({'params': params}, imgs, None)\n", + "\n", + "print(imgs.shape, zimg.shape)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "YsK74v2J04Xp", + "outputId": "63f024ad-205c-4dd3-a5af-4dfd5ff198ca" + }, + "execution_count": 4, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "/usr/local/lib/python3.10/dist-packages/tensorflow_addons/utils/tfa_eol_msg.py:23: UserWarning: \n", + "\n", + "TensorFlow Addons (TFA) has ended development and introduction of new features.\n", + "TFA has entered a minimal maintenance and release mode until a planned end of life in May 2024.\n", + "Please modify downstream libraries to take dependencies from other repositories in our TensorFlow community (e.g. Keras, Keras-CV, and Keras-NLP). \n", + "\n", + "For more information see: https://github.com/tensorflow/addons/issues/2807 \n", + "\n", + " warnings.warn(\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "(10, 256, 256, 3) (10, 768)\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "#@title Tokenize and embed texts\n", + "\n", + "texts = [\n", + " '蚂蚁上树',\n", + " '肉末粉丝',\n", + " 'ants climbing a tree',\n", + " 'minced pork rice noodle',\n", + " #\n", + " '红烧狮子头',\n", + " 'red burned lion head',\n", + " 'lion head',\n", + " 'meat ball with soy sauce',\n", + " #\n", + " '松鼠鳜鱼',\n", + " 'squirrel',\n", + " 'squirrel and fish',\n", + " 'squirrel mandarinfish',\n", + " 'squirrel mandarin fish',\n", + " 'sweet and sour mandarin fish',\n", + " #\n", + " 'cow',\n", + " 'a cow in a tuxedo',\n", + " 'a cow on the beach',\n", + " 'a cow in the prairie',\n", + " 'une vache sur la plage',\n", + " 'eine Kuh am Strand',\n", + " 'วัวอยู่ที่ชายหาด',\n", + " '一只躺在沙滩上的牛',\n", + " '一只沙滩上的牛',\n", + " 'корова на пляже',\n", + " 'بقرة على الشاطئ',\n", + "]\n", + "\n", + "TOKENIZERS = {\n", + " 32_000: 'c4_en',\n", + " 250_000: 'mc4',\n", + "}\n", + "pp_txt = pp_builder.get_preprocess_fn(f'tokenize(max_len={SEQLEN}, model=\"{TOKENIZERS[VOCAB]}\", eos=\"sticky\", pad_value=1, inkey=\"text\")')\n", + "txts = np.array([pp_txt({'text': text})['labels'] for text in texts])\n", + "_, ztxt, out = model.apply({'params': params}, None, txts)\n", + "\n", + "print(txts.shape, ztxt.shape)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "dAzAuYJh1eQ3", + "outputId": "6c07c1a2-c236-4b68-b7e3-f92dcc070fcc" + }, + "execution_count": 5, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "(25, 64) (25, 768)\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "make_table(zimg, ztxt, out)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 633 + }, + "id": "JlMwn6K1-62i", + "outputId": "6b8fa113-06f3-492c-ffa7-942d4799cae3" + }, + "execution_count": 8, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "application/javascript": [ + "var logits = [[15.194855690002441, 14.548081398010254, 4.362802505493164, 8.915352821350098, 0.12249733507633209, -1.8669313192367554, -2.1026358604431152, 4.83571195602417, -1.48772132396698, -2.885380744934082, -3.757584571838379, -9.74438190460205, -6.739628791809082, 1.0982742309570312, -1.8383992910385132, -8.639388084411621, -8.514564514160156, -8.664950370788574, -9.010446548461914, -8.695591926574707, -0.29446348547935486, -2.3145699501037598, 0.3301776945590973, -9.183826446533203, -7.548545837402344], [3.1235272884368896, -2.662849187850952, 15.499628067016602, -5.6270432472229, -8.800381660461426, -5.2857537269592285, -4.901862621307373, -8.64078426361084, -8.457619667053223, -0.7642378211021423, -6.292320251464844, -6.919025421142578, -5.699285984039307, -6.146625518798828, -1.7575650215148926, -9.384129524230957, -6.215198040008545, -6.763903617858887, -6.789668560028076, -6.646523952484131, 2.078498125076294, 0.1571565568447113, 1.2640687227249146, -4.958133697509766, -4.504084587097168], [2.4513118267059326, 3.711794853210449, -2.7506296634674072, 6.2139153480529785, 12.623679161071777, -2.242187261581421, -0.873506486415863, 12.75291633605957, 5.779244422912598, -3.411043405532837, -2.7684485912323, 0.8032691478729248, 2.4132730960845947, 10.139656066894531, -1.5548374652862549, -7.363276481628418, -10.937602043151855, -10.354545593261719, -12.12853717803955, -11.330802917480469, -3.7032158374786377, -4.167450428009033, -2.857227087020874, -12.429163932800293, -10.023411750793457], [-7.848373889923096, -8.82786750793457, -4.246535301208496, -11.672212600708008, -4.754408836364746, 5.023717403411865, 10.245930671691895, -9.671830177307129, -5.305540561676025, 0.939210832118988, -3.7660276889801025, -6.9834089279174805, -5.540616512298584, -7.520627498626709, 0.6897578239440918, -4.008193016052246, -3.137038230895996, -2.492392063140869, -3.349771022796631, -2.571514129638672, -0.5961494445800781, 1.920261025428772, -0.5972135066986084, -3.192373275756836, -2.797152280807495], [-7.591951370239258, -9.57149887084961, -7.410569667816162, -10.887884140014648, -2.1018383502960205, 10.839365005493164, 12.306414604187012, -8.755990028381348, -6.4970011711120605, 1.732677698135376, -1.484777808189392, -3.788830280303955, -2.954533338546753, -4.137475967407227, 1.2805907726287842, -4.848579406738281, -4.63262939453125, -4.869859218597412, -4.654362201690674, -4.7860589027404785, -0.6505587697029114, -0.741170346736908, -1.2220640182495117, -5.068485260009766, -4.302990913391113], [0.38381102681159973, -0.5291793346405029, -4.558042049407959, -0.798613965511322, 1.3992505073547363, -3.269932508468628, -2.243269205093384, 3.4091484546661377, 13.690838813781738, -3.199730396270752, 2.4068713188171387, 4.793602466583252, 6.522286415100098, 12.24045467376709, -0.973887026309967, -5.842926025390625, -8.813263893127441, -10.347548484802246, -10.193572044372559, -9.09493350982666, 0.17290785908699036, -2.690534830093384, 0.4429348409175873, -10.299919128417969, -7.2381591796875], [-11.066581726074219, -10.138232231140137, -5.7180986404418945, -11.073030471801758, -9.701227188110352, 1.2774648666381836, 0.6818075776100159, -11.766871452331543, 7.582111358642578, 6.539462089538574, 13.692913055419922, 11.608633041381836, 12.523263931274414, 2.838015556335449, 0.06712919473648071, -8.434947967529297, -5.371018409729004, -7.046348571777344, -5.160297393798828, -4.178375244140625, -1.4383944272994995, -1.4511940479278564, -0.826172947883606, -4.657361030578613, -4.185240745544434], [-3.598116874694824, -6.576178073883057, -2.7102479934692383, -8.999201774597168, -6.829661846160889, -5.066120147705078, -1.7694122791290283, -7.724926471710205, 0.23896828293800354, 11.48562240600586, 18.98163414001465, 10.054450035095215, 10.879026412963867, -0.23405185341835022, 1.1370410919189453, -4.135552406311035, -0.34031882882118225, -1.2078852653503418, -1.5318009853363037, -3.0245869159698486, -0.7356898188591003, 2.346902847290039, 1.158348560333252, -1.281561017036438, -1.2338509559631348], [-9.843914985656738, -9.799589157104492, -6.7716383934021, -9.883660316467285, -12.059309005737305, -6.143594264984131, -3.1696691513061523, -7.953651428222656, -14.6300048828125, -5.153632164001465, -9.101214408874512, -8.86422061920166, -7.411843299865723, -9.261401176452637, 12.271851539611816, 7.439639091491699, 19.08420181274414, 9.05471420288086, 18.37834930419922, 18.505441665649414, 14.171286582946777, 12.338602066040039, 14.924001693725586, 17.368127822875977, 17.931604385375977], [-9.439372062683105, -8.37105941772461, -9.730523109436035, -9.263359069824219, -7.634936809539795, -5.775638580322266, -0.2548319399356842, -6.097734451293945, -12.719864845275879, -5.2038702964782715, -8.733600616455078, -8.040817260742188, -6.40618896484375, -8.534762382507324, 11.509172439575195, 18.91118049621582, 14.150744438171387, 6.8233747482299805, 13.563973426818848, 13.099942207336426, 10.563776016235352, 10.233851432800293, 11.005309104919434, 15.13718032836914, 14.48193359375]];\n", + "//# sourceURL=js_ca0f68d49c" + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "application/javascript": [ + "var cmap = ['transparent', '#f6fcf4', '#f4fbf2', '#f3faf0', '#f1faee', '#f0f9ec', '#eff9eb', '#edf8e9', '#ecf8e8', '#eaf7e6', '#e8f6e4', '#e7f6e3', '#e5f5e1', '#e4f5df', '#e1f3dc', '#def2d9', '#dcf2d7', '#daf0d4', '#d7efd1', '#d5efcf', '#d2edcc', '#d0edca', '#cdecc7', '#cbeac4', '#c9eac2', '#c6e8bf', '#c3e7bc', '#c0e6b9', '#bce4b5', '#bae3b3', '#b6e2af', '#b4e1ad', '#b0dfaa', '#acdea6', '#aadda4', '#a7dba0', '#a3da9d', '#a0d99b', '#9cd797', '#99d595', '#95d391', '#91d28e', '#8ed08b', '#8ace88', '#87cd86', '#83cb82', '#7fc97f', '#7cc87c', '#78c679', '#73c476'];\n", + "//# sourceURL=js_b212ab59e1" + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "" + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "\n", + "

\n", + " \n", + " \n", + " \n", + "

\n", + " \n", + "
  91%
   0%
   0%
   0%
   0%
   0%
   0%
   0%
   0%
   0%
蚂蚁上树
  84%
   0%
   0%
   0%
   0%
   0%
   0%
   0%
   0%
   0%
肉末粉丝
   0%
  93%
   0%
   0%
   0%
   0%
   0%
   0%
   0%
   0%
ants climbing a tree
   2%
   0%
   0%
   0%
   0%
   0%
   0%
   0%
   0%
   0%
minced pork rice noodle
   0%
   0%
  43%
   0%
   0%
   0%
   0%
   0%
   0%
   0%
红烧狮子头
   0%
   0%
   0%
   0%
  11%
   0%
   0%
   0%
   0%
   0%
red burned lion head
   0%
   0%
   0%
   7%
  36%
   0%
   0%
   0%
   0%
   0%
lion head
   0%
   0%
  47%
   0%
   0%
   0%
   0%
   0%
   0%
   0%
meat ball with soy sauce
   0%
   0%
   0%
   0%
   0%
  69%
   0%
   0%
   0%
   0%
松鼠鳜鱼
   0%
   0%
   0%
   0%
   0%
   0%
   0%
  20%
   0%
   0%
squirrel
   0%
   0%
   0%
   0%
   0%
   0%
  69%
 100%
   0%
   0%
squirrel and fish
   0%
   0%
   0%
   0%
   0%
   0%
  22%
   6%
   0%
   0%
squirrel mandarinfish
   0%
   0%
   0%
   0%
   0%
   0%
  41%
  12%
   0%
   0%
squirrel mandarin fish
   0%
   0%
   6%
   0%
   0%
  34%
   0%
   0%
   0%
   0%
sweet and sour mandarin fish
   0%
   0%
   0%
   0%
   0%
   0%
   0%
   0%
  35%
  20%
cow
   0%
   0%
   0%
   0%
   0%
   0%
   0%
   0%
   0%
 100%
a cow in a tuxedo
   0%
   0%
   0%
   0%
   0%
   0%
   0%
   0%
 100%
  78%
a cow on the beach
   0%
   0%
   0%
   0%
   0%
   0%
   0%
   0%
   2%
   0%
a cow in the prairie
   0%
   0%
   0%
   0%
   0%
   0%
   0%
   0%
 100%
  66%
une vache sur la plage
   0%
   0%
   0%
   0%
   0%
   0%
   0%
   0%
 100%
  55%
eine Kuh am Strand
   0%
   0%
   0%
   0%
   0%
   0%
   0%
   0%
  78%
   9%
วัวอยู่ที่ชายหาด
   0%
   0%
   0%
   0%
   0%
   0%
   0%
   0%
  37%
   7%
一只躺在沙滩上的牛
   0%
   0%
   0%
   0%
   0%
   0%
   0%
   0%
  88%
  13%
一只沙滩上的牛
   0%
   0%
   0%
   0%
   0%
   0%
   0%
   0%
  99%
  90%
корова на пляже
   0%
   0%
   0%
   0%
   0%
   0%
   0%
   0%
  99%
  83%
بقرة على الشاطئ" + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "application/javascript": [ + "\n", + " function update(b) {\n", + " for(var iimg = 0; iimg < logits.length; iimg++) {\n", + " for(var itxt = 0; itxt < logits[iimg].length; itxt++) {\n", + " const el = document.getElementById(`p_${iimg}_${itxt}`);\n", + " const p = Math.round(100 / (1 + Math.exp(-logits[iimg][itxt] - b)));\n", + " const pad = p < 10.0 ? ' ' : p < 100.0 ? ' ' : ''\n", + " el.innerHTML = pad + (p).toFixed(0) + '%';\n", + "\n", + " const td = document.getElementById(`td_${iimg}_${itxt}`);\n", + " const c = cmap[Math.round(p / 100 * (cmap.length - 1))];\n", + " td.style.backgroundColor = c;\n", + " }\n", + " }\n", + " }\n", + " " + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "application/javascript": [ + "\n", + " const value = document.querySelector(\"#value\");\n", + " const input = document.querySelector(\"#b\");\n", + " value.textContent = input.value;\n", + " input.addEventListener(\"input\", (event) => {\n", + " value.textContent = event.target.value;\n", + " update(event.target.value);\n", + " });\n", + " " + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "application/javascript": [ + "update(-12.885268211364746)" + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "application/javascript": [ + "google.colab.output.resizeIframeToContent()" + ] + }, + "metadata": {} + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "Explanation for non-Chinese speakers:\n", + "\n", + "- The first dish is literally called \"ants climbing a tree\" in Chinese.\n", + "- The second dish is literally called \"red burned lion head\" in Chinese.\n", + "- The third dish is literally called \"squirrel mandarinfish\" in Chinese.\n", + "\n", + "We are looking for more interesting examples that highlight culture-language aspects and where a non-EN model should \"get it\" while an EN-only does not." + ], + "metadata": { + "id": "bNGoftU3y4UQ" + } + }, + { + "cell_type": "markdown", + "source": [ + "# Example image credits\n", + "\n", + "- The apple and apple + iPod images are from OpenAI.\n", + "- [Cold drink on hot day](https://unsplash.com/fr/photos/hQHm2D1fH70).\n", + "- [Hot drink on cold day](https://www.rawpixel.com/image/3282934).\n", + "- Cows on beach were created by Chitwan Saharia using the Imagen model and shared with permission.\n", + "- [\"ant climbing tree\" noodles](https://www.flickr.com/photos/avlxyz/5189547658)\n", + "- [actual ants climbing on a tree](https://www.pexels.com/photo/macro-photo-of-five-orange-ants-842401/)\n", + "- [real lion head](https://www.rawpixel.com/image/5941715/free-public-domain-cc0-photo)\n", + "- [cartoon red lion head](https://www.rawpixel.com/image/12447997/image-texture-paper-png)\n", + "- Collaged [squirrel](https://www.pexels.com/photo/brown-squirrel-47547/) and [fish](https://zh.wikipedia.org/zh-hans/%E9%B3%9C%E9%B1%BC) images.\n", + "- cartoon [squirrel and fish](https://ideogram.ai/g/zgoma01ASS21U1YwIC7MrA/2) generated by [ideogram.ai](http://ideogram.ai) [with permission](https://x.com/ideogram_ai/status/1697428471184515316?s=20).\n", + "- The remaining pictures are personal photos taken by the authors, long after the models were trained." + ], + "metadata": { + "id": "etDZ3sl4kZ_q" + } + } + ] +} diff --git a/big_vision/configs/proj/image_text/common.py b/big_vision/configs/proj/image_text/common.py new file mode 100644 index 0000000000000000000000000000000000000000..96e5235c557a19aaf8500f03759c0655ea28956d --- /dev/null +++ b/big_vision/configs/proj/image_text/common.py @@ -0,0 +1,127 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Snippets and constants used a lot in image-text configs.""" + +import ml_collections + + +# pylint: disable=line-too-long +inits = { + # Downloaded & extracted from original repo: + # https://github.com/google-research/bert + 'bert_base': ('base', 'gs://vit_models/lit/bert/uncased_L-12_H-768_A-12'), + 'bert_large': ('large', 'gs://vit_models/lit/bert/uncased_L-uncased_L-24_H-1024_A-16'), + # Recommended "How to train your ViT..." checkpoints from + # https://github.com/google-research/vision_transformer#available-vit-models + 'B/32': ('B/32', 'gs://vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0.npz'), + 'B/16': ('B/16', 'gs://vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz'), + 'L/16': ('L/16', 'gs://vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_strong1-wd_0.1-do_0.0-sd_0.0.npz'), +} +# pylint: enable=line-too-long + + +def _square875(sz): + return f'resize({int(sz/0.875)})|central_crop({sz})|value_range(-1,1)' + + +def _aspect75(sz): + return f'resize_small({int(sz/0.75)})|central_crop({sz})|value_range(-1,1)' + + +def _drop_no_real_label(f): + return len(f['real_label']) > 0 + + +def _drop_no_imagenet(f): + return len(f['labels_imagenet']) > 0 + + +DISCLF_DATASET_OVERRIDES = { + 'imagenet2012': {'class_names': 'clip', 'split': 'validation'}, + 'imagenet2012_minival': { + 'dataset_name': 'imagenet2012', + 'class_names': 'clip', + 'split': 'train[99%:]', + }, + 'imagenet2012_real': { + 'split': 'validation', + 'class_names': 'clip', + 'class_names_dataset_name': 'imagenet2012', + 'pp_img': lambda sz: ( + _square875(sz) + '|pad_to_shape(inkey="real_label", outkey="label", shape=[10], pad_value=-1)|keep("label", "image")'), # pylint: disable=line-too-long + 'pre_filter_fn': _drop_no_real_label, + }, + 'imagenet_v2': {'class_names': 'clip'}, + 'imagenet_a': { + 'class_names': 'clip', + 'pp_img': lambda sz: _aspect75(sz) + '|map("i1k_i1ka")', + }, + 'imagenet_r': { + 'class_names': 'clip', + 'pp_img': lambda sz: _square875(sz) + '|map("i1k_i1kr")', + }, +} + + +def get_disclf(sz, *, pp_txt=None, dataset_names=('imagenet2012',), **kw): + """Returns config for discriminative_classifier of specified datasets.""" + config = ml_collections.ConfigDict(dict( + dataset_names=list(dataset_names), + type='proj.image_text.discriminative_classifier', + prefix='z/0shot/', + pp_img=_square875(sz), + dataset_overrides={}, + cache_final=True, + **kw, + )) + if pp_txt: + config.pp_txt = pp_txt + for name in dataset_names: + if name in DISCLF_DATASET_OVERRIDES: + config.dataset_overrides[name] = {**DISCLF_DATASET_OVERRIDES[name]} + d = config.dataset_overrides[name] + if 'pp_img' in d and callable(d['pp_img']): + with d.ignore_type(): + d['pp_img'] = d['pp_img'](sz) + return config + + +def get_coco( + *, + pp_img='resize(224)|value_range(-1, 1)', + pp_txt='tokenize(max_len=16, inkey="texts", eos="sticky", pad_value=1)', + prefix='z/retr/coco_', + **kw): + """Returns config for mscoco retrieval zero-shot. + + Args: + pp_img: Pre-processing string for "image" feature. + pp_txt: Pre-processing string for texts (expected to tokenize "texts" to + "labels"). + prefix: Prefix to use for metrics. + **kw: Other config settings, most notably log_{steps,percent,...}. + + Returns: + `ConfigDict` that can be used as a retrieval evaluator configuration. + """ + return ml_collections.ConfigDict({ + 'type': 'proj.image_text.retrieval', + 'pp_txt': pp_txt, + 'pp_img': pp_img, + 'prefix': prefix, + 'dataset': 'coco_captions', + 'txt_name': ('captions', 'text'), + **kw, + }) diff --git a/big_vision/configs/proj/image_text/lit.ipynb b/big_vision/configs/proj/image_text/lit.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..db3c9536c216102e8d0343d0283a79517eb0302c --- /dev/null +++ b/big_vision/configs/proj/image_text/lit.ipynb @@ -0,0 +1,1903 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "source": [ + "This Colab shows some example code how to make use of the\n", + "[LiT: Zero-Shot Transfer with Locked-image text Tuning](https://arxiv.org/abs/2111.07991)\n", + "models in the `big_vision` codebase.\n", + "\n", + "For more information refer to\n", + "\n", + "https://github.com/google-research/big_vision/blob/main/README.md\n", + "\n", + "https://github.com/google-research/big_vision/blob/main/big_vision/configs/proj/image_text/README.md" + ], + "metadata": { + "id": "3OCq_g6vBiWX" + } + }, + { + "cell_type": "markdown", + "source": [ + "### Initialize" + ], + "metadata": { + "id": "hg1hy3ER9LHT" + } + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "id": "i0ws1pPjl6nY", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "175aa7e3-4c88-47aa-f0c9-6b3a36b87478" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "fatal: destination path 'big_vision' already exists and is not an empty directory.\n", + "Already up to date.\n" + ] + } + ], + "source": [ + "!git clone --branch=main --depth=1 https://github.com/google-research/big_vision\n", + "!cd big_vision \u0026\u0026 git pull" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "id": "xa2C2eTej-XX", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "f5b50c10-8ecb-4873-841f-3846bca1dc18" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", + " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", + " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n" + ] + } + ], + "source": [ + "!pip install -qr big_vision/big_vision/requirements.txt" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "id": "f09dX0cekPRa" + }, + "outputs": [], + "source": [ + "import sys\n", + "bv_path = './big_vision'\n", + "if bv_path not in sys.path:\n", + " sys.path.insert(0, bv_path)\n", + "\n", + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "source": [ + "from absl import flags\n", + "from absl import logging\n", + "import tensorflow_datasets as tfds\n", + "from google.colab import files\n", + "\n", + "logging.set_verbosity(logging.INFO)\n", + "\n", + "def set_max_height(max_height):\n", + " \"\"\"Limits scrollable area of output cell to `max_height` pixels.\"\"\"\n", + " import IPython.display\n", + " IPython.display.display(IPython.display.Javascript('''\n", + " google.colab.output.setIframeHeight(0, true, {maxHeight: %d})\n", + " ''' % max_height))" + ], + "metadata": { + "id": "3J0Ilcu6LczM" + }, + "execution_count": 4, + "outputs": [] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "id": "ij41ZsIkmVQB", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "03058240-b85b-46e2-e944-92a980077bcf" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Mon Mar 13 10:56:53 2023 \n", + "+-----------------------------------------------------------------------------+\n", + "| NVIDIA-SMI 525.85.12 Driver Version: 525.85.12 CUDA Version: 12.0 |\n", + "|-------------------------------+----------------------+----------------------+\n", + "| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n", + "| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n", + "| | | MIG M. |\n", + "|===============================+======================+======================|\n", + "| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 |\n", + "| N/A 68C P0 30W / 70W | 0MiB / 15360MiB | 0% Default |\n", + "| | | N/A |\n", + "+-------------------------------+----------------------+----------------------+\n", + " \n", + "+-----------------------------------------------------------------------------+\n", + "| Processes: |\n", + "| GPU GI CI PID Type Process name GPU Memory |\n", + "| ID ID Usage |\n", + "|=============================================================================|\n", + "| No running processes found |\n", + "+-----------------------------------------------------------------------------+\n" + ] + }, + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "[StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)]" + ] + }, + "metadata": {}, + "execution_count": 5 + } + ], + "source": [ + "# Set up Colab TPUs (if available).\n", + "import os\n", + "if 'COLAB_TPU_ADDR' in os.environ:\n", + " import jax.tools.colab_tpu\n", + " jax.tools.colab_tpu.setup_tpu()\n", + "else:\n", + " !nvidia-smi\n", + "import jax\n", + "jax.devices()" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "id": "cCqkcNAqz7xr" + }, + "outputs": [], + "source": [ + "# Uncomment this snippet to access a private GCS bucket with prepared\n", + "# TFDS datasets.\n", + "\n", + "# import tensorflow_datasets as tfds\n", + "# from google.colab import auth\n", + "# auth.authenticate_user() # Required to access access protected GCS buckets.\n", + "# import os\n", + "# os.environ['TFDS_DATA_DIR'] = 'gs://tensorflow-datasets/datasets'\n", + "# builder = tfds.builder('coco_captions')\n", + "# b = next(iter(builder.as_dataset('val')))" + ] + }, + { + "cell_type": "markdown", + "source": [ + "### Load training run data" + ], + "metadata": { + "id": "TiJsesb1Bds8" + } + }, + { + "cell_type": "code", + "source": [ + "import json\n", + "import re\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import pandas as pd\n", + "from tensorflow.io import gfile\n", + "\n", + "def plot_metrics(workdir, regexes, cols=4, cmp={}):\n", + " \"\"\"Plots metrics matching `regexes` from `workdir`.\"\"\"\n", + " df = pd.DataFrame([json.loads(line) for line in gfile.GFile(f'{workdir}/big_vision_metrics.txt')])\n", + " df = df.set_index('step')\n", + " ms = []\n", + " for regex in regexes:\n", + " for col in df.columns:\n", + " if col not in ms and re.match(regex, col):\n", + " ms.append(col)\n", + " rows = int(np.ceil(len(ms) / cols))\n", + " _, axs = plt.subplots(rows, cols, figsize=(4*cols, 3*rows))\n", + " if rows == 1: axs = [axs]\n", + " for i, m in enumerate(ms):\n", + " ax = axs[i // cols][i % cols]\n", + " df[m].dropna().plot(ax=ax)\n", + " if m in cmp: cmp[m].dropna().plot(ax=ax)\n", + " ax.set_title(m)\n", + " plt.tight_layout()\n", + " return df\n", + "\n", + "# Reference run using the tiny 80k \"coco-captions\" TFDS dataset.\n", + "df = plot_metrics('gs://vit_models/lit/big_vision/coco_B16B', [\n", + " 'training', 'val/loss', 'img/', '.*net2012',\n", + " '.*cifar100', '.*pet', '.*@1$',\n", + "])" + ], + "metadata": { + "id": "EGSsrwtnBfD5", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 306 + }, + "outputId": "96985539-74f1-41cd-b4fe-89c1901b447d" + }, + "execution_count": 7, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "\u003cFigure size 1152x432 with 8 Axes\u003e" + ], + "image/png": "iVBORw0KGgoAAAANSUhEUgAABHYAAAGoCAYAAAAjPWJ4AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/NK7nSAAAACXBIWXMAAAsTAAALEwEAmpwYAADRKUlEQVR4nOzdd5ycVb348c932vaSsptseicJJBAITQERFCkieG1gw4rlYrnqVfFn7+XqtRcuKioCghURCyooSg0QEiCBFALZtE3blm1Tvr8/zpnZZyfbkuzuM7P5vl+vec0zT5sz7czzfJ9zvkdUFWOMMcYYY4wxxhhTfCJhF8AYY4wxxhhjjDHGHB4L7BhjjDHGGGOMMcYUKQvsGGOMMcYYY4wxxhQpC+wYY4wxxhhjjDHGFCkL7BhjjDHGGGOMMcYUKQvsGGOMMcYYY4wxxhQpC+wc5UTk+yLysZFe9wjKoyKyYDSfwxhjjDHFT0QeF5Gzwy6HMWb0ichdIvKWUdr3LBFpF5HoaOzfmLFggZ0iJyJbROQFh7u9qr5dVT8z0usaY8xoEJGzRaRxkOVHVCcaY4qHqh6rqneNxr5F5HIRuWE09u33P22wusyY8UZEnhSRReJ8SUT2+tuXRERG+bn7PTbI/s5V9VlVrVTV9GiWYyyJyCdF5Pq8ef8jIhtEpE1E1ovI6/OWnyAiD4lIh78/IbDs+SJyp4i0iMiWvO3qReRGEdnul/9bRE4dzddnDmaBnXFMRGJhl8EYY0abnSAZY0bBRcDto7j/C4E/jdTO7JjPFDIRmQ9EVfUp4ErgUuB4YDlwMfC2kIo22r/zQnMA937XAFcA3xCR5wCISAL4HXA9MAH4CfA7Pz+77Y+A/+5nv5XAg8BJwES/7R9EpHL0XsqRG2/1pgV2ipiI/AyYBfzeNx/8oO/K9GYReRb4u1/vFhHZ6SOo/xSRYwP7uE5EPuunzxaRRhF5v4g0icgOEXnjYa47SUR+LyKtIvKgiHxWRP51iK+vRkR+KiK7ReQZEfmoiET8sgUi8g//mvaIyC/8fBGR//VlahWRtSJy3GG/ycaYYjCiJ0jGmOKQvQrvr0zfIiLX+yvRa33LgKv98cBWETkvsN1cfzzUJiJ/FZHvBK9s+2ONFwJ/EpFSv9+9ItLsj2mm+PVqROSH/hhomz/WiQb281YRWeef5wkROTFQ/AvxJ5QiMlNEfu2Pd/aKyLez5fDHPs/41/FTEanxy+YMcMz3Jv+c+0XkzyIye9Q+AGMCRORV/nwke+sWkbv84mAA5Qrgq6raqKrbgK8Cb/D7GPD35s0W1xqkTUT+IiKTA8//EnHdM5vFddta4ucfdL7k5wd/59nfU8wvu8v/nu/x2/zen9v8PHBuMyfw3N/w9UyruJYuZwaWlYnIT/xvcp2487XGwPJpIvIr//t/WkTeHVj2SRG52f/22/zrWznUtiJyPvARIPuZPAqgqp9Q1fWqmlHV+4G7gdP97s4GYsDXVbVbVb8JCHCO3/YBVf0ZsDn/s1fVzar6NVXdoappVb0GSADH9PNVyRGR+SLyd/957/Hvb21geb91o1/Wb/0qeWk9pP/z1w+JyE7gxyIyQURu88+x30/PCGw/UUR+LK410n4R+a2f/5iIXBxYL+5fw4rBXvNossBOEVPV1wHPAheraiVws1/0PGAJ8CL/+I/AQqAeeBj4+SC7nYqL4k4H3gx8R0QmHMa638FFdqfiKvArDvX1Ad/y+5/nX9PrgWzw6DPAX3AR5Rl+XYDzgLOARX7bVwJ7D+O5jTGjyP+p/jJv3jdE5Jsi8sbAn/VmERnqSl7uBClvfyUi8nX/Z7zdT5f4ZZP9n3eziOwTkbulN3D8IXEnaW3imo6fO1Kv2xgzai4GfoY7LngE+DPuOHc68GngB4F1bwAeACYBnwRel7evU4DNqroHd/xSA8z0678d6PTrXQekgAXACtwxyFsAROQVft+vB6qBl+CPR0QkjjtWuUNcIOg24Blgji/vTX7/b/C35+OOhSqB3ImNlzvmE5FLcCdz/wHU4U7abhzkPTNmxKjqL3x3pkpgGi4AkP3+XQj8wU8fCzwa2PRRPw8G/70BvBp3LlCPCxx8AEBEFvnnei/uu387LpCTyD9fUtUv+30Ff+f9uQxXN0wH5gP3Aj/GtUhZB3wisO6DwAl+2Q3ALSJS6pd9AvfbnocLJL02u5E/7vi9fw+mA+cC7xWR7DkcuLrjJqAWuBVfBwy2rar+Cfg8kP1Mjs9/cSJSBpwMPO5nHQusUVUNrLaG3s9m2MR14UoAG4daFfgC7vuyBPe5f9LvY8C6cbD6dRim4j6n2bjWYxHc5zobFwDspG89+zOgHPc+1AP/6+f/lMBnifuO71DVR4ZZjpGnqnYr4huwBXiBn54DKDBvkPVr/To1/vF1wGf99Nm4L3MssH4TcNqhrAtEgSRwTGDZZ4F/DeP1KO4AKQr0AEsDy94G3OWnfwpcA8zI2/4c4ClfjkjYn4/d7Ga3/m+4P9AOoMo/jgI7/G/3ItxBlOBOWjqAE/16ZwONgf3EgT2B/QTrxE8D9+H+iOuAe4DP+GVfAL7vt48DZ/rnOwbYCkzz680B5of9ftnNbnY7+Jb9veMO8O8IzL8YaMd1/QCo8scXtbgD9xRQHlj/euD6wOPPAB/z02/ydcfyvOeeAnQDZYF5lwN3+uk/A+8ZoNznAn/z06cDu4PHU4H1/ga8M/D4GH98FaOfYz7chbw3Bx5HfP05O+zPym5Hz81/724Dvucfl+NOukv84zSwOLD+Qv9dloF+b369u4CPBh6/E/iTn/4YcHNeGbYBZ/vHuWODwDrB33n29xQLPNf/C6z7VeCPgccXA6sHeQ/2A8f76c3AiwLL3oI/jgFOBZ7N2/Zq4Md++pPAXwPLlgKdh7Dt9YOU8Se41s4SeA9vylvn58An8+a9ANgyyH6rgbXA1Yfx3bkUeMRPD1Y3Dla/KrAg8Pg6+p6/9gClg5ThBGC/n24AMsCEftabBrQB1f7xL4EPjtbvajg3a7EzPm3NTohIVES+KCKbRKQVV7EBTO53S9irqqnA4w7cFaJDWbcOd9CxNbAsOD0ck3EnW88E5j2Di9YCfBD3B/CAb5b4JgBV/TsuyvodoElErhGR6kN8bmPMKFPVZ3AtCF/qZ50DdKjqfar6B1XdpM4/cK3zzhxgV2cBj6pqWz/LXgN8WlWbVHU38Cl6r8wncX/Ys1U1qap3q/tnTgMlwFIRiavqFlXdNBKv2RgzqnYFpjuBPdqbCDV7xT/bkmCfqnYE1s8/Rgm2AvwZ7iTiJt/y78u+xc1s3HHKDt/yrxnXKqjebzcTGKjuCO5/JvBM3vFU1jQOPg6K4YJK/ZV9Ni5nRrY8+3DHStMxZux8DhdMzXYpOhe4R1W7/eN23Ml/VjXQ7v+DB/q9Ze0MTAfPUfr8VlQ1g/ttDPbd77e1b0B+nZL/OHd+JCIf8C2NW/xvr4bec61pDHxONBuYlv3N+m0/Qt/feP5rLhXXZWw42/ZLRL4CHAe80r/vcPDngn/c3/HVQPstw7Uiuk9VvzCM9aeIyE2+lXQrLsiefd8GqxsHq1+HsltVuwJlKBeRH4jr8toK/BOo9S2GZuL+L/bn70RVtwP/Bl7mu49dwOC9YkadBXaKnw4x79XAJbjoag0uIg3uj3607MZdDZsRmDfzEPexB3fiFewbPgsXfUdVd6rqW1V1Gq4lz3ez/SlV9ZuqehIuqr2I/pN8GWPCdwPuCje4uuoGABG5QETu812kmnEHXwMFowc7MOvvpGian/4KronwX3x3rw8DqOpGXFPuT+KCwzeJyDSMMePFDmCiiJQH5uWOUURkKi7o+zCAD/x+SlWXAs8BXoxr/r8V12JnsqrW+lu1qma7LWzFtTzsT7De2grMkv6TeG7n4OOgFH1PMIPHfFuBtwXKU6uqZap6zwDlMGZEichluP/1l6tq0s/O/59+HJc4Oet4P2+w39tQ+vxWRERwv+ttflaf86X83/mREJdP54O49A8TVLUWaKH3XGsHA58TbQWezvvNVqnqhcN46qG27e8cERH5FC4IcZ6qtgYWPQ4s9+9d1nJ6u2oNynd1/y3QyPCTYX/el3OZqlbjujZln3+wunGw+rUD10osa2re8vz35f241pCn+jKc5eeLf56JEsj7k+cnvsyvAO5VlzMqNBbYKX67cH02B1KFO/DYi/uSf360C+SvkP0a+KSPgi5meJVy/j5uBj4nIlXikv+9DxfJRUReIb2JrfbjfqQZETlZRE710f0DQBeuCZ0xpvDcApztf8svBW7wBwa/Av4HmOIPkG5n4GD0YIGd/k6KtgOoapuqvl9V5+H6Zr8vm0tHVW9Q1TP8tgp86fBfojGmkPjWgqtwxygJETkd160i6wJc9w6F3BC/y/zV21bcRaeMqu7AtSb8qohUi0t0PF9Enuf3cy3wARE5SZwFIjJbRObiuqSs8+s9gDvx+6KIVIhLHvtcv+xG4L/EJXuupDdnRn9XsMF1L71a/CAZ4pI7v+JI3zNjhkNc0thvAZf6VrJZF9CbXwdcOoX3ich0f+Hk/bjuMgP+3obx9DcDF4nIuf4c4P24859sUDP/fKnP7/wIVeECrruBmIh8nL4tX27G/S4niMh04KrAsgeANnG5/cp8T4vjROTkYTzvUNvuAuaIzx8IICJX4y6kvUBV83PS3IVrtfxucTkKs+XMJmaPiMsbFHcPpVT8iFn+Pf8lriXTFb7F1HBU4VoKtfj3JngxfrC6sd/61S9bDbzavx/n47r0D1WGTqBZRCYSyJ3k6/k/4hoQTBCXIPmswLa/BU4E3oP7XofKAjvF7wvAR/1V7Zf3s/ynuKvU24AncPkmxsJVuBZCO3HNKm/EVbCH4l244Mxm4F+4q/k/8stOBu4XkXZcIrH3qOpmXEX6f7hgzzO4gNZXjuiVGGNGhT/wuwuXtO5pf6KTwHWF2g2kROQCXELSg/RzgpTvRlz9WCdu5IyP0xscfrE/EBDclbU0Ljh8jIic4wNMXbg/ewsOGzO+vAaXv2EvLgfgL+g9Rskf/ngq7oSlFZcw9R+44xpwF60SuOOr/X69BgBVvQXXJeUGXFeG3+ISdvbZv7+QdTEuv+CzuKvdr/KLf+Sf65/A07g66V0DvShV/Q0uEH2TuC4Fj+FOYI0ZC5fgkpf/S3pHxroT183q2cB6P8B111mL+47+gd7k5oP93gakqk/iWk58C9fq/2JcsuQev0rufElEPsDIDnP+Z1yumqdw5x5d9O1u9Wnc7/pp4K+419fty53GtUo6wS/fgwta1Az1pMPY9hZ/v1dEsi2TPo+7yLUx8Bl9xO+vB5fj5vVAMy7f0aWB9/As3DHR7fQmGf6LX5ZtXXUeLkCS3fdA3eizPoULjLTgvge/znt9/daNg9Sv4IIsF/vX8Bq/bDBfB8pw7999HDzK6utwAcb1uHyy7w2UsRN3MXJusOxhkZEJVBozOBH5EjBVVa8IuyzGmMIhIq/DBaA/qKpf8fP+ExeEKcEd/MWBjar6URE5G5cMcIa/mrRYVa8K7G8L8BZV/au/svRlXBNZcAc5H1TVLhH5L9yffx3uhOwHqvoZEVmOOzBagvsjvwe40velNsaMQyLyC9xB+2dwF6Tm5XVRGMnnuh34tqqO1EmlMQVL3NDik1X1g2GXJUtc155R/Z0P8fzvAC5T1aFakpgi4FtoLVLV1w658miXxQI7ZjSI636VwEXjT8ZFd9+iqr8Ns1zGmPHDTpCMMYfDd1XYh7vKfR7uiu7puNbNL1PV743ic38Q+Ja/0mvMuCYirwTWDtKydsyJSD2j/DvPe74GXDewe3EjgP0Bd+zy9bF4fjN6fNetR4DXqeo/Qy+PBXbMaPAHTTfiEpXuwg1N/kXgDFxfxYOo6kCjbxljzEHsBMkYczhE5GLgu8AkXPP+L6jqj8MtlTFmPPK5X/6A667TDNyEGwq8Z7DtxgsR+T6um1y+61X17WNdnpEiIm/FdeP6WaG8DgvsGGOMMcYYY4wxxhQpS55sjDHGGGOMMcYYU6T6Gxd+TEyePFnnzJkT1tMbYw7DQw89tEdV68Iux0iwOsiY4jNe6iCrf4wpPuOl/gGrg4wpRkPVQaEFdubMmcOqVavCenpjzGEQkWfCLsNIsTrImOIzXuogq3+MKT7jpf4Bq4OMKUZD1UHWFcsYY4wxxhhjjDGmSFlgxxhjjDHGGGOMMaZIWWDHGGOMMcYYY4wxpkhZYMcYY4wxxhhjjDGmSIWWPHm4fv/odmrL45y5cFwkoTfGjBIR+S/gLYACa4E3qmrXkexzy54D/G19E689bRYlsehIFNMYY8aNve3dTChPEIlI2EUZli17DlBVGmNSZUnYRRmWprYu2rpSzK+rDLsow9KVTLOxqZ3jpteEXZRhaelM8ttHtnHuknpmTCgPuzjGmAKlqrR0JkmmlXRGSWUy/t4/Hmh+RklnMoHlfeenMzCxIsH5x00dkXIWfGDn6399isVTqy2wY4wZkIhMB94NLFXVThG5GbgMuO5I9rt2Wwufue0JTps3kWOnFceBqjHGjIWWjiRnfOlOPn3Jsbxi5cywizMsb7/+IRZPreLrl60IuyjD8tnb1rF6azP//ODzwy7KsNz4wLN89g/ruO/qc6mrKvzg2bodrXzi1seZO7nCAjvGHOUyGWVXWxdP7znAM3s72LL3AM/s8fd7O+hMpkfleVfMqj16Ajul8Shdo/RGGmPGlRhQJiJJoBzYfqQ7XNJQBcC6HW0W2DHGmICnmtroTKZ5ZGtz0QR22rpSPLGjNexiDNu6Ha08u6+Dtq4kVaXxsIszpCe2t5LOKBt2tRVFYGdjUzsA8+uLo0WUMebIpDPKjpbO3sDN3g4fyHHT3alMbt1ENMLMiWXMmVTBc+ZPZlptKSXxKLGIEI1I4D7S+zg6wPz+tvHrlsRGLjNOwQd2yuLRUYuQGWPGB1XdJiL/AzwLdAJ/UdW/HOl+50yqoCQWYX0RnQgYY8xYyJ4Ub9zVHnJJhi+VyfD0ni5S6QyxaGGnmUylM2zZewCATbsPcMLM2nALNAwbd7fn7p+zYHLIpRnapt3tlCeiTKspDbsoiMj5wDeAKHCtqn4xb/n7cN3NU8Bu4E2q+oxfdgXwUb/qZ1X1J2NWcGMKTCqdYUdLVy5gs2VvR+7+2b0d9KQDwZtYhNkTy5k9qYLnLapj9qQK5kyqYPakcqbVlhEtkm7GWQUf2CmNR+noSYVdDGNMARORCcAlwFygGbhFRF6rqtfnrXclcCXArFmzhtxvLBrhmKlVrNtpgR1jjAnKBnaeampDVREp/APgdEZJppVn93Uwr8Dz1jy7r4NkWgH3Xhd6YEdVe4N9TcUR7NvY1M78usrQv7siEgW+A7wQaAQeFJFbVfWJwGqPACtVtUNE3gF8GXiViEwEPgGsxOUYfMhvu39sX4Uxo0tV2d+RZFdrF7tau2hq7XbTbV3sau2mqdXd727vJp3R3Hal8QhzJlUwv66CcxfXu+DN5HLmTKpganVp0eSIG44RC+yMRuJScIGdvQd6jnQ3xpjx7QXA06q6G0BEfg08B+gT2FHVa4BrAFauXKn5O+nP4qlV/HVdU9GcuBhjzFjInrw3dyTZe6CHyUWQkDgYKCn0wE4wOLKhqS3EkgzP7rZu2rrchdhiCexs3n2Ak+dMCLsYAKcAG1V1M4CI3IS7WJUL7KjqnYH17wNe66dfBNyhqvv8tncA5wM3jkG5jTliqkpbdyoXmNkVuG9qC0y3dvdpbZNVWx5nSlUp9dUlLJxSxZTqEmb5VjhzJlVQX1UyroI3gxmRwM5oJS4FKEtE6bauWMaYwT0LnCYi5biuWOcCq0Zix0saqrl5VSO727qprw6/ubYxxhSCjU3tTK4sYU97Nxt2tRdFYCd7FXfj7nbOC7ksQ8l2a5peW1YU3d2ywZxpNaVsKILAzoHuFNuaO7m8viDyQ00HtgYeNwKnDrL+m4E/DrLt9P42OtRWy2Z0dPSkePiZZvZ39FBTFqemLE5tubuvKo0XXfefgaTSGfYe6GF3W3fvrb27z+Ns4Ka/tCtVJTHqq0uYUl3KyXMmuumqUqZUlzLFz6+rKqE0bqPWZo1kV6wRT1wKUBqLWI4dY8ygVPV+Efkl8DCu//kj+JY5R2pJQzUAT+xotcCOMcbgTky2NXfymlNn8fP7n2VDUxunz58UdrGGlPRXe4uhRcnGpnamVJewYlYtaxpbwi7OkLKBqBcdN5Uf/3sLLZ1JasoKN+Hz5t0uf1GxDCWfJSKvxXW7et6hbns4rZbNkTvQnWLVM/u5f/Ne7tu8lzWNLaQy/b/9IlBZEssFemrK4tSWJaguix8UBOpzK49TVRIb9Zbl2WG/BwrUBB/v6+hB+3mZVaUx6qpKqKss4bjpNbxgiQvWZIM4U6pLqa8qoaKk4DPGFJwReceGm7j0cCLFZQlLnmyMGZqqfgLXz3xELZnqAjvrdrRx9jH1I717Y4wpOtmT4ucumMytq7ezoQhalECgxU4RBHY2NbWzoL6ShfVV/GHtDrqS6YK+Mr2xqZ3KkhjPnT+ZH/97Cxub2jlpdkF0c+rXJh+IWlAYI2JtA4JNh2b4eX2IyAuA/wc8T1W7A9uenbftXaNSSjMsbV1JH8jZx32b97J2WwvpjBKLCMtn1PDWs+Zx2rxJNNSU0tKZpKUj6e47kzR3JmnNTnf00NKZZGdLKy2dKVo6e3LdSfsTEaguixOPRhAgIoJI7312OiKCgJ8nRPx8/DaRCAhuPn55OqPs8YGb/sqQiEWoqyyhrqqEmRPLOXH2hNzj3M0/LuR6rNiNVFesYSUuPZxIsQ13bowJU015nGk1payzkbGMMQboe1K8YEplUeSAUdXcVfJNTe0FnTdNVdm0+wAvO3E6C+orUXXv+bHTasIu2oA2NrUzv76ShVNcoGRTgQd2Nja1E40IsydVhF0UgAeBhSIyFxeouQx4dXAFEVkB/AA4X1WbAov+DHzen4sBnAdcPfpFNlmtXUlWbdmXC+Q8tr2VdEaJR4XjZ9Ty9ue5QM5JsydQnjj8U29VpTOZ9kGf3mBQfnAolVFUFVXIqJJRUNxj9Y8zqqjfZ5/1/DpKcJ4SEWFhfVW/gZq6qhKqS0e/tZAZ2ki1cRpW4tLD4QI7mYL+AzbGjG9LGqpZbyNjGWMM0HtSPGdSBQvrK/n7+qahNwpZtrXOlOoSdrV2s6Oli2m1ZSGXqn+7Wrtp7065Fjs+ULKxqfADO2curGPGhHISsUjBB/s2NrUze6Ira9hUNSUiV+GCNFHgR6r6uIh8GlilqrcCXwEqcRfPAZ5V1Zeo6j4R+QwuOATw6WwiZTM6WjqTPPj0Pu5/ei/3bd7H49tbyCgkohFOmFnLO8+ez2nzJnHirAmUJUaudYqIUJ6IUZ6I0VBTmHWXCddIBXZGLXFpadxVuN2pjDXdMsaEYklDNXc9tbvgm8IbY8xYCJ4UL5pSxc2rGtl3oIeJFYmwizagbGudxVOr2dW6m41N7QUb2Ml2FZtfX8mcSRVEI1LQ3d1au5I0tXWzcEol0Ygwb3JFwXd327TbtTAqFKp6O3B73ryPB6ZfMMi2PwJ+NHqlO7o1d/TwwNP7uP9p1yLniR2tqLruRytm1vKucxZy6ryJnDhrgh0jmlCNVI6dUUtcWuZ/IJ09dkJljAnH4oYq0hllY1M7x00v3CumxhgzFrLdbqA3R8mGXW2cOq9wEyjnAjsNVfzjKRfYOWtRXcil6t9G39plQX0liViEOZPKC7oFTDaIs8AnIl44pYrVW/eHWaRBpdIZtuw9wLlLpoRdFBMSVaW1M8Xu9i6a2vJHauqbDHjfgR4ASmIRTpw1gfecu5DT5k3ihJm1dm5qCsqIpZsercSl2cBOV8ry7BhjwpEdGWvdjlYL7Bhjjmr5J8ULp1QBsKGpvaADO2mf8HNKVSm15fHcKE6FaOPudqpLY9T5IeQX1FcWdAuYXGAnG+yrq+S2Ndvp7EmPaFeUkfLsvg6SaS2UxMlmBHWn0uxp7+kznPZAAZueVOag7ROxCPU+b8zsSeWsnDOBabVlnDxnIsfPrKEkVnjfZ2OyCn4csdJAix1jjAnDnEkVlMYjrNtRuFdMjTHjzy8efJbWzhSKkohGqC1PUO5HCz3QnWZ/Rw/7D/QwqbKEeXUVpDPK03sOsP9ADyXxCNNqy3j1KbMOylH41K427tm4h/buFMm0Ul0WpywepbXLjcgCEI0IquQScUYikhttJXhSPK2mlIpElF88uJVELEJDTSkbm9ppautmUkWCiRUJulMZ2rtSdCbTdCbTlMai1FeXUBqP0NKRpK0rRTKjpNIZupIZOpMpdrf10NTWRWksytJp1UyrLaW9O01nTwpw+SaS6QzJdIZXrpzJ8hm1fV5jOqP8dd0uzls6BREhlXEncfGosKCub6Akk1FuW7uDHc2dtHQmc8OiR0SIRYWyeJTa8gQVJVHaulK0daWYWJFganUpnck0O1q62Lqvg2f3ddDenWLWxHJmTCijIhGjJB7Jvf/dyTTt3ancCeWcyRW8cmXvYEiPPLufDbvauW/zPhbUV+a2W1hfxV/XNbG3vZtJlSWoKresamTT7vbcyLFl8ShliSiVJTFKYhFau1K0dCbpSqZJpjNEI0JJLEo6o7R2JmnrTtHR48pSUxZnYkWCsniUeDRCtX8MsLOli7auJBMqEtSWxTnQk6atK8W02lIWT62mM5nm9rU7SMQizJxYDpBL+Hz1r9cwpbqUkniU0niERDRCIhYhk1E/+k+KnnSaZKp3PJXaijgN1aWUJ2J0p9J09KRzowXtbe9mf0eSWRPLWTa9hsk+8JXKZOhOue9COqNkMupykZRESaYzHOhO09GToqMnTTqjbGvuzJXTjA+Pb2/hP3/+MFv2dvS7fGJFIhewmTe5grrqvsl/66tKLQmwKXrFE9ixkbGMMSGJRoRjplbbyFjGmDH1rb9vpHF/56DrlMYjdCX7XnkuT0TpTrmT3DMX1DFrUnmf5R/77WPc/3T/+VVjETc0birjRkKJRtzQuKrQk+59nuUzXOtFEeHKs+bzs/u28MFfrsktjwhk+hn/NBaRXLeofPGoUBqLUpqIMqkiQX11KR3dKW5ZtZUDPb0BDHCjvMQjEdp9cOLLL6/ts697N+3lbT97iFuvei7LZ9TmnjMaiTB9Qhmrtzbn1l27rYV33/hIrnzxaATxQ/ymMppLvDyYikSUWZMqqCyJcveG3exq7R5w3VhEciPOXHLCNEpiUdq7U7zse/fk3rN3n7Mgt/4Fy6byf3dv5q0/XcXP33IaX/zjOn5y7zOUxCK5FjGdPWm681oglMYjlMajxCIRVJWuZJpIRKgpi1NZEqOyJEYsKmxr7uKxba10pdL0pDJ0BC6mRiNCRSJKa1eqT/nzP8Pzj51KNOJOiE+eM4EF9ZX846nddPRTrqxsECoWFQT3njR3JPt8z7LPV1MWZ1JlgpqyOHeub+KXDzUO+P4Ox8SKBAstsDMuPLhlH2/68YNUlcZ4/wsXHRSsmVSZIB4NP0m2MaOtCAI77oeYf9BijDFjacnUKv70+E4boc+YcU5EZgI/BaYAClyjqt8QkYnAL4A5wBbglaq6X1yF8A3gQqADeIOqPjwSZfnTe8/K1Tk9qQz7O3pyOQcrSqJMKE9Q6lvaPL37ANGIMHdyBRUlMe7esJvX/fABdrR0HhTYaetKcfYxdVzzupVEI0J7V4qOZIrq0jjlieiAdVw6o3443Qz1VaW5+e95wULedc4C1u9so7mzhwX1ldRVltDamWJfRw+l8QgVJTHK41Fi0QjdqTR723voSqapKYtTVRonHpUBnzeTUdp7UlQkYrngQdb5X/8n+zuSB23T3u3mZYMU2UBELCLEIhFS6d7ARHadH7/xZM5eVHdQObpTaZo7khzoTlHtgyL7DvSwo6WL8kSUhppSasrifbZLpl0rkq5k2g0hjLrPzb+Ga+/ezGf/sI6uZMYFdrpSZBQ+fMFirjh9Tp8uTMdOq+HrrzqBd97wMC/42j/Y1tzJW8+cy0cuXNLnOdMZ5UBPiq5kmurS+GHn/+hJZXJ5ReqqSohGXOuo1s4kFT4Ys7O1i/U72iiJR1jaUE1teW/i7PrqUv76vuflHqsq3akMPekMPakMEXGBmvzPEtxnva+jJzdYQVk8etB3UlXZ2dpFmw82RUQoibnWQC4wKbkWOvFohPKE20dZPEo0IqQz7jfV3/Ob4nLXk028/fqHmFZbxvVvPrVgE6IbMxYKPrCTy7FjLXaMMSFa0lDNTQ9uZVdrN1NrSofewBhTrFLA+1X1YRGpAh4SkTuANwB/U9UvisiHgQ8DHwIuABb626nA9/z9EassCRymlTDgqFPVpXGOn1nbZ97UaldP7WztOmj9zmSaqtJ4bqjnmvI4NcSHLE80IgOWIRIRlk6r7jOvpjxOTfnB+y2JRQ/pBCwSEapL+y9fTVmcln4CO9kLgtkATjbHTiwqxPzJfVZ2urKk/24YJbEoU6r7Bkmm1ZYN+hri0QjxaKTvZxiQDdxkg1vZlulTqkv6zUtzwbIGPnrRUj5z2xO85YyDgzrgPp/q0viA79VwJWKRg/7n4tEIk3zXJ4CGmrJhD7ksIpTGo8MKNEUikutiNdj+3PMPvM5gI7TFohbQGQ9uW7Od//rFahZNqeInbzplyO+NMeNdwQd2Si2wY4wpAMEEyhbYMWb8UtUdwA4/3SYi64DpwCXA2X61nwB34QI7lwA/VVUF7hORWhFp8PsJTb0P7OzqL7DTk6Z8nIzmMqE8waZ+EiF3+0E3kj63TvY+GnF5c7I5d/KXjZXSWN/j265k365m/XnzGXO5aFkDU6pLrOWoOard+MCzfOQ3a1k5ewI/fMPJRxzMNGY8KPgOh7m+wxbYMcaEaHGDG/nlCcuzY8xRQ0TmACuA+4EpgWDNTlxXLXBBn62BzRr9vPx9XSkiq0Rk1e7du0ev0F51aYyyeLTfXC8dPamCHK3ocNSWx2nuHEaLHd8qJ9uSJpgjJrtOPDJ2h8X5x7fZ+6FatUytKbWgjjmq/eAfm7j612t53qI6fvqmUy2oY4xX+IGdXIsdy7FjjAlPdWmc6bVllkDZmKOEiFQCvwLeq6p9fvi+dc7QGXX7bnONqq5U1ZV1dXUjWNL+iQhTa0r77YrVlcyMo8BOgpaOJO4j6ZVtsZPyiXizI11FIy63SjDHTjqEFjv5x7ddPcML7BhztFJVvvyn9Xzhj+t58fIGrnndynFTjxkzEgo+sFPikydbix1jTNiWNFSzfqcNeW7MeCcicVxQ5+eq+ms/e5eINPjlDUCTn78NmBnYfIafF7op1SXsaukb2EmlXRLbwbr8FJPa8jg96b4jOUFvwCSZ6dtiJ+a7YiUDIy8lsy12xjD3Su741pe7KzV0VyxjjlaZjPKx3z3Gd+/axOWnzOIbl63I5QgzxjgF/4vIXdHoscCOMSZcSxuq2Ly73XJ+GTOO+VGufgisU9WvBRbdClzhp68AfheY/3pxTgNaws6vkzWl+uAWO9kLZeXj5Er3BJ+cOb87VraezrbGyY2KFY0MmDw5lBY7PqDT2ePKaS12jOkrmc7wXzev5vr7nuXtz5vP5196nI1oZkw/Cj6wY8mTjTGFYnFDNRmFp3ZZqx1jxrHnAq8DzhGR1f52IfBF4IUisgF4gX8McDuwGdgI/B/wzhDK3K+p1aU0tXb36aaUbSEyXrow1JS50Y/2++G5s7pT2S5Y7rVnu17lhjvPaO59ybbeiUfH7rC4NO/C5XCSJxtztOlKpnnbzx7id6u388Hzj+HDFyy2HFPGDKDgR8WK+ysr1hXLGBO27MhY63e0sXxGbbiFMcaMClX9FzDQmcO5/ayvwH+OaqEO05TqUnrSGfZ3JHPDP3eOswBCtsVOywAtdrIBnewoWC6w4z7edEaJRSXUFjsHJ08u+GuuxoyJtq4kb/7JKh7cso/PXnocrz1tdthFMqagFXxgB9xVDUuebIwJ2+yJ5ZQnojYyljGmKEytcUOe72zpygV2srloxktXrNpy32Kno/8WO9mATq7FTlSI+ZY5qYwSi/bm4YmNYY6d0vzkydnAzjj5XIw5Envbu7nixw+wfkcbX3/VCVxywkEDDRpj8hTFZYHSeNRa7BhjQheJCMdMrbKRsYwxRWFKdQkAuwJ5djrG2ehLtdkcOx39t9hJ5g13Ho1Eci12snl30tmuWGM53Hleix3rimWMs6Olk1f+4F427GrnmtefZEEdY4apKAI7ZYkI3RbYMcYMQESOCeTCWC0irSLy3tF4rsVTq1m3o/WgoXWNMabQTKl2LXaCgZ2uXPLkomi0PaSasmxgZ4AWO+m85Ml+VKz+lkVDGBWrK9AVKxqRMc3zY0yheXrPAV7+vXvZ1drNT990CucsnhJ2kYwpGkXx71EasxY7xpiBqeqTqnqCqp4AnAR0AL8Zjeda2lBFa1eKHXlDCBtjTKGpr/JdsfppsTNeWoaUxqOUxaMDttjJBm2yQZxYVA5qsZO9H8sWOyWxCCK95exKjp8h6I05HE9sb+UV37+XzmSaG996GqfOmxR2kYwpKkUR2ClLRG1ULGPMcJ0LbFLVZ0Zj59kEytYdyxhT6BKxCJMrE31a7OSSJ4+jXC4TyuMHDXfeOypWfoudSG+OnXTfoM9YJk8WEcri0T4tdixxsjlaPfTMPi675l7iUeHmt53Oshk1YRfJmKJTFP8g1mLHGHMILgNu7G+BiFwpIqtEZNXu3bsPa+fHTK0CLLBjjCkO9VWl7Grtzj3u7EkB4yuwU1OeOLgr1iCjYkVzLXYO7qY1loI5JLt60uMm75Exh+KfT+3mtdc+wMSKBLe8/XQW1FeGXSRjilJxBHYSUTptVCxjzBBEJAG8BLilv+Wqeo2qrlTVlXV1dYf1HFWlcWZOLGPdzrYjKKkxxoyNqTWl7Ax0He3Mjoo1joIIE8rj/XTF8i128kbFcnlspM+8VFqJiEuQP5bKAqO+dqXS1hXLHHVuX7uDN//kQeZMruCWtz+HGRPKwy6SMUWrOAI7MUuebIwZlguAh1V112g+yRKfQNkYYwrdlOrSvqNijcOuWLXl8X6GO+/bYic7KlY8GiEa6R3uPHsfCyFpcUk8kmux02ktdsxR5uYHt3LVDQ+zfEYtN115GnVVJWEXyZiiVhSBnbKEdcUyxgzL5QzQDWskLWmoZsueA7kr38YYU6imVpey90BPLtDR1ZNGxCXvHS9qyxO0dPbfYiebPyeZCbTYye+Klc6MeTcscC12ui15sjkKXXv3Zj74qzWcsbCOn735lNzodsaYw1cU/+rB5HLGGNMfEakAXgj8erSfa0lDFRmFJ3dZdyxjTGGbUu2ugu9uc3l2Onpclx+RsQ9kjJbaMtcVS1Vz87KBrGxAJ53uJ8dOOtBiJ4TATjDHTmcynRsC3Zjx7P/+uZnP/mEdFy1r4NrXr6Q8EQu7SMaMC0XxD1Iaj9qVcWPMoFT1gKpOUtWW0X6u7MhY6607ljGmwE2pcUOeZ7tjdSbTlI+jblgAE8oTpDJKe3cqNy/bYied7jukeSwqxKP5XbEyoXTFKgsc33YlLceOOTr87L5nOG3eRL55+QoS46jloDFhK4pfU2k8SlfKkicbYwrDzAnlVCSilmfHGFPwplS5wE6THxmrsyc9rvLrANSUu24c2QTKqhposXPwcOe9LXZ88Ce0FjuR3uTJScuxY8a//Qd6eHZfB89bVJ/7HRpjRkaRBHYi9KQyucR3xhgTpkhEWNxQzbod1hXLGFPYarNBD5+DpnMctgypLesb2EmmlewhY37y5FhUiGVHxfLzkunwumJ1BbpijbfPxQyfiJwvIk+KyEYR+XA/y88SkYdFJCUiL89blhaR1f5269iV+tCt2eYaVR8/oybkkhgz/hRFYCf7R2d5dowxhWLx1CrW7Wztk9PBGGMKTTYpaTa5cEdPmrJxltNiQkUCgOZONzJWV6r3eDGbIDnpW+dERYhlR8UKBH3C6orVFUyePM5aUpnhEZEo8B3cyJ5LgctFZGneas8CbwBu6GcXnap6gr+9ZFQLe4TWNjYDcOx0C+wYM9KKIrBTaoEdY0yBWdJQTVtXim3NnWEXxRhjBlSeiBKPSi6w09mTpmycJenNttjZ71vsdCd7u+8nA8GbiLgWl70tdnqDPpY82YToFGCjqm5W1R7gJuCS4AqqukVV1wBFnZvi0cYW5k2usFGwjBkFRfEPkm2xY0OeG2MKRTaBsnXHMsYUMhGhxo8aBdnkyeOrxU5tuWux09LhW+wEjhdzw52ne1vlxPJGxXItdkIY7jwRpSvpUg30pGy486PYdGBr4HGjnzdcpSKySkTuE5FLB1pJRK70663avXv3YRb1yKxtbGGZdcMyZlQURWCnNJFtsVPUQWpjzDhyzNQqAEugbIwpeNVlcVpzXbFS4y6AUJPfYicw4EY2j04609sqJ9cVq0+OnbE/JC6NRehMpnOBKEuebA7TbFVdCbwa+LqIzO9vJVW9RlVXqurKurq6sS0h0NTaxc7WLpbPqB3z5zbmaFAcgR0/FJ51xTLGFIrKkhizJ5WzfqcFdowxha22LJ7rijUec7kkYhEqS2K5VknB48Vsbp1UYOSr/K5YbrjzELpi+c8h+9mMt4CbGbZtwMzA4xl+3rCo6jZ/vxm4C1gxkoUbKY82WuJkY0ZTUQR2yhKWY8cYU3iWTLWRsYwZb0TkRyLSJCKPBeb9IjDqzBYRWe3nzxGRzsCy74dW8EHUlMVziYXHY4sd8K/Rd8Xq02LHd7dK9dMVK50JdMUKI8dOzH0O+325x+PnYoblQWChiMwVkQRwGTCs0a1EZIKIlPjpycBzgSdGraRHYG1jMxGBpdOqwy6KMeNSUQR2Si3HjjGmAC1pqGbL3gN09KTCLooxZuRcB5wfnKGqr8qOOgP8Cvh1YPGmwIg0bx+7Yg5fTaDFjsuxM/4CCBMq4rkh3bv98WJJLELSB29SGSWa1xUrm1jZJU8OYVQs/zlkWxpZ8uSjk6qmgKuAPwPrgJtV9XER+bSIvARARE4WkUbgFcAPRORxv/kSYJWIPArcCXxRVQsysPNoYwuLplSNuxxfxhSKEftliUgtcC1wHKDAm1T13pHYdy55co8FdowxhWNxQxWq8OTONlbMmhB2cYwxI0BV/ykic/pbJiICvBI4Z0wLdYRqyxO0dCTJZJSuZGZc5nKpLUvkWr5kW+xUlsRyyZNT6QzxvK5Yad8VK51R4iEMd17qAzn7DliLnaOdqt4O3J437+OB6QdxXbTyt7sHWDbqBTxCqsqaxmZeuHRK2EUxZtwayX+xbwB/UtXFwPG4iPOIyP7xdaUsebIxpnAstZGxjDnanAnsUtUNgXlzReQREfmHiJw50IZhjkhTXRantSvFAd+6cDy22KkJJIjOdt2vLI31GfkqGu0b2OltsdPbmmcsZQM52S5k4zHgZgxA4/5O9nckWWaJk40ZNSMS2BGRGuAs4IcAqtqjqs0jsW/o/aPrshY7xpgCMmNCGZUlMRsZy5ijx+XAjYHHO4BZqroCeB9wg4j0m0AizBFpsqNGNbV1A4y75MkA1WUxWrtc4Kor5Y4XKxIxkpne5MnxSDbHjrsP5tgJp8VONseOT548Dj8XYwDWWOJkY0bdSP2LzQV2Az/2V62uFZGK/JUO92pV9opG9o/aGGMKgYiweGqVBXaMOQqISAz4D+AX2Xmq2q2qe/30Q8AmYFE4JRxYrQ/s7GzpAsZnl5/q0t4WO91J3xUr0GInlcn05tjJtdjJ5O7DaLHTG9jxLXZi4+9zMQZgzbZm4lHhmKlVYRfFmHFrpAI7MeBE4Hv+qtUB4MP5Kx3u1apSy7FjjClQSxqqWb+zDVUNuyjGmNH1AmC9qjZmZ4hInYhE/fQ8YCGwOaTyDSjbYmeHD+yMx+Sl1WVxulMZulPpXFesqpJYrlVOKh1Mnpwd7jzYYifMrljZFjuWPNmMT2u2trCkoZoSC14aM2pG6h+kEWhU1fv941/iAj0jItcVK2k5dowxhWVJQzXt3Ska93eGXRRjzAgQkRuBe4FjRKRRRN7sF11G325Y4Lqhr/HDn/8SeLuq7huzwg5TTXm2xY6rp8ZjAKG61AWr2rpSueTJFSWxXKucVKC7VX5XLDdiVphdsSzHjhm/MhnlsW0tLLduWMaMqhG5ZKOqO0Vkq4gco6pPAucCIzbUXjQiJKIRG+7cGFNwljS4ZsVP7Ghl5sTykEtjjDlSqnr5APPf0M+8X+GGPy9otXktdsri47PFDkBrZzJ3IbCiJJZrldN3uPO+XbFSmd4Rs8ZSWV6OHQvsmPHo6b0HaOtOsXx6bdhFMWZcG8nLE+8Cfi4ia4ATgM+P4L4pjUdyTWuNMaZQHDO1ChEsz44xpmDV5OfYGYdJeqt8i53WrhTdqTTxqFASi+SCN+lMJtfdKhIRIkK/3bTGUnbU1+yoWOMx95ExaxqbAVg+01rsGDOaRuySjaquBlaO1P7ylcajFtgxxgxIRGqBa4HjAAXepKr3jvbzlidizJlUwXob8twYU6CqD8qxM/4CCNWlfVvslMSixKOSS56cP6R5LBLJDXeeyiixMEbF8p/D/gPWFcuMX49ubaEsHmVBXWXYRTFmXCuatrhliah1xTLGDOYbwJ9U9eUikgDGrF/UkoYqHt9uLXaMMYWpNB6lNB5hZ+s4HhUr2xWrK0lXKk1pPEIsGiGVybbYURKB1x2LCunsUOjpTK571ljKjoLV2pUiEY2E0mrImNG2dlsLx06rDiV4aszRpGh+YWXWYscYMwARqcElMf0hgKr2qGrzWD3/kqnVPLO3g/bu1Fg9pTHGHJKasjj7fMuQ8dgVq7fFTorubIudiJBMK6rqW+X0Bk6ifhlw0LKxEo9KLpiT7ZZlzHiSSmd4fHsLy2fUhl0UY8a9ovkXKYlH6bRRsYwx/ZsL7AZ+LCKPiMi1IlKRv5KIXCkiq0Rk1e7du0fsyRc3VAPw5E7rjmWMKUzZPDswXlvsZHPsuBY7Jb7FDrjWOvmtcuLRSJ8cO/EQWhOISO6zsG5YZjza0NROVzLD8ZZfx5hRVzSBnTJLnmyMGVgMOBH4nqquAA4AH85fSVWvUdWVqrqyrq5uxJ48OzKWJVA2xhSq2rJEbno8BnbK4lFiEaG1M0l3MkNpLJprhZPKKOmM5oY5B9diJ9hNK6xuUNmWOuOxFZUx2cTJy6ZbYMeY0VY0gZ1JlSVsb+4MuxjGmMLUCDSq6v3+8S9xgZ4xMb22jKrSmAV2jDEFK5uDpiQWITIOc7mICNVlcVq7knRnW+wEhjVPpjNEA92t4pFAYuWQhjuH3pY62Xw7xownjza2UFXqBpkwxoyuognsLJ9eQ+P+Tva2d4ddFGNMgVHVncBWETnGzzoXeGKsnl9EWDK12gI7xpiCle2KNR5HxMqqLo3R1pXqbbHjW+ik0tkWO4EcO1EhlVEyGUUVopFwDolzgZ1x/LmYo9faxhaWTa8Zl8FkYwpN0QR2jp9ZC8CabS3hFsQYU6jeBfxcRNYAJwCfH8snX9JQxZM728j4nA3GGFNIsoGd8dgNK6u6LO6GO/ctduK+hU4yk3EJkgPBm3gkQiqjJH13rDCSJ0Pv51EaK5pDcmOGpTuVZv3OVkucbMwYKZp/keOm1yACa7ZaYMcYczBVXe3z5yxX1UtVdf9YPv+ShmoO9KTZur9jLJ/WGGOGpbbcB3bGccuQqtIYrcEWO9HeFjupdF6LnYiQSmdyCZTDGO4cLMeOGb/W7WgjmVaOn2H5dYwZC0UT2KksibGgrpJHfRIuY4wpJEv8yFjWHcsYU4h6u2LFQi7J6Kku7dtiJxusSWcOHu48FvUtdnyenVgIo2JBb1es8dySyhyd1mYTJ1tgx5gxUTSBHXDdsdY0NqNqXR2MMYVl0ZQqIuKuUBljTKE5KrpilbrkyV3JNKWxaG4I82Q6QyrTd7jzWMG02LHhzs349GhjC5MqEkyvLQu7KMYcFYorsDOjhj3tPWyz0bGMMQWmLBFlzuQKa7FjjClINUdBV6zqshitnSm6UxlK45G+w52ntU+C5JhPnpxKF0iOHQvsmHFmbWMLy2fUIGKJk40ZC8UV2MkmUG60PDvGmMKzpKGadTstsGOMKTxHS4udzmSaA90pSuK9o2K5FjuaS6YM2RY7SjLkFju9gZ2iOiQ3ZlAdPSk2NLWxzBInGzNmiupfZPHUahLRCI9ubQ67KMYYc5AlU6vYuq+Ttq5k2EUxxpg+jorhzv1rTKaV0ljvqFiptJLKZIj26YoVIZXJkM7m2AltuHOfPHkcB9zM0eexba1kFEucbMwYKqrATiIWYcm0akugbIwpSNkEyk/utDw7xpjCUuuDHqXjOrDTmxi6JB4YFSs73Hn04K5YYQ93nv08LLBjxpM1ljjZmDFXVIEdgBNn1fLwM83cs2lP2EUxxpg+bGQsY0yhyrZmKR/HAYTq0nhuuiQWIe5b6HSnMqjST/JkDSRPDqnFTsxy7JjxZ01jCw01pdRXlYZdFGOOGkUX2Lnq+QuYPamcN1+3igee3hd2cYwxJqehppSasjhP2MhYxpgCE49GuPyUmZy1qC7sooyaqmBgJ9BipyuZBujTFSsayQ53njlo2VjKJrMezy2pzNFnTWMzy621jjFjqugCO5MqS/j5W0+lobaUN133ILvbusMukjHGACAiLJ5axXpLoGxM0RKRH4lIk4g8Fpj3SRHZJiKr/e3CwLKrRWSjiDwpIi8Kp9TD84X/WD6uAzvBrlilsd5RsbqSLngTTJ4cj/Yd7jweVlesWKTPvTk6icj5vg7ZKCIf7mf5WSLysIikROTlecuuEJEN/nbF2JW6fy0dSbbs7WC5JU42ZkwV5b9IfVUpP3jtSbR3p/jNI41hF8cYY3KWNFTz5M42Mv5kwRhTdK4Dzu9n/v+q6gn+djuAiCwFLgOO9dt8V0Ss6UVIqvNa7MQj+S12eg97oxEhnVGS2eTJ0XAOibMtdsbzMPRmcL7O+A5wAbAUuNzXLUHPAm8AbsjbdiLwCeBU4BTgEyIyYbTLPJi129zoxdZix5ixVZSBHYCFU6o4cVYtN69qRNVOoIwxhWFpQzUdPWme2dcRdlGMMYdBVf8JDLev9yXATararapPAxtxJ1cmBNk8QuBawGS7V3X6wE7fFjsRkplMIMdOSC12ssOdxyywcxQ7BdioqptVtQe4CVe35KjqFlVdA2Tytn0RcIeq7lPV/cAd9B+YHjNrtjUDsHx6bZjFMOaoU7SBHYBXrpzJxqZ2Vtvw58aYArG4oQqwBMrGjENXicga31Ure0V8OrA1sE6jn2dCUJGIko3PlMajuUBOtitWNC95cjqtpHyOnbADO9Zi56h2JPVIwdVBa7a2MHtSOTXl8aFXNsaMmKIO7Fy0vIGyeJSbV1l3LGNMYVg0pYqIwHoL7BgznnwPmA+cAOwAvnqoOxCRK0VklYis2r179wgXz4DLc5ZttVMSixyUPLnPqFhRIZlRUtkWOyHl2JleW4YITK2x0YPM6BqrOsglTq4dtf0bY/pX1IGdqtI4Fy5r4PePbqezJx12cYwxhtJ4lHl1lTYyljHjiKruUtW0qmaA/6O3u9U2YGZg1Rl+Xn/7uEZVV6rqyrq68ZvAOGzZPDul8WgukNMb2Ok97I1FIqQzSiqTOWjZWDpueg0PffSFzK+rDOX5TUEYdj1yJNuORR20u62b7S1dLJ9u+XWMGWtFHdgBuOyUmbR3p/j+PzaFXRRjjAFcAmXrimXM+CEiDYGHLwWyI2bdClwmIiUiMhdYCDww1uUzvbIjY5XEI8TzW+xEg8OdC8l0hpRPnhzWcOcAEysSoT23KQgPAgtFZK6IJHAJ2W8d5rZ/Bs4TkQm+i+h5fl4o1mbz61jiZGPGXNEHdk6eM5H/WDGdb9+5kYef3R92cYwxhsVTq9jW3ElrVzLsohhjDpGI3AjcCxwjIo0i8mbgyyKyVkTWAM8H/gtAVR8HbgaeAP4E/KeqWhPiEFWV+BY7sehBw50HW+XEo+Jb7GSHOy/6Q2JTpFQ1BVyFC8isA25W1cdF5NMi8hIAETlZRBqBVwA/EJHH/bb7gM/ggkMPAp/280Lx6NYWRFxLNGPM2IqFXYCR8MlLjuX+p/fxX79Yze3vPpOKknHxsowxRWppQzUA63e0ccrciSGXxhhzKFT18n5m/3CQ9T8HfG70SmQORZ8WOwcNdx5ssRMhle4N7ITZYscYVb0duD1v3scD0w/iuln1t+2PgB+NagGHae22FhbUVdq5mDEhGBeXJ6pL43ztlcfz7L4OvvjH9WEXxxgTAhHZ4q+orxaRVWGWZYkP7Fh3LGOMGVu5HDuBFjud/SRPjkeFZCaTGxUrHlLyZGPGC1W1xMnGhGhcBHYATp03iTc+Zy4/u+8Z7t+8N+ziGGPC8XxVPUFVV4ZZiCnVJUwoj1tgxxhjxlh2VKzSeD9dsfJy7KhCMn3wUOjGmEO3o6WLPe09HD/TumEZE4ZxE9gB+MCLFjFzYhkf/vXaXLNbY4wZayLC4qnVrNtpI2MZY8xYmjGhjJqyOIlYb1es7tTBo2L1JlbO9HlsjDk8axqbAVhm+XWMCcW4+hcrT8T44n8s5+k9B/jfvz4VdnGMMWNLgb+IyEMicmV/K4jIlSKySkRW7d69e1QLs6Shmid3tpL2+RuMMcaMvtecOpu/vu95RCNCJCJEBDp7+h8VC/rPv2OMOXSPNrYQi0iuO7oxZmyNq8AOwHMXTOaVK2dw7d1P88R26wZhzFHkDFU9EbgA+E8ROSt/BVW9RlVXqurKurq6US3MkoYqupIZtuw9MKrPY4wxplciFqGuqiT3OBaN0JU6OMdOLNK3m1Y8Mu4OiY0ZU2sbW1jcUEVpPBp2UYw5Ko3ov5iIREXkERG5bST3e6g+cuESasviXP2btXa13JijhKpu8/dNwG+AU8IsjyVQNsaY8MUjkmuxE+0vsOODPlFLnmzMYcsmTl42vTbsohhz1BrpyxPvAdaN8D4PWW15go+9eCmPbm3mZ/duCbs4xphRJiIVIlKVnQbOAx4Ls0wL6iuJRoT1OyzPjjHGhCUWjfSbRycW7TsUesy6Yhlz2Lbs7aC1K8XxMyy/jjFhGbHAjojMAC4Crh2pfR6JS06YxvMW1fG529dx94bRzaVhjAndFOBfIvIo8ADwB1X9U5gFKo1HmV9XYS12jDEmRPGo5JIn99tix5InG3PEcomTLbBjTGhG8l/s68AHgcwI7vOwiQjfvHwF8+sqedvPHuKRZ/eHXSRjzChR1c2qery/Hauqnwu7TOC6Y1lgxxhjwhMNdMWKB7pbZVvsdPsWO9Zgx5jDt6axhZJYhEVTqsIuijFHrREJ7IjIi4EmVX1oiPXGbEQagJqyOD990ylMrizhih89wMMW3DHGjKHFU6vZ3tJFS0cy7KIYY8xRKRaJ0JVy1xyjgQTJwRw78aggYpEdYw7X2sYWjp1WbS3fjAnRSP36ngu8RES2ADcB54jI9fkrjeWINFn11aX8/C2nMqEiwWuvvZ9/bdgzJs9rjDFLGtyVq3U7rdWOMcaEIR6V3EAafUbFivZ2xbKhzo05fOmM8tj2FpbPqA27KMYc1UYksKOqV6vqDFWdA1wG/F1VXzsS+x4JMyeWc8vbT2fWxHLedN2DPLatJewiGWOOAkttZCxjjAlVrE/C5INz7HT2pG2oc2OOwMamdjp60iy3/DrGhOqo+SerryrlxreeRm15nPfdvDo3CoIxxoyWuqoSJlYkLLBjjDEhCbbS6Zs82Y+KlUrbUOfGHIFs4mRrsWNMuEY8sKOqd6nqi0d6vyNhQkWCL798OU/taudrdzwVdnGMMeOciLCkoYr1O23Ic2OMCUMw50ewZU400BUrZi12jDlsaxpbqCyJMW9yRdhFMeaodtT9k519TD2vOXUW/3f3Zu56sins4hhjxrklU6t5cmcbqXRBDBhojDFHlWD3q2DLnGyQpzuZ7tOqxxhzaNY0NnPc9Goi9jsyJlRHXWAH4P9dtITFU6u56oZHWG9JTY0xo2hJQzXdqQxb9h4IuyjGGHPUifczEhb0dsvqSqb7BH+MMcPXk8qwbkebdcMypgAclYGd8kSMH71hJRUlUd583Sp2tHSGXSRjzDi1ODsy1g7rjmWMMWOtb8LkQLesbFesVMZa7BhzmJ7c2UZPOmOJk40pAEdlYAegoaaMH15xMs0dPZz/9bu59dHtYRfJGDMOLaivJBYRS6BsjDEh6DMq1oAtdo7aw2Fjjsiabc0AHG8tdowJ3VH9T3bc9Bp+/64zmDu5gnff+Aifv31d2EUyxowzJbEoC+orLbBjjDEhiPsAjgh9coBkkyp3Wo4dYw7bmq0tTCiPM2NCWdhFMeaod1QHdgDm1VXyy7efzqtPncU1/9zMH9bsCLtIxphxZklDtXXFMqZIiMiPRKRJRB4LzPuKiKwXkTUi8hsRqfXz54hIp4is9rfvh1Zw069sV6x43shX2fmqWI4dYw7To43NLJtRi4j9howJ21Ef2AHXTPeTFx/Lilm1fOhXa9i8uz3sIhljxpHFU6vY2drF/gM9YRfFGDO064Dz8+bdARynqsuBp4CrA8s2qeoJ/vb2MSqjGaZsN6toXqucvt2y7HDYmEPV2ZNmQ1M7x1t+HWMKgv2TeYlYhG+/+kRiUeGN1z1owR1jzIhZ0lANwDobhc+Ygqeq/wT25c37i6qm/MP7gBljXjBzWLJdsfJb5fRJpGxdsYw5ZE/saCGdUZZNt8COMYXAAjsB02tdQuW2rhQv/e493LNpT9hFMsaMA7nAjnXHMmY8eBPwx8DjuSLyiIj8Q0TOHGgjEblSRFaJyKrdu3ePfikN0NsaJz+PTrSfRMrGmOF7dGsLAMfPrA23IMYYwAI7Bzlp9gR++87nUl9Vwut++AA/+tfTqGrYxTLGFLG6qhImVyZYbwmUjSlqIvL/gBTwcz9rBzBLVVcA7wNuEJHq/rZV1WtUdaWqrqyrqxubApvcsOb53a3i0Ui/08aY4Vm7rYUp1SVMqS4NuyjGGCyw069Zk8r51TufwzmL6/n0bU/w3l+spjuVDrtYxpgitqSh2rpiGVPEROQNwIuB16i/4qOq3aq6108/BGwCFoVWSHOQXPLk6MAtdix5sjGH7tHGZpZNrw27GMYYzwI7A6gujfOD157EB85bxO9Wb+etP32Izh4L7hhjDs+Shmqe2tVOKp0JuyjGmEMkIucDHwReoqodgfl1IhL10/OAhcDmcEpp+pPNpZPf3SoY6LHhzk3YROR8EXlSRDaKyIf7WV4iIr/wy+8XkTl+figj87V1Jdm8+4AlTjamgFhgZxCRiHDVOQv58suWc/eG3bzxugdo7UqGXSxjzABEJOpzXdwWdlnyLWmooieVYfOeA2EXxRgzCBG5EbgXOEZEGkXkzcC3gSrgjryTp7OANSKyGvgl8HZV3dfffk044rkWO30Pefu02LFRsUyIfHD4O8AFwFLgchFZmrfam4H9qroA+F/gS4FlYz4y39ptLr/OMgvsGFMwYmEXoBi88uSZlMQjvO/mR7nk2//me689kcVT++1Cb4wJ13uAdUDB/UCzdca6Ha0smlIVcmmMMQNR1cv7mf3DAdb9FfCr0S2RORIDDXceDPRErSuWCdcpwEZV3QwgIjcBlwBPBNa5BPikn/4l8G0RCe2Lu6bRBXaWz6gNqwjGmDx2iWKYLjlhOje+9TTau1O89Dv38LHfPsYdT+yiK2nds4wpBCIyA7gIuDbssvRnfl0l8ajYyFjGGDOGcsOdDzIqlg13bkI2HdgaeNzo5/W7jqqmgBZgkl825iPzrW1sYebEMiZWJI5oP8aYkWOBnUNwytyJ/OFdZ/D8xXX86uFG3vrTVVz6nX+zo6Uz7KIZY+DruBwYBZnEJhGLsKC+inU2MpYxxoyZbIud/ATJsT7DndvhsClaoYzM92hjM8stcbIxBcX+yQ5RfXUp333NSaz++Hl8/7Un0ri/k5d+5x7u3rCbfQd6bGh0Y0IgIi8GmvyoNIOtN2JXqw7HkqkW2DHGmLEUG2C4cxHJtdrJHzHLmDG2DZgZeDzDz+t3HRGJATXA3jBG5tvb3k3j/k6WW34dYwqKBXYOUyIW4fzjGrjl7acD8LofPsCJn7mDM750J9+7axPNHT0hl9CYo8pzgZeIyBbgJuAcEbk+f6WRvFp1OJY0VNPU1s3e9u4xf25jjDkaxX1Ap7/uVtlWO/n5d4wZYw8CC0VkrogkgMuAW/PWuRW4wk+/HPi7qmoYI/NlEydbfh1jCosFdo7QkoZq/vzes/jhFSv56EVLmDO5nC/9aT1nfOlO7t+8N+ziGXNUUNWrVXWGqs7BHRD9XVVfG3KxDrKkwbWOXr/T8uwYY8xY6G2xM3BgJ3/ELGPGks+ZcxXwZ9wAEDer6uMi8mkReYlf7YfAJBHZiOtylR0SfcxH5lvT2IIIHDe94MapMOaoZqNijYCa8jjnLpkCwFvOnMe6Ha2868ZHeMOPH+THbzyZkliEuzfs4eQ5Ezl9/qQh9maMGa+WNLjRsNbtaOW5CyaHXBpjjBn/Bsqx07ssbS12TOhU9Xbg9rx5Hw9MdwGv6Ge7MR+Zb01jC/MmV1BVGh/LpzXGDMECO6NgSUM1N7z1VF79f/dz2TX39Vn2/GPq+MiFS1howx0bMypU9S7grpCL0a9JlSXUVZXYyFjGGDNGekfFOrhVTrbFTn9BH2NM/9Y0NtvFKWMKkAV2Rkl9VSk3vPVUvvW3jRw/s5azFk7mN49s4zt3buTCb97NO89ewOtPn82+Az1UlsZoqCkLu8jGmDGwpKHaEigbY8wYybXY6a8rVrT/odCNMf3b2dJFU1u3JU42pgBZYGcU1VeV8plLj8s9ftvz5vPyk2bwmdue4Bt/28A3/rYBgEQ0wldesZxLTpgeVlGNMWNkSUMVP960l2Q6Y3kdjDFmlGVHvOq3K1Yk0ufeGDO4NY3NgCVONqYQWWBnjE2qLOHrl63gFStnsm5HK3VVJdxw/7O856bVPLWrjYX1rovWc+ZPor66NOTSGmNG2tKGanrSGTbtbmfxVEs8aIwxoyk6WFcsa7FjzCFZ09hCNCIsbbDjF2MKjQV2QvLcBZNz/VPPP24qH7hlDd+5c1NueUTcOhcua+DcxfX0pDPcv3kfJfEIL1w6hZJYNKyiG2OOQDaYs35HmwV2jDFmlGUDOv0lSM4Ffaz1pDHD8mhjM4umVFGWsPMQYwqNBXYKQEksyjcvO4H3vXARqkpnMs2fHtvJb1dv4+pfrz1o/cmVCV532hzefva8gwI8bV1JKhIxInb1yZiCNK+ugkQ0wrodrVy6wrpfGmPMaBqsK1bcB33iljzZmCGpKmu3tXD+sVPDLooxph8W2CkQIsLcyRW5x8dOq+F9L1zEk7vauHP9bkrjEU6bN4ndbd389N4t/O9fn+KPj+3gC/+xjOqyOM/sPcAN92/lb+t3MXdSBW85cx7/ceJ0SuMWUTemkMSjERZOqeTv65t485lzqa+yLpfGGDNaBkuenG2xY8OdGzO0rfs6ae5IsswSJxtTkCywU8BEhMVTq/t011jSAGctquPv63fxwV+u4aXfvSe3bFJFgjc9dy4PPL2Pj/xmLdf8cxNf+I/lnD5/Um4dVaW5I0lNWdxa9RgTkneds4D3/mI1F37jbr72yhM4a1Fd2EUyxphxKT5Id6ve1jzWFcuYoazZ1gzA8ZY42ZiCZIGdInXO4in8+b1n8ZcndlEWjzKxIsGp8yZSEouiqvxzwx4+9tvHuPz/7uP0eZOory6hoyfNQ8/sZ9+BHhKxCHMmlfPRi5baSaUxY+z84xq4ta6Sq254mNf/6AHe/rz5vP+8RTZKljHGjLDhtNix5MnGDG1NYwuJaIRFU6rCLooxph8W2ClikypLuPyUWQfNFxGet6iOP7/3LL719w38e+MeHnm2k2hEeP4x9RwztZK97T38bX0Tb7zuQT5zyXG8+tRZqCrrd7bx9/VNiMDLT5xhI3MZM0oWTanid/95Bp++7Qm+/49N3P/0Xr552QpmTiwPu2jGGDNuZHPr9NfdarCgjzGmr0e3NrNkWjWJmF2EMqYQWWBnHCtLRPng+YsHXP6ucxdy1Q0P85HfrOWTtz4OAj2pTG751/7yFGcunMzihmoWT63ihUunUJ6wr4wxI6UsEeUL/7GMMxZM5sO/WsOF37ybL71sORcuawi7aMYYMy70JkjuZ7jzyMCJlY0xvTIZ5bFtLbzspBlhF8UYM4AROUsXkZnAT4EpgALXqOo3RmLfZvRUlsS49vUrueGBZ9ne3EVGlXmTKzhnST0HutP8/L5nuOup3dy9YQ+pjFJZEuPCZVMRhD3t3SyaWsU5i+uZNbGcZDrDhPIEFSXuK7XvQA+Pb2/h5DkTLYGzMUO4aHkDy2fUcNWNj/DOnz/Ma06dxcdevNR+O8YYc4SG12LHWiAYM5jNe9o50JNm2XRLnGxMoRqp5hcp4P2q+rCIVAEPicgdqvrECO3fjJJYNMLrT59z8IIq+OiLl/JRIJnOsHprMzfe/yx/WLODipIYE8oT/OOp3Xzvrk25TRLRCGctmkxteYLfP7qd7lSG2vI4Fy1rYE97Nxt2tXP2MfX81wsXUlUaH7BMXck0G5vaWdpQbQmezVFj5sRyfvn20/mfvzzJD/6xmVVb9vPtV69gofVlN2ZMiciPgBcDTap6nJ83EfgFMAfYArxSVfeLiADfAC4EOoA3qOrDYZTb9C+bIDnez/FE3HLsGDMsj25tAeD4mbXhFsQYM6ARCeyo6g5gh59uE5F1wHTAAjvjQDwa4eQ5Ezl5zsQ+81u7ktyzcQ/7DiSJRYQnd7Xxx7U72NfRw8tOmsGZCyZz25od3LKqkWm1pcycWM6P73maP6zdzgkza3l6zwEAjptWw4wJZbR0JnlmXwf3b95HZzLN606bzacvORZ33GzM+BePRrj6giU8Z/5k3veL1Vz87X/xqZccyytXzrTfgTFj5zrg27iWyFkfBv6mql8UkQ/7xx8CLgAW+tupwPf8vSkQ2dY40X5a5UQHGTHLGNNr7bYWyhNR5tdVhl0UY8wARjxhiojMAVYA9/ez7ErgSoBZsw5O+muKS3VpnPOP65sL5KMXLSGd0dxB0gXLGlDV3EnpI8/u53N/WMem3QeYO7mCdEb518Y9NLV1U10ao766lFesnEEyneFn9z1DeSLKecdO4e4Ne5hUkeCCZQ1MriwB3NDt927eywNP7+OC4xo4Zqq1bDDjw/MW1fHH95zJf928mg/9ai3/2riXz7/0uEFbuhljRoaq/tMfywRdApztp38C3IUL7FwC/FRVFbhPRGpFpMFf8DIFIBYdOI9O3JInGzMsjzY2c9y0mn67NBpjCsOIBnZEpBL4FfBeVW3NX66q1wDXAKxcuVJH8rlNYRCRgw6egi0NVsyawC/f8ZyDtstktE+3K1UlFonwg39u5gf/3IwIqMInbn2cY30Ln23NnaxpdE1Dv/7XDZw2byKLp1YzqSLBgvpKjpteQ3VZnAPdKWrL45b42RSV+upSfvqmU/n+PzbxtTue4tGtzXzr8hXWDNqYcEwJBGt24nIKgmudvDWwXqOfd1Bgxy5uhWOw4E3UkicbM6RkOsMT21t53Wmzwy6KMWYQI3amKyJxXFDn56r665Harzk65OfSERE+9ZJjWTSlkgkVCc5cUMfO1i5ufXQbaxpbeGpXG/FohM+/dBnnLK7n14808puHt/Grhxtp60odtP+q0hhXnjmPN54xl8oSC/CY4hCNCP/5/AWcOnci77lpNS/73j186PzFvPmMuZZ/ypiQqKqKyCFfnLKLW+HIBnT6T5488DJjjPPUrja6UxmW24UlYwraSI2KJcAPgXWq+rWR2KcxkYjwukBi55ryOP89tf/h29959gLeefYCwCVffnJnG2u3tdCVTFNREuNv65r46h1P8aN/P83rTp/DFafPZpLv0jUcqXSGZ/Z1WN9iE4qVcybyh3efwYd+tYbP3b6Of2/aw1dfcfwhfYeNMUdkV7aLlYg0AE1+/jZgZmC9GX6eKRDVZXFWzp7Acf2M5pMN+vQ3FLoxxsm2jl9uI2IZU9BGqunCc4HXAWtFZLWf9xFVvX2E9m/MsJXGoxw/s7ZPl5XLT5nFI8/u5zt3buKbf9vAd+/cyLHTqlkxawIzJpQxubIEEUillbl1FZwwozbXImLrvg7ec9MjPPxsMx+5cDFXnjU/pFdmjma15Qm+/9qTuP6+Z/jMH9ZxwTfu5uuXncBz5k8Ou2jGHA1uBa4AvujvfxeYf5WI3IRLmtxi+XUKSzwa6bcLOPQmTbYWO8YMbE1jC9WlMWZPKg+7KMaYQYzUqFj/Auxf0RS0FbMmcO0VK9nY1MavH97GQ8/s5xcPbqUzmT5o3fqqEpZNr0EE7t+8D4DT503i87evJ52BS1dMoyuZ4bFtLTz87H6Wz6jh0hOm28hFIRGRUuCfQAmuXvulqn4i3FKNPBHXiu2k2RO56saHec219/Ou5y/g3ecutFFdjBkhInIjLlHyZBFpBD6BC+jcLCJvBp4BXulXvx031PlG3HDnbxzzApvDlmux08+IWcYYZ01jM8tn1NoxrjEFzpKNmKPOgvoqPni+69KlqrR2pdjb3k1G3VW7R7c285cndrJlTwcAp86bxCcuXkpDTSnv+cVqvvSn9XzpT+tz+4tFhFRG+fNju/jYxUupryqxZt1jrxs4R1Xbfb6vf4nIH1X1vrALNhqWTqvmtnedwcd/9zjf/PtG7t28l29ctoJptWVhF82Yoqeqlw+w6Nx+1lXgP0e3RGa0ZIdCt+TJxvQvm97gyrPmhV0UY8wQLLBjjmoiQk1ZnJqy3mGk506u4NIV0/td/xuvOoHzlk6hoydNPBph0ZRKFk+t5rp7nuYrf36SPz2+E4Dq0hhzJ1ewpKGatz1vPnMnV4zJ6zla+ZOrdv8w7m/jOjlpeSLG/7zieM5YMJn/95u1XPCNu/nKy5dz3rFTwy6aMcYUhdxQ6NYVy5h+rdvRSiqjLJ9RG3ZRjDFDsMCOMYcgFo1wyQkHB32uPGs+zz+mnns376WlI0lTWzeb97Rz66Pb+dXDjbz6lFlUlcbZ0dJFXVUJi6dWEYkITa1dlMajnDZvIvPrKq2Z6xEQkSjwELAA+I6q3t/POuNuuOFLV0zn+Jm1vOvGh7nyZw8xe1I5C+srWVBfxcL6ShZNqWJ+fQXlCavujTEmKJYb7txa2RrTn1zi5BmWONmYQmdH+saMkIVTqlg4parPvKa2Lv7nz0/y0/ueISLC5MoE+w70kEwf3JikPBGlpizOpMoEZy2s40XHTmX5jJpcsOepXW0ALKirtKGu+6GqaeAEEakFfiMix6nqY3nrjMvhhudOruBX73gOP7v3GR7Z2szGXe3846ndfb5nMyaUsbC+koVTqlhQX5mbriyxvwFjzNEpF9ix/1Rj+rWmsYXJlSU01JSGXRRjzBDsiN6YUVRfVcqXX348H7/4WEpjEWLRCMl0hqf3HABgSlUp+zt6uP/pvazf2UZbV4qt+zr4wT838927NjG9toyzj6nj0cZmHtvWCkBVaYzT5k3i4uOn8YIl9YfUEuMfT+3m0a3NXLisgQX143PodlVtFpE7gfOBx4Zaf7woiUV5y5m9feCT6QzP7utgw642NuxqZ0OTu/170156UpncetNqSlkwxbXuccEe19on2D3xaKKqdKcydPak6Uql6exJ05lM05XMoKpkFDKqZFTR3DT+sZLJ9M4bzvo6SuHFiAiRiBCNuOloRIhm54kQjfr7iPQu9+vGIhEiEQ7eJuKmI+LWE3/vbq5ra3SI5cYUkmxLHcuxY0z/1jQ2c3zgIqMxpnBZYMeYMRBsFeFy8/S27KkpjzMnLwfP/gM9/HXdLv702E5uWdXI/PpKPnnxUipL4zz0zH7+vn4Xdzyxi5JYhOfMn8Rx02t4bFsLj29vZebEcpZNr2HRlCrmTq6gLBHlQHeK6+97hj8+5nIAfe2Opzht3kT++0WLOWn2hLF5E0aRiNQBSR/UKQNeCHwp5GKFKh6NML+ukvl1lZx/XO/8dEZ7Az5N7WxsamdDUxs/v38vXcnegM+U6hIW1rvWPXMnV1BZEqMsEaUsEaU8HqU84R6XJ6KUxd38klhkVA/+0hmlM5mmoydFV08mN92Z7A3AdPSk6fKPu5IZH5TpXR58nA3cZNfLzhutYIshF/QJBoiOnVbNLW/vfzhqY0ZTdphzG+7chE1Ezge+AUSBa1X1i3nLS4CfAicBe4FXqeoWv+xq4M1AGni3qv55JMrU3p1i4+52LlreMBK7M8aMMgvsGFOAJlQkeMXKmbxi5UzSGe1ztfvlJ80gnTmOB57ex58f38mdTzZx55O7WVBfyXMXTKZxf0e/w7iXxCL894uO4aUrpvO71dv58b+f5mXfu4dLTpjGh85fXOwjKjUAP/F5diLAzap6W8hlKkjRiDB3cgVzJ1dw3rG989MZZdv+TjY0uYDPhl3tbGxq4+ZVW+noSQ+8w4CI0G/Ax03HKPfTpXF3LwKdPRk6kyk6e1xQJhikCQZrOpPpPi2NhisRi1AWj1Iaz967MpXFXdfHUj+dXV4Wj+bmZctfGndBq2zrlv5aokQC83LLI+Stc3BrltGKg2UU0mklrUo641oMpTO906mMkvGP076lkbt381KBbYLb9ttKyU8PtTzbisntM9uiSZlSbU38TTjiURvu3ITPH7t8B3dRqhF4UERuVdUnAqu9GdivqgtE5DLcxatXichS4DLgWGAa8FcRWeS7px+Rx7e1oArHW+JkY4qCBXaMKXD9XUmMRoTT50/i9PmT+CTH0tmTpiwRzS3PZJQdrV1s3t1OTypDWSLKgrpK6v0J1DvOns/rT5/N9/+xiWv+uZlLV0wv6sCOqq4BVoRdjmIWjQizJpUza1I55y6ZkpufySh7D/TkWsZ09KRzQZiOnlTf4EsuOJPKPc4u23egk86eVJ/gjQLlPpBSnguwRHP5prIBmPKEXyceoywRoSwRy80PblMWCCZl59uVeGPMQI6ZWs0xU6qosFxjJlynABtVdTOAiNwEXAIEAzuXAJ/0078Evi3uit8lwE2q2g08LSIb/f7uPdJCPb3nACKwzBInG1MU7J/MmHEgGNQBiESE6bVlTB8kWFNREuP95x3D60+fQ11VyWgX0RSpSET892PkvyOqav32jTGhed6iOp63qC7sYhgzHdgaeNwInDrQOqqaEpEWYJKff1/etgcP38qhjwx62SmzePHx02yQBWOKhLU9NeYoZ0EdExYL6hhjjDFjQ1WvUdWVqrqyrm54AU0L6hhTPCywY4wxxhhjjDHh2AbMDDye4ef1u46IxIAaXBLl4WxrjDkKWGDHGGOMMcYYY8LxILBQROaKSAKXDPnWvHVuBa7w0y8H/q6q6udfJiIlIjIXWAg8MEblNsYUEGtfZ4wxxhhjjDEh8DlzrgL+jBvu/Eeq+riIfBpYpaq3Aj8EfuaTI+/DBX/w692MS7ScAv5zJEbEMsYUHwvsGGOMMcYYY0xIVPV24Pa8eR8PTHcBrxhg288BnxvVAhpjCp64VnwhPLHIbuCZYa4+GdgzisUZScVUViiu8lpZR89wyztbVcfFECKHUAeN18+yEFhZR08xlfdQyjou6qBB6p9i+tyyiq3MVt7RVWzlBTsGGkyxfZ7FVF4r6+gppvKO2DFQaIGdQyEiq1R1ZdjlGI5iKisUV3mtrKOn2Mo7lortvSmm8lpZR08xlbeYyjraivG9KLYyW3lHV7GVF4qzzGOl2N6bYiqvlXX0FFN5R7KsljzZGGOMMcYYY4wxpkhZYMcYY4wxxhhjjDGmSBVLYOeasAtwCIqprFBc5bWyjp5iK+9YKrb3ppjKa2UdPcVU3mIq62grxvei2Mps5R1dxVZeKM4yj5Vie2+KqbxW1tFTTOUdsbIWRY4dY4wxxhhjjDHGGHOwYmmxY4wxxhhjjDHGGGPyWGDHGGOMMcYYY4wxpkgVdGBHRM4XkSdFZKOIfDjEcvxIRJpE5LHAvIkicoeIbPD3E/x8EZFv+jKvEZETA9tc4dffICJXjFJZZ4rInSLyhIg8LiLvKdTyikipiDwgIo/6sn7Kz58rIvf7Mv1CRBJ+fol/vNEvnxPY19V+/pMi8qKRLmvgeaIi8oiI3FYEZd0iImtFZLWIrPLzCu57UMgKoQ6y+mdUy2t10CiV1eqfQ1cI9Y0vR9HUOf55rN4Z5XrHP1dR1D2B57I66AgVQp1UTPWR1UV2DBR4jnDqH1UtyBsQBTYB84AE8CiwNKSynAWcCDwWmPdl4MN++sPAl/z0hcAfAQFOA+738ycCm/39BD89YRTK2gCc6KergKeApYVYXv+clX46Dtzvy3AzcJmf/33gHX76ncD3/fRlwC/89FL//SgB5vrvTXSUvgvvA24AbvOPC7msW4DJefMK7ntQqLdCqYOs/hnV8lodNEpltfrnkN+vgqhvfFmKps7xz2X1zijXO/75iqLuCZTX6qAje/8Kok4qpvrI6iI7BgqUM5T6Z0x/nIf4hpwO/Dnw+Grg6hDLMyevUnkSaPDTDcCTfvoHwOX56wGXAz8IzO+z3iiW+3fACwu9vEA58DBwKrAHiOV/D4A/A6f76ZhfT/K/G8H1RriMM4C/AecAt/nnLsiy+n33V6kU9PegkG6FVAdZ/TP65bU6aMTLavXPob1fBVPf+OcvyjrHP5fVOyNfzqKpewL7tzroyN6/gqmTirU+srpoxMtYNPVQWPVPIXfFmg5sDTxu9PMKxRRV3eGndwJT/PRA5R7z1+Obna3ARWALsry+Sd1qoAm4Axc5bVbVVD/PmyuTX94CTBqrsgJfBz4IZPzjSQVcVgAF/iIiD4nIlX5eQX4PClQhv/aC/xyLof7x5bQ6aHTKavXPoSn011oUn53VO6P2Hn+d4ql7sqwOOjKF/NoL/nO0usiOgQih/okdaakNqKqKiIZdjiARqQR+BbxXVVtFJLeskMqrqmngBBGpBX4DLA63RP0TkRcDTar6kIicHXJxhusMVd0mIvXAHSKyPriwkL4H5vAV4udYLPUPWB00iqz+GacK9bOzemd0FGHdk2V10FGgED9Hq4tGXhHWQ6HUP4XcYmcbMDPweIafVyh2iUgDgL9v8vMHKveYvR4RieMqlJ+r6q8LvbwAqtoM3IlrRlcrItmgY/B5c2Xyy2uAvWNU1ucCLxGRLcBNuGaA3yjQsgKgqtv8fROusj6FAv8eFJhCfu0F+zkWY/0DVgeNcFmt/jl0hf5aC/qzs3pnVMtbVHVPltVBR6yQX3vBfo5WF41aeYuqHgqt/hnpPmUjdcO1JtqMS2yUTdp1bIjlmUPf/p1foW8CpC/76YvomwDpAT9/IvA0LvnRBD89cRTKKcBPga/nzS+48gJ1QK2fLgPuBl4M3ELfRFjv9NP/Sd9EWDf76WPpmwhrM6ObkO9sepN2FWRZgQqgKjB9D3B+IX4PCvVWSHWQ1T+jVl6rg0ahrFb/HNZ7VjD1jS9PUdQ5/rms3hmDesc/Z0HXPYFyWh105O9hwdRJxVIfWV1kx0D+OUKrf8b8x3mIb8yFuIzim4D/F2I5bgR2AElc/7Y34/rp/Q3YAPw1+0b7D+U7vsxrgZWB/bwJ2Ohvbxylsp6B69e3BljtbxcWYnmB5cAjvqyPAR/38+cBD/jnvQUo8fNL/eONfvm8wL7+n38NTwIXjPL3IVihFGRZfbke9bfHs7+fQvweFPKtEOogq39GtbxWB41CWa3+Oez3LfT6xpejaOoc/zxW74xBveOfr6DrnsDzWB00Mu9j6HVSMdVHVhfZMVCgTKHUP+I3MsYYY4wxxhhjjDFFppBz7BhjjDHGGGOMMcaYQVhgxxhjjDHGGGOMMaZIWWDHGGOMMcYYY4wxpkhZYMcYY4wxxhhjjDGmSFlgxxhjjDHGGGOMMaZIWWDHHBERea+IlIddDmPM0cnqIGNMWKz+McaEyeogE2TDnZsjIiJbgJWquifsshhjjj5WBxljwmL1jzEmTFYHmSBrsWOGTUQqROQPIvKoiDwmIp8ApgF3isidfp3zROReEXlYRG4RkUo/f4uIfFlE1orIAyKyIMzXYowpPlYHGWPCYvWPMSZMVgeZoVhgxxyK84Htqnq8qh4HfB3YDjxfVZ8vIpOBjwIvUNUTgVXA+wLbt6jqMuDbfltjjDkUVgcZY8Ji9Y8xJkxWB5lBWWDHHIq1wAtF5EsicqaqtuQtPw1YCvxbRFYDVwCzA8tvDNyfPtqFNcaMO1YHGWPCYvWPMSZMVgeZQcXCLoApHqr6lIicCFwIfFZE/pa3igB3qOrlA+1igGljjBmS1UHGmLBY/WOMCZPVQWYo1mLHDJuITAM6VPV64CvAiUAbUOVXuQ94brbfpu8Luiiwi1cF7u8dm1IbY8YLq4OMMWGx+scYEyarg8xQrMWOORTLgK+ISAZIAu/ANeX7k4hs9/073wDcKCIlfpuPAk/56QkisgboBgaKJhtjzECsDjLGhMXqH2NMmKwOMoOy4c7NmBAbjs8YEyKrg4wxYbH6xxgTJquDjg7WFcsYY4wxxhhjjDGmSFmLHWOMMcYYY4wxxpgiZS12jDHGGGOMMcYYY4qUBXaMMcYYY4wxxhhjipQFdowxxhhjjDHGGGOKlAV2jDHGGGOMMcYYY4qUBXaMMcYYY4wxxhhjipQFdowxxhhjjDHGGGOKlAV2jDHGGGOMMcYYY4qUBXaMMcYYY4wxxhhjipQFdowxxhhjjDHGGGOKlAV2jDHGGGOMMcYYY4qUBXZCJCJ3ichbQnje74vIxwKP3yEiu0SkXUQmjXV5jDGHJqy6Yygi8lIR2errkhUjsL/rROSzQ6xzpog8GXj8uIic7ac/IiLXHmk5jDHjX35dYoZPROaIiIpIzD8uyP8oY452/vhsXtjlGGlWBzkW2DkCIvKkiCwS50sistffviQiMsrPvUVEXtDP/MtF5IbBtlXVt6vqZ/z6ceBrwHmqWqmqew+jLMeJyJ9FZI+IaD/LJ4rIb0TkgIg8IyKvzlv+aj//gIj8VkQmHmoZjCkmxVp3DMP/AFf5uuSRI9zXsKjq3ap6TODxsap6l5/+vKq+BQ7+0x8tIvIGEfnXaD6HMaMtW0cdxnafFJHrh7luwh83VB56CUdefl0ykkTkChF5SERaRaRRRL4crIsGqpcH2Z+KyILRKOtoEZEJIvJZEXlMRPaJyGYRuSb/JFNEGkTkVhHZ7l/nnJCKbEJ0NNZBInK2iDQewvoHHW+IyH/731ibiDwtIv8dXO6PzzaPVJmLydFQB1lg5zCJyHwgqqpPAVcClwLHA8uBi4G3hVS0i4DbD2H9KUAp8PjhPJk/MEkCNwNvHmC17wA9/rleA3xPRI712x8L/AB4nV/eAXz3cMoyVvzJuP12zGEZR3VHf2Zz+HVJ9Aif2xyi0Q5ymeKUV0flLzui32ne/+dZwGpVbT+SfRaJcuC9wGTgVOBc4ANjXYiwfvMishh4AIgBLwPqgJOAe4G/iMh5gdUzwJ/8euYoZHXQERHg9cAE4HzgKhG5LNwi9bI6aJSpqt0GuAGvAtoDt27gLr/s3cA3/fQ9wJWB7d4M3OenS4Hrgb1AM/AgMMUvuwv4DPBvoA34CzA5sJ+X4E6Smv26S/z8n+G+dJ2+XB/08yPAruw+gDN82ZqBrcAb/PzrgM8Ci4ADgPr9/N0v/4ZfvxV4CDgzUKZPAr/0r6kVeEtg2QL3lerzHlbggjqLAvN+BnzRT38euCGwbL5fv2qIz+bDwCb/vj0BvDRv+VuBdYHlJ/r5M4FfA7v9Z/LtwOu6PrD9HP++xAKf1ef8Z9XpX+sbA8+xGXhbXhkuAVb792kTroJ9BfBQ3nrvA34X9vfdbiN3o/jrjoG2/xBwf+B38Q6/Xo3fn+LqlE1++RK/fbNf7yWBMl4HfA8XTDoAvABYATzsX9MvgJuAzw7xXp8NNAYebwFe4Kc/if9dA8/SW9e1A6cPss83+Pf220ALsB44N7C8BvghsAPYhqtPo/71dgFp/xzNQ5T9IuARXB2xFfhk3vKB6vAy4KvAM758//Lz+rwXA7wffepv4BTcgU2zfz3fBhKB7Y8F7gD2+e/IR4CpuCD8pMB6J+Lq1XjYvz+7DX1j+HVUf7/TacCv/Of9NPBuv+75uP/vpN/no37+XeT9f/r5XwPe56cnAj8GtgP7gd8GyvpWYKP/Dt4KTBvs++nnlwBf9/vb7qdLhnhP+vx+/G/nv4E1/rX/EHcB6o+4OuqvwITA+q/3v8m9wMeCv71+nut9wO/99EH1sv98ngaq/ToXADtxJyP/pLeubQdeNdRrwtXdO/1zReg9htqLuzA3MbDNQPXOgPUV/R8zvcVPJ3D1/wsHKONs4CmgNm9+zO9zTti/F7uN/A2rg/LfjwpftkzgPZnmX/dXA+vdBPyIYR5vAN8EvhV4rIHXfx3uYvof/T7+jft//7p/D9YDKwLbnujrgDbgFtxx2rCO0bA6aGx+V2EXoFhuQDXuJP5t/vGfgBf56Rbg1MC6K4E2P/024Pe4qzVRXHQw+0d9l/9SL8IdlN9Fb8AjG3R5IRDH/dFvxB9w088BA3AacK+fnu1/eJf77ScBJ/hl12V/iPk/BD/vtX79GPB+/0Ms9cs+iaswL/U/zLLAdv0FdlYAHXnzPkDvAc3vgA/lLW8HThri83gFrsKL4P4cDgANgWXbgJNxkesF/v2IAo8C/4urQEuBMwKva6jAzrO4Cjzm39OLcIEoAZ6HO9HJBpBO8d+LF/oyTgcW4yr6ffgTZb/uI8DLwv6O2210bhRf3THg9v67/E//e1mI++MP/ukHDxjifruP+G3PwdVJx/jl1/nX/1y/32rcSdF/+W1fjqtrRiqw0+c3PcQ+3wCkAmV5lS/rRL/8N7iWhhVAPe4q0NsC2/5rmN+Ns4Fl/vUvxx0UXuqXDVaHf8d/5tP9d+M5uLqlz3sxwPvRp/7Gfa9Ow9Vrc3Df1ff69atwwZ734+rLKvz3FXew+Y7A8/wvgYNHuxXPjcHrqPzfaTnugs/H/e96Hu7CRnb93G8usP+7yPv/9PPX01sf/AF3kjDBf9+f5+efA+zBnVCUAN8C/jmM7+engfv877MOd5LwmSHehz6/H//buQ8XzJkONOECzyv88/0d+IRfdynu2OUM/778j/+tDRTY+S2+zg48V369/HP//k/CnRi+OLAsV9cO4zWlgC/5968MeI9/XTP8vB8AN/r1B6t3zmbg+moOA59UXQH8wE8vw12k2A18CrjHz/9/uG68wbIX3UmV3Q7vhtVB2XKezcH/4VNxdc85uF4Pm/EXvxnieAN3fvII8PbAvOBx2nX+tZ1Eb532NC5IHcVdtLrTr5vAHaO9x78//4ELog3nGM3qoLH6LYVdgGK4+S/QbcD3/ONyXISxxD9OA4sD6y/0XwQB3uR/zMv72e9dwEcDj98J/MlPfwy4Oa8M24Cz/eMtHHwQ8BngY376auA3A7ye6xgksNPP+vuB4/30J/EVWj/r9RfYORPYmTfvrfRG5f9GoMLx83Kv8xA+o9XAJX76z8B7+lnndP9DPui1MrzAzqeHKMNvs8+Lq6T+d4D1vgd8zk8f69/fQa8k2q04b0Vadwy1/RxccHIdcHXefoIHDGfigsKRwPIb8VdYcPXQTwPLzsKdvEhg3j2EF9jJL8sD9HYZ7aZvUPtyeg9+3sAwAzv9PO/Xs/UGA9Th/vPoxNfJg70XA7wf/dbfgfXfm31e/7oeGWC9VwH/9tNR/1mfMla/LbuNzI2h66j83+mpwLN5+7ga+LGfzv3mAsvvIu//E3dRZKOfbsBdpZ7QT/l+CHw58LgSFzCZM8T3cxNwYeDxi4AtQ7wX/dUlrwk8/lX2ffKP34W/qo87ybwxsKwcd9JzUGAHV7c30reV5Zb8dYFa3MnoWvxJSWBZrq4dxmvqwV+c8/PW0bcFYoN/T2MMcuzYz76/Tm99NYeBT6quB57vp+/HdUGO+fstfv5F+BbUgf0X3UmV3Q79htVBwW3OJu8/3M9/Ga6Fyh78BWk//w0MHtj5FO6CdklgXvA47Trg/wLL3gWsCzxehm8JhDtG20bf46J/MbxjNKuDxuhmeUKG53O4KOy7/eNzcRG+bv+4HRdtzqoG2tV9K36GCzTc5JMwfVlcwuKsnYHpDlyFAa41yjPZBaqawf2opw9SzgvpzZExE1epHDIR+YCIrBORFhFpxnU7mBxYZesh7C7/vcE/bhvm8oHK+HoRWS0izb6MxwXKONBrnwk8o6qp4Re/jz6vW0QuEJH7fAKuZtz7P1QZAH4CvNonyX0d7iS6e4B1TXErxrpj0O1VdQtwJ+5P9DuD7HMasNVvn/VMXjm25q2/zb/24Pph6a8s03BXk+LAjkD98wPcVblDIiKnisidIrJbRFqAtzN0HTIZd2XtsOp3Dq7HFonIbSKyU0Racd1jh1OP/Q5YKiJzca27WlT1gcMskwnPUHUU9P3OzAamZb/7/vv/EVzAczD5xw0X4pr/g/ue7VPV/f1sl18fteNO+qYz+Pezz3b0/n4P1a7AdGc/j4P1bu41qmqHL2cfInIp8AXgAlXdM9gTq2ozrrvDcbiul4drt6p2BR7PBn4T+PzW4S4yTGGQ93SI+mow9bgTQnAnitf747BggtuZgXXM0cXqoKH9HncB5UlVHdbgDCJyFa7lzUVDnGMcSh2Xf1w03PNBq4PGiAV2huATTl0OvFxVk3528CQIXL+94wOPj/fzUNWkqn5KVZfimsu/GPdDG8p23Bc/Ww6h75cu+MNCRKbiIp4P+1lbcdHoQyIiZ+K6XrwSF7muxTWBDI7Uo/1sOpCngJiILAzMy70/5L13PjN5id9uoDLOBv4PuAqX46EWeCxQxoFe+1Zg1gCJuw7grhJkTe1nndzrFpES3NW7/8HlPanFfSeGKgOqeh8uen0m8GrcCbwZZ4q47hh0exG5CNf67W/AV4Yox8y8ROOz6PvHGSzLDmC6f77g+iPlUOotBijLdtxvuxt3tb3W36pV9djDeJ4bcP31Z6pqDfB9hq5D9uD61fe3rE895pNM1uWtk1++7+Gaoy9U1WrcAXKwDP0Oi+oP0m7Gdd19HVaPFZ1h1lHQ9zuzFXg68N2vVdUqVb2wn3UH2kf+82wFJopIbT/b5ddHFbim+dsY5PuZvx29v9/RsgPXrQAAESnDlZPAvPNxxy4Xq+ravO0Pet9E5ARc654bcXkyDlf+vrfiAkvBz7BUVbPv6UDHjoPVV4PZg/uPAdf66LW+bnotgIichGspcKSjMpoiY3XQkGXM+hwu+NEgIpcPtb6IvAmXw+ZcVR32KFtD6O8YbeYwt7U6aIxYYGcQIrIC15fyUlXdHVh0Aa4vZtZPgfeJyHQRmYbra3md38fzRWSZ/wK14pqaBa9gD+Rm4CIROddfpX8/7mTiHr98F30rkwtwXTGyP56fAy8QkVeKSExEJvmDhKFU4fpC7sYFZD7OwS1q+vAZ5ktx/S8RkVIf+EBVD+CSFX9aRCpE5Lm4pMLZk4CfAxeLyJm+svw08GtVHazFTgWuktjtn++NuCtaWdcCHxCRk3zZFvhg0AO4iumLviylvjzgunKdJSKzRKQG1xRwMAlcAGo3kBKRC4BgRvUfAm/0n1/EfzcWB5b/FJekNDnc6LspHkVedwy4vYhMxv2+3oLrs3yxiFxI/+7HtST6oIjEReRs3KhfNw2w/r24uufdfv3/wOWqGim7ce/fQAdh+eoDZXkFLlHh7aq6A5es+qsiUu1/3/NF5Hl+u13ADBFJDOM5qnBXCbtE5BRcoDer3zrct4D6EfA1EZkmIlEROd3XuU8BpSJykf/sPoqrp4YqQyvQ7uuodwSW3YY7kHyviJSISJWInBpY/lNcU/CXYIGdonIIdVS+B4A2EfmQiJT5799xInKyX74LmCODjBwpIuW43/adAP439Ufgu+KGo42LyFl+9Rtx/6Un+O/454H7fcvBwb6fNwIfFZE6X299nL5XZ0faL3H14XP8b/+TBE44ROQc3G/6ZQO0bOtTL/tjqutxgdY34k6o3jnQ+ofo+8Dn/HER/j26xC8b7NhxsPpqMH/H5UwD99/xVlzrhQW4YPRngNepaq51g3/92bqrxD8244jVQf3aBUzy5yHZsp6FqwNejzvu+paITA+s3+d4Q0Re48v4Qh3ZYc3vxbWqucrXDZdw+MdoVgeNEgvsDO4SXBKtf4lIu7/diesq8WxgvR/gmsmtxbUc+YOfB67lxy9xB87rgH8wjANgVX0SF0n8Fi7SeDHuKk+PX+ULuAqjWUQ+QN5Qxb58F+JOyvbhAhfBlgED+TMuadlTuC99F0M3tZuNa66XbYXTCTwZWP5OXLKsJlxF9w5VzbZKeBzXlO7nfnmVX39AqvoErlnyvbhKbRkuk3t2+S246PYNuC5dv8UlPU3j3scFuH7rjbg8EajqHbikaWtwSdluG6IMbbhmozfjcuS8GhdFzi5/AFcR/y+uxdM/6Bu5/xkuGDWaB5omPMVcdwy2/TW4EdxuV9W9uFG8rhWRPlem/X56/LYX+P18F3i9qq4foNw9uGR8b8DVWa/CBYVHhO8a8Tng3/61nzbEJvfjch7t8du93L9mcAdYCdyIe/txn1P2atDfcXXhThEZtKsFrq77tIi04Q76bg6Ud7A6/AO478yDftmXcLmMWvw+r8VdTTyAq+cG8wFc/dWGa03wi0AZ2nDdrC7Gdf3bADw/sPzfuGDZw8EDIlMUhltH9eH/R18MnIBLsrkH933Lnojc4u/3isjDB+3AOQeXrD3YNP91uOD1etyxwHv98/0Vl/frV7gLM/OBy/yywb6fnwVW4f7T1+JaJH52sDfkSPhjmXfhAtc7cN1sm3BBcfxrqAFuD7zffwzsIr9e/gKuK+v3fDeK1wKfld7Wz58EfuLXf+UhFvcbuOOVv/i65z5c3pKh6p0B66shXA+8UETOVtW1qnqyqs5Q1Q/6lo4vUdX870p2hDBw34nOQ3yNpvBZHXTwa1uPO0/a7H/bs3AXUK5S1W2qejfuwvGPRUTo/3jjs7gWRQ8G3tfvD/a8wxE4RnszbrSq1+LOlQ4nlYTVQaNEtE9XOTMUEfkgrgn+B8MuS5a4rkU7gXmq2hp2eczQxDXTbsKNorUh7PKY0Wd1R/EQkTfgku6dEXZZCp2I/B24QVWvDbss5siMVR0lIt8FHlPV747m84RJRCpxJz8LVfXpkIsTOhFZhsvLdQ3uQt42YC4uGFamqm8LsXimQFgdVFxE5H7g+6r647DLMpSjpQ6yFjuHbgtQaF/gibgRbezErHi8A3jQgjpHlS1Y3WHGEd/0/UQCrXxMUdvC2NRRq4HfjMHzjCkRuVhEysV1K/8f3FX6LeGWqjD4nEKn45Kj/g3X0vFW3BX594VYNFNYtmB1UMESkeeJyFTfReoK3JDjfwq7XMNxtNRB1mLHFCTf/PCJARYvHayZZqETkS24vveXquojIRfHmIInIh/B5ZrId7eqXnCY+/w+PnFenutxzYJHpMWOiDxO326YWW9T1Z8f6f7DIiI/AS4F3qOq14VbGmOGZzTqksC+r8XlcRBcF4x3+q6to2o0X5MxZmQV8+9VRK7E5aOpADYDV6vqH4r5NY03FtgxxhhjjDHGGGOMKVLWFcsYY4wxxhhjjDGmSMXCeuLJkyfrnDlzwnp6Y8xheOihh/aoal3Y5RgJVgcZU3zGSx1k9Y8xxWe81D9gdZAxxWioOii0wM6cOXNYtWpVWE9vjDkMIjJuhjS2OsiY4jNe6iCrf4wpPuOl/gGrg4wpRkPVQdYVyxhjjDHGGGOMMaZIWWDHGDOmROR8EXlSRDaKyIf7WX6WiDwsIikReXlg/vNFZHXg1iUil/pl14nI04FlJ4zdKzLGGGOMMcaY8ITWFcsYc/QRkSjwHeCFQCPwoIjcqqrBoe2fBd4AfCC4rareCZzg9zMR2Aj8JbDKf6vqL0et8MYYY4wxxhhTgCywY4wZS6cAG1V1M4CI3ARcAuQCO6q6xS/LDLKflwN/VNWO0SuqMcYYY4wxxhQ+C+wYM8K2NXfyxPZWkukMyXSGnlSGZFp7H6czJFN5jwPzco/9Nm579zgWFSoSMSpKolSUxKgsiVGeiFHpH1eUxHLLK/MeV5TEKIlFEJEw357pwNbA40bg1MPYz2XA1/LmfU5EPg78DfiwqnYfXhGNMaOhJ5Xh2X0dbNlzgC17D/C0v59UUcI3L18RdvGMMQVCVdnd3s3GXe08tauNDU3tbGxqpyuZpiQWJRGLUBKL5O6D80riERLRqL93j/ssz20XzT2eVltGRYmdEhljDk86o7R3p9ytK0V7d5K2rt7HbV0p2gLL2rtTueXzJlfy1VcePyLlsFrMmCOkqqzb0cZfntjJHU/s4vHtrcPaLhoR4lEhHnUHH/FohHgs77FfXlESIxYRUr7i2N3WTXt3igM9KQ50p0imdVjPGYsI5YlA0KfEBX3e+4JFnDxn4pG8DWNGRBqAZcCfA7OvBnYCCeAa4EPAp/vZ9krgSoBZs2aNelmNGczvH93OY9tbuHj5NI6dVh120HVEJNMZGvd3smVPb+Ame79tfyeZQFVVUxZnzuQKjplSEl6BjTGhUVV2tXazoamNDbva2dDUzgYfyGnpTObWqy6NsXBKFbXlCbpTaTp6UuzvyNCdche/ulNpf+9u6czwjomyfvyGk3n+4vqRfnnGmCKWzijP7D3AU7va2djUxjN7O3LBmLbuFG1dSR+oSdHRkx7WPitLYlSVugvzlf5+cmVixMpsgR1T8Fo6k6zb0UplSYxjplYRj4af8zuZzvDg0/v4yxO7uOOJXWxr7kQETpw1gQ9fsJhT506kPBHrDdzE+gZq4tEI0cjIncT1pDIc8JFiF+xJc6A71TuvO8WBnuA8d2CUXZY5xIOgI7ANmBl4PMPPOxSvBH6jqrmjPlXd4Se7ReTH5OXnCax3DS7ww8qVK8fsRRuTb+u+Dj5wy6N0pzL84B+bmV9XwaUnTOeSE6Yza1J52MUbVDqjbNvfydN7D/QJ4GzZc4DG/Z2kAvVJVUmMOZMrOGHmBF56wnTmTK5gzuQK5k6qYELFyB3MGAMuUJBMK53JNN3JNJ3+1pXM0NmTpivpbp3JNMl0hvrqUmZOKGN6bTlliWjYxR+3VJUdLV29gZtd7S6Y09ROW1cqt15teZxF9VW8eHkDC+srWTilioX1ldRVlRxS4DvlWz/ngj3JDD1p9z3oSbvHwWDQsdOqR+NlG2OKQDqjbN3XkWsd+NSuNp7a1c6m3e30pHqzQkypLqGmLE5VaZzasjgzJpRRVdI3SFNVGqOqNJ6bVxVYVpGIERnBc7/+WGCnwKkqq7c209yRpLY8Tm15ggnlcapL46P+5QjDnvZuHt/eymPbWnh8ewuPbWvl2X29aVRK4xGWTa9hxawJrJhZy4pZE5haUzomZTvQneIfT+3mjid28ff1TbR0JknEIpy5YDLvOmcB5y6ZQl1VOFeeE7EIiViiGE6UHgQWishcXEDnMuDVh7iPy3EtdHJEpEFVd4g78rsUeGwEymrMqPnU758gGhFuf/eZPLJ1P797ZDtfveMpvnrHU5w4q5ZLV0znomUNTKoMtzVLOqM8sb2Vf2/aw6ot+3l6Tztb93XSk+492ClPRJk9qYKl06q5cFkDcyZXMM8HcCZVJMZFSyQTjqf3HOD3j25ne3OnC9L0pOlKZejqyQZseu+7khk6k+lDbq2RNakiwYwJZcyYUO7ve6enTyijPFG8h8ypdIbmziTNHT3sO5Bkf0cPzR09uVYxEREiIkQjQiQiREWIRkAkO+3mRwSi0rtOJEJuu+z8iAgHulO5VjhPNbWzqamd9u7eAM7kygQL6iu59ITpLJxSycL6KhZOqRyx+iIWjRCLRigv+EMiY8xYyWSUrfs7eCrbxTPQzbM7EMCZXlvGgvpKzlgwiYVTqlg0pYoF9ZVUFkF3zcIv4VGqK5nm949u57p7tvTbtScirhl7bXmC2vI4E8oT1Jb1Bn5qK9zjCdnl/nF5IloQB9mqys7WLh7b1jeIs7O1K7fOrInlHDe9mledPJOl06pp70rxyLPNPLJ1P9f9ewvX+BOLhppSTphZy4pZLtCzbHoNpfGRufLW1NbFX59o4o4ndvLvTXvpSWWoLY9z7pJ6zls6hTMX1lm/7EOgqikRuQrXjSoK/EhVHxeRTwOrVPVWETkZ+A0wAbhYRD6lqscCiMgcXIuff+Tt+uciUgcIsBp4+5i8IGMOw9/W7eKv63bx4QsWs3RaNUunVfOaU2fTuL+DWx/dzu8e2c7Hf/c4n/79E5y5cDKXrpjOC5dOGZMTS1Vly94O/rVxD/ds3MM9m/bmTv7m1VWwsL6SFyydwtxJvuXN5ArqD/FqujGD2Xegh9vWbOfXD29j9dZmRKC+qoSyeJRSfyuLR6mrKqE0Hsk9zt6XJVzulLJE3/lu2975ERGa2rpo3N/pbx007u9k3Y5W7li3q8+VWoCJucBP3+DP9Fo3PVbHAt2pNM0dLjiz70APzR1Jf9/D/o4k+w/0sD873dHD/gM9tAZaxYyluqoSFk2p5OUnzWBBfWWuFc7Ewr8IZYwpUpmMsq25M9fyZsOuNp5qavN5unrr9YaaUhZOqeL0eZNYNMUFlxfUV1JVGg+x9EdGVIe+siEi5wPfwJ2IXauqXxxgvZcBvwROVtVVg+1z5cqVumrVoKsclZraurj+vme54f5n2NPew8L6St743LksaajK/ZE3dyR7/8D9FRf35+3mHxikn18iGskFgmrK40wsTzChwj2eUO5afEzwgaAJ5QkmlieoKj2ypmOqyrP7OlwQZ3uLD+S0su9ADwAiML+ukuOmVXPc9BqOnVbD0mnV1JQN/MPqTqV5Ynsrq7c254I9W/d1Ai6PzJKG6j7BnjmTyod14qGqbNrdnuti9cizzQDMnFjGC5dM5bxjp7By9gRiBdAdLAwi8pCqrgy7HCPB6iAThq5kmhf+7z8oiUW5/d1nkoj1X5es29HKb1dv49bV29nR0kV5Isp5S6dwyYrpnLlg8ojWQU2tXdyzaW8umLO9xQXYp9WU8twFk3nugsk8Z/4k6qvHpnXkYMZLHWT1T19dyTR/X9/Erx/exl1PNpHKKIunVvHSFa574li1zM3KZJQ97d1sDQR8Gvd3sq2593F+4GdCeZwZE8qZVJlAFRR3TOGmlUzG3auSm6cKGVUUXP6p3HR2mdsHQFtXashjvIpE1F3gCx7XBY7pcsd4fjp7nJVR/f/t3Xl8VPW9//HXJwsJkISwJAQI+w6CLAHcq2gVq5Uu2qptXaq1va2t9+ftYn/t9d5re+/tcn/dba+2brV1a29tqdV6raB1lyiLwAQIEJY4IYHsezLz/f0xB4yRZQKZnDnJ+/l4zCNnzjlz5jPJ8OU7n/l+P1+iUUck6oh4zx3x7kedI+rdjzp3eL9zEHFdzvEeG43C4EEpTM3LIrefDZnpL+0PqA2SYHLOUdPcwdu1LVTUtRKuayFc10pFXStv1x3a1/quETgFOZlMH53FjNHZzBj9zhTPICZwjtcGHffrBTNLBe4E3k9sBZu1ZrbKObel23nZwC3AaycX8sC0cV8t971UxhMb36Yz6lg+M5/rz5zMmdNG9vib0LbOCHUtHbFEUFMsAVTX8k4iqLapg1rv/s4DjVTvjiWEOo8yfDnFODzyZ8ShzsGRkkBDYyOGIlF3eATOpvI6toTrD8+hTksxZozO5oLZ+YeTOLPHZPf4m+iMtNTYdKwJw7n+zNi+A41trPeSPOv21PKHN/fx4Ku7gdi87YXjc1kwfjgLJ+Ry6vjcwx2aSNSxbk8Nz2zZz/9u2c+uA00AzBs3jH96/wzeP3c0M0dn6xtpETlpv3huB3urW3joxmVHTeoAzB6Tw+wxOXztolm8XlbNn9aX85eNYf64/m1GDh3EpfPHsHLhOBaOz+1x21Tf2sFrO6t5qfQAL5UeYHtlIxBrJ8+YOpLPT40lc+JNiIuciGjUsbasmsfXlfOXt8I0tHaSn53Bp8+azIcXjmP2GP/qnqSkGPk5meTnZLJ44vD3HI9GHQea2t4z2mdfTQvVTe0YgMWmLhmx6UpmYMR+pqSAkeL99I6ZeefGtmPfpx16XKzoZm6XvtaIoe+M2D60nZGmOkEiEkzOOaqb2gl7yZmKuhbe9pI2hxI44brW9yTV01KM0TmZjBmWybzCXC6cm8mUUUO9ETjZxxwo0N/E82l6KVDqnNsJYGaPACuBLd3O+xbwXeArvRphP9YZifLXzRXc91IZb+yuYeigVD6xbCLXnTGJSaOGnvB1M9JSyc9OJT87/m+4nHM0tHUeTgQdGsp7eJivNx+7uqmdPdXNh+v+dK2z8N44Upg9JoeVC8ZyythYEmdGQVbCOh6jsjK4YM5oLpgzGoglbEorG1m3p+bwqJ7ntlVxaJDatPwspuYNpbishoNN7aSlGKdPHcmnz5zEBXNGM2bY4ITEKSKJsbe6mc/8uphPnzWZjxWNP/4D+tjug0384vkdfPDUsZwxbVRcj0lJMU6bMpLTpozkXy+by3Nbq/jT+nIeXruXB17ZzcSRQ1h56lhWLhzH1LysI16jrTPCm7trY4mcHQfYuK+OSNSRmZ7CkkkjuHxxIWdOG8WcMTn9snabJJcdVY08/mY5j68rp7y2hSGDUlkxt4APLxrHGVNH9erCAomSkmLkZ2eSn53JognvTfyI9JbjzZows88BXwAiQCNwU/cv30WOxjlHXUsHlQ1tVHm3yobWw9tVjW1U1rdR29JxuN5WWuo7dbXeuZ9CWop33/vZdTstJeXd971rNLVFYqNv6o+dtBmbm8n8wlxWzM2kYFgsiTNm2GDGDMtkVFaG+i6eeBI744C9Xe7vA5Z1PcHMFgHjnXN/MTMldo6jtrmdh1/fy4OvlPF2XSsTRgzhny+dwxVFheT4NCzMzMjJjBVlnjgyvsc452hqj7x7PndTOw7HnDHDmJo31NcpS6kpxsyCbGYWZHPl0tjS1g2tHWzcV3c42RMKN3DGtFG8f85ozp2Z59vvX0ROTkNrBzc8sJZt+xv55z9uYsH4XGaMzvY7rMOcc/zrqs2kpxjf+MDsE7pGRloqF80t4KK5BdS3dvDXTRX8aX05P11Tyk9WlzJv3DBWLhjLpfPHUtXQxks7YiNy1pZV09oRJTXFOLVwGJ8/dypnThvFwgm5+oZf+sSBxjb+vOFtHl9XzsZ9daQYnDU9j69cNJML5/ZN/SiRoIlz1sRDzrn/9s6/DPgBsKLPg5Wk0toR4UBjW7eEzTvbVYeSN41tdETeO2MjMz2F/OxM8rIzvGmV6TgHnVFHJBql05uC2RmJTcfsjB76GaUjEqWl451pm92PRyLv3M9MT2VsbianvitpE0vYjMnNZNRQJW164qT/JzWzFGKNyHVxnHsTcBPAhAkTTvapA2fb/gbue6mMx9fto7UjyhlTR/JvK09h+az8QHxD1Z2ZxZZzy0hj/IjkXp73kOzM9MM1I0Skf+iMRPniw+vYWdXEj69cwLee2MKXHl7HH79wZq8VUj9Zz2zZz5qtVXzzktm9Ui8kJzOdjxWN52NF49lf38qfN7zNH9eX8+2/hPj2X0KHz5sxOourlk7gzKmjWDZlRCDnlEswtXZEeGbLfh5fV87z26qIRB1zxuTwzUtmc9mpY5OiZpNIkjvurAnnXNcVVoYSK+8kSa61I0JVQxttnbEV/Q6t7Peu+50R2jqitHr72rwVANs63zm/+zmtHREONrYdsWC6WWz1vzwvYTMtP5u87AzyszPI826HtrMy0jQVO4DiSeyUE1uF5pBCb98h2cApwHPeG6AAWGVml3UvoOycuxu4G2JFu04i7sCIRh1rtlZy30tlvFh6gIy0FD68cBzXnTmJWQX+zR8XEekvvv2XEM9treI/PzKPlQvGkZOZzvX3r+U7T5Xwr5fN9Ts8Wtoj/NuftzBjdBbXnjGp168/OieTG8+ewo1nT6G0spFntuxnbG4mp08d2aMpuSInKxp1vLarmsfX7eOptypoaOukICeTG8+ezEcWFjKzIHlG0YkEwHFnTQCY2ReAW4FBwPKjXWygf8Hul/bOKNv2N7BhXy0b99axsbyObfsbiByltunRZKSlHF7dLyMt9fCqgJlpqWRlpDFyaGxfLHmTcXjEzaGEzYihgwbs4i8DRTyJnbXAdDObTCyhcyVw9aGDzrk64PDwBzN7Dvjy8VbF6u8a2zr5ffFe7n+5jLKDzYzOyeArF83kqqUTtMyjiEgvefCVMu5/uYwbz5rMVd6Uy/Nm5XPdGZO4/+Uy3jcjj/Nm5fsa48+fK6W8toVHbjqN9AR3qqblx5brFOlLVQ1t/Pa13Ty2di9v17UydFAqF88bw0cWjmPZlJGBHJUsEhTOuTuBO83sauCbwLVHOW/AfcHe1yLR2Oq6G/bW8lZ5HRv21REK1x+uHZM7JJ35hblcMDuf8SOGeImZQwmbIydtMtJTyEhL0QgaOa7jJnacc51mdjPwNLHCXfc65zab2R1AsXNuVaKDDJI9B5u5/+Uyfle8l4a2ThZOyOXWC2dy8SkFCe/Qi4gMJH/fVsW//nkL58/K5+vd6tbcdvEsXt15kK/8fgNP3XIOedkZvsS460ATdz2/kw8vHMdpU+IsYCYSEG/tq+O+l3bxxMYw7ZEo58zI42sXz+LCOQUMHpQc0yBFAux4sya6ewT4RUIjksOcc+ypbmbDvjo27q1lY3kdm8vraGqPALGV7E4Zl8N1Z0xifuEwTi3MpXD4YCVoJGHiqrHjnHsSeLLbvtuPcu65Jx9WsHREoqwuqeSR1/fw3LYqUs24ZP4Yrj9zMgvG5/odnohIv7N9fwNf+O2bTM/P4sdXLXzPiIDM9FR+fOVCLvvZi3z5dxu477olfV6AzznH7X/aREZaCl//wKw+fW6RROmMRHl6837ue2kXxd6Knlcvm8A1p09kylFWZhORE3LMWRMAZjbdObfdu3sJsB3pdc45Kupb2bC3jrfKa9m4r46N++qoa+kAYFBaCnPH5nBF0XjmjRvGqeOHMWVUlgr/Sp/SMgQnoexAE4+s3cvv39jHgcY2RudkcPN50/jEsom9UhxTRETe62BjG59+YC0Z6ancc90SsjKO/F/ZzIJsvnHJbG7/02buf7mMT581uU/j/OumCl7YfoDbL52jWjcSeDVN7Ty8dg8PvrKbcJKs6CnSn8U5a+JmM7sA6ABqOMo0LOm5vdXNPBvaz4ulB9iwr46qhjYgtgT3zIJsPjCvgPmFucwvHMaM0dmamSG+U2Knh1o7Ijy9uYKHX9/DqzurSU0xzpuZz5VLxnPuzDwVpRIRSaC2zgifffANKuvbePSzpzMud/Axz//UaRN5fmsV33mqhNOnjmT2mL4pWt/c3skdT2xhVkE215w+sU+eUyQRtlY0cP/Lu3h8XTmtHVHOnDaSOwK8oqdIkBxv1oRz7pY+D6qfikYdb5XX8bfQfp7Zsp+SigYAJo8aytnTRzF/3DDmj89lzpicpFlxU6QrJXbiVFJRzyOv7+XxdeXUtXQwYcQQvnLRTC5fXMhoLdkpIpJwzjm+/j9vUby7hp9dvTCuqa5mxvcun8+KH7/Alx5ex5+/eFafdMh+urqUcF0rP7lqoRL+EjiRqGN1SSX3vbSLl3ccJCMthY8sGsd1Z0zWylYi0m+0dkR4eccBntlSybOh/VQ2tJFisGTSCL55yWzOnz2ayaOG+h2mSFyU2DmGxrZOntjwNg+v3cuGvbUMSk3holMKuGrJeE6bMlLzJkVE+tDPn9vBH9aV80/vn8Gl88fG/biRWRn8vytO5Zp7X+ff/xLiWx86JYFRQmllI796YScfXVTIkkkjEvpcIr2pobWDx4r38cDLZeypbmbMsEy+umImVy2ZwHCt6Cki/cCBxjZWl1Tyty37eWH7AVo6IgwdlMq5M/O5YE4+587IV3sngaTETjfOOdbvreXRtXtZteFtmtsjTM/P4p8vncNHFo7TP3QRER88+VaY7z+9lQ8tGMvNy6f1+PHnzMjjxrMm86sXd/G+GXlcMGd0AqKM/R/yL6s2kZmeqoLJEhi7DjTxgLeiZ1N7hKKJw/naillcNHe0RpyJSKA5F1uC/JktlfwttJ8399TgHIwdlskVRYVcMHs0y6aMICNN06sk2JTY8dQ2t/P4unIeXbuXkooGBqen8sFTx/DxJRNYNCFXS9OJiPhkw95abn1sPYsnDuc7H51/wu3xV1bM5OUdB/nq/2zkr4Vnk5+AabR/eSvMS6UHuWPlXEZl+bPEuvQuM1sB/JhY8dJfOee+c5TzPgr8HljinCvuwxBPiHOOF7Yf4L6XdrFmaxXpqcYH54/l+jMnM69wmN/hiYicsM5IlOLdNfxty37+FtpP2cFmAE4Zl8Mt50/ngtmjmTs2R5/vpF8Z0Ikd5xyv7DzIo2v38tSmCto7o5xaOIz/+PA8PnjqGLK1yoOIiK/erm3hxl8XMyorg7s+tfik6uNkpKXyk6sWculPX+CffreBB65f2qtTahvbOvnWE1uYOzaHTyxTweT+wMxSgTuB9wP7gLVmtso5t6XbednALcBrfR9lzzjneKx4L798YRellY2MysrgHy+YztXLJmj1NhEJrMa2Tp7fWsXfQvtZXVJJXUsHg1JTOH3qSG44ewoXzM5nzLBjL7ggEmQDMrHjnOOeF3fxm1d3U3awmZzMNK5aMp6PL5nAnLF9s2KKiEhvikQdP3xmG8tn57NownC/w+kVTW2d3PhAMS3tEX5747JeGQEzLT+L2y+dy/99/C3ueXEXnzlnSi9EGvOTZ7ezv76NX3xysVYL6j+WAqXOuZ0AZvYIsBLY0u28bwHfBb7St+H13Gu7qvna/7zFKeNy+MHHTuWS+WM0BUFEAm1TeR1X/fJVGlo7GT4knfNn5/P+2aM5e0YeWRkD8uOuDEAD8p3+6s5qvv2XEIsnDueWC6Zz8SljtGydiATa+r01/GxNKfe/XMbDnzkt8FMpIlHHLY+sp6SinnuvW8KM0b23Es9VS8fz/LZKvvd0bAn0U8ad/O9q2/4G7n1xFx8vGt9vEmsCwDhgb5f7+4BlXU8ws0XAeOfcX8ws6RM7G/fVAvDgp5epbqCIBN7btS18+v615GSm88triiiaOFy1wWRAGpDv+k3ldQD88poiPrywUEkdEQm81SWVpKYYwwanc+19r1Na2eB3SCfle38t4W+h/fzLB+dy7sz8Xr22mfGdj8xnxNBBfOmRdTS3d57U9Zxz/PMfNzE0I42vrpjZS1FKEJhZCvAD4J/iOPcmMys2s+KqqqrEB3cUoXADBTmZSuqISOA1tHbw6fvX0tIe4b7rl3DalJFK6siANSDf+aFwPaNzMhihTo1InzOzFWa21cxKzey2Ixw/x8zeNLNOM7u827GIma33bqu67J9sZq9513zUzAbcP+7VJVUsnjic3964jBQzPvmr19lb3ex3WCfkkdf3cNffd3LN6RO59oxJCXmO4UMH8cOPLWDXgSa+9UTopK61asPbvLarmq+umMlIFUzub8qB8V3uF3r7DskGTgGeM7My4DRglZkVdb+Qc+5u51yRc64oLy8vgSEfWyhcz+wxvTcCTkTEDx2RKJ//7ZuUVjbyi08u7tWRvSJBNCATO1vC9cweo1o6In2tSyHSi4E5wFVmNqfbaXuA64CHjnCJFufcAu92WZf93wV+6JybBtQAN/R68Ens7doWQuF6ls/KZ9KoofzmxqW0dET45D2vUVnf6nd4PfLyjgN884+bOGdGHrdf2v2t0bvOmDaKz54zlYdf38NfN1Wc0DUaWjv49l9CzC8cxpVLJvRyhJIE1gLTveTxIOBK4HBS2TlX55wb5Zyb5JybBLwKXJasq2K1d0YprWxUH0hEAu3QSNkXth/gPz4yj7Omj/I7JBHfDbjETntnlB1V6tSI+ORwIVLnXDtwqBDpYc65MufcRiAazwUttlblcmLLDAM8AHyo1yIOgDVbKwE4f1ZsytKsghzuu34JVQ1tfOqe16ltbvczvLjtrGrkH37zJpNHDeVnVy/sk+HUt75/BvPGDeO2P2wkXNfS48f/8JntHGhs41srT1HB5H7IOdcJ3Aw8DYSAx5xzm83sDjO77NiPTj6llY10Rh2z1AcSkQD7xfM7eGTtXm4+bxofKxp//AeIDAADLrFTWtlIR8QpsSPijyMVIh3Xg8dnejUqXjWzD3n7RgK13gewY14zWWpc9LY1JZUUDh/MtPysw/sWTRjOL68pYteBJq69by2NbSdXRybRapvbueGBYlJTjHuvW0JOZnqfPO+gtBR+fOUC2juj3ProBiJRF/djQ+F6HniljKuWTuDU8bmJC1J85Zx70jk3wzk31Tn3796+251zq45w7rnJOloHYu9ZgDmaiiUiAbVqw9t8769bWblgLP904Qy/wxFJGgMusaNOjUigTXTOFQFXAz8ys6k9eXCy1LjoTa0dEV4sPcDyWfnEBi+948xpo/jZ1QvZVF7HTb8uprUj4lOUx9beGeUffvMm5TUt3P2pxYwfMaRPn39KXhb/+sG5vLLzIHf/fWdcj3HOcfufNpGTmcZXLlTBZAmGULiejLQUJo0c6ncoIiI9trasmi//bgNLJ43ge5fPf0+/R2QgG5CJHXVqRHxzvEKkx+ScK/d+7gSeAxYCB4FcM0s7kWsG3Ss7D9LaEWX5rCOvHHXh3AL+64r5vLzjIF98eB0dkbhmuPWZQ/PkX9l5kO9ePo+iSSN8ieOKokIumTeG//e/Ww8vB30sf3iznLVlNXxtxSytLiSBUVLRwMyCbK0aIyKBs+tAE5/5dTGFuYO561OLyUjTqsYiXQ24/9lDFfXq1Ij455iFSI/FzIabWYa3PQo4E9jinHPAGuDQClrXAn/q9ciT1OpQJYPTUzltysijnvPhhYXcsXIuz2zZz1d/v5FoD6YbJdqvXtjFo8V7+eLyaXx4YaFvcZgZ//HheeRnZ3DLI+tpOsbUtbqWDv7zqRALxudqbr8EhnOOULieWQUasSwiwVLd1M71971Oihn3Xb9EX6iIHMGAym7EOjUNzFF9HRFfxFOI1MyWmNk+4ArgLjPb7D18NlBsZhuIJXK+45zb4h37GnCrmZUSq7lzT9+9Kv8451hdUsmZ00aRmX7sb66uOX0SX7loJo+vK+df/7yZWD7MX89s2c9/PBXiknlj+D8X+D9PftiQdH748QWUHWzi3/68+ajn/eB/t1Ld1M63P3QKKSqYLAFR1dDGwaZ21RgUkUBp7YjwmV8X83ZdK7+8poiJmnUhckRpxz+l/9hf30a1OjUivnLOPQk82W3f7V221xKbTtX9cS8D845yzZ3EVtwaULZXNlJe28IXzpsW1/mfP3cq9S0d3PX3nWRnpvGVi2YlOMKj21Rexy2PrGP+uGH81xWnJk2CZNmUkXzh3Gn8bE0p75uRzyXzx7zr+KbyOh58dTefPG0ip4wb5lOUIj23xasxqD6QiARFNOr48u828MbuGu68ehGLJw73OySRpDWgEjshdWpEpB95NhRb5vy8WfEVgjYzbrt4FvWtHdy5ZgfZmel87n09qj990g42tvHT1aX89rXdjMrK4JfXFDF4UHLNk7/lgum8WHqAr/9hIwsm5DIudzAQ62De/qdNDB8yiH96vwomS7CUVDQAMLtAfSARCYbv/+9WntgY5usXz3rPFy0i8m4DairWoW+rZmlFLBHpB9aUVDJnTA5jhg2O+zFmxrc/NI9L54/hO0+V8NBrexIY4Tua2zv56bPbed/3n+PBV3dz+eLx/OkLZ5Kfk9knz98T6amxJdAjUcf/eWT94SXQf//GPt7cU8ttF89i2JC+WY5dpLeEwvWMHZap966IBMLDr+/hF8/t4OplE7jpnCl+hyOS9AbciJ3C4YPJyVSnRkSCrba5neLd1Xz+3PimYXWVmmL88OMLaG6P8I0/vkVWZhqXnTo2AVFCRyTKo2v38uNnt1PV0MaKuQV8+aKZTMvPSsjz9ZaJI4fyrQ+dwq2PbeAXz5XyydMm8p2/llA0cTgfXeRfkWeRExUK12vEsogEwvPbqvjmHzfxvhl53HHZXC1rLhKHAZfYUadGRPqD57dVEXVw3lGWOT+e9NQUfv6JRVxz7+vc+uh6sjJSWT5rdK/F55zjqU0VfP/prew60MTSSSO461OLWTQhOPPjP7xwHM9treKHf9vOqzurqW1u546Vy5KmHpBIvFo7IuyoauLCOQV+hyIickyhcD1f+O2bzBidzZ2fWKSVjEXiNGD+pbR2RNh1oEmJHRHpF9aUVDJi6CAWjM894Wtkpqdyz7VFzB6Twz/85k1e3XmwV2J7ZcdBPvTzl/n8b98kPdW459oiHv3saYFK6oA3be3DpzBmWCYvlh7gmtMnMWes/g+R4CmtbCQSdeoDiUhS21/fyqfvX0tWRhr3XldEVsaAGoMgclIGTGJna0UDUQdzVF9HRAIuEnU8t62Kc2fkkXqSo0eyM9N54NNLGT9iCDc+UMzGfbUnfK1QuJ7r7nudq375KpX1rXz/8vk8dcs5nD97dGCHUedkpvPzTyziI4vGceuF/i/JLnIiQqoxKCJJrrGtk+vvW0t9Swf3XrekR/UDRSTOxI6ZrTCzrWZWama3HeH458zsLTNbb2Yvmtmc3g/15GhFLBHpL9btqaG2uYPls09sGlZ3I4YO4jc3LCN3SDrX3vs62/c39Ojx+2qaufWx9XzgJy+wbk8t//cDs1jz5XO5omj8SSeeksH8wlx+8LEFqs8mgRUKN5CZnsKkkUP9DkVE5D06I1G++NCbbN3fwM8+sUijY0VOwHETO2aWCtwJXAzMAa46QuLmIefcPOfcAuB7wA96O9CTFQrXk5WRxvjhQ/wORUTkpKwuqSQ1xTh7enzLnMejYFgmv71xGWmpKXziV6+x52DzcR9T09TOt5/YwvL/ep4nNoa56Zwp/P0r53HTOVPJTE+uJcxFBrJQuJ6ZBTn9ItEqIv2Lc45/+/MW1myt4o6VczlvZu98aSUy0MQzYmcpUOqc2+mcawceAVZ2PcE5V9/l7lDA9V6IvSMUbmBWQbaKXopI4K0uqaRo4nCGDe7dESQTRw7lNzcsoz0S5ZP3vMb++tYjntfSHuHONaWc87013PvSLj60cCzPfflcvn7xbC2lLJJknHOUVNRrKrqIJKV7XtzFg6/u5rPvm8Inlk30OxyRwIonsTMO2Nvl/j5v37uY2RfMbAexETtfOtKFzOwmMys2s+KqqqoTifeEOOe0IpaI9AvltS2UVDRwfi9Nw+puZkE291+/lIONbXzqnteoaWo/fKwzEuWR1/dw7n+t4ftPb2XZlJH89R/P4XuXn8rYXM2FF0lG++vbqGnuYFaB+kAiklyeeivMvz8Z4pJ5Y/jaRbP8Dkck0HqteLJz7k7n3FTga8A3j3LO3c65IudcUV5e700hOJ59NS00tHUqsSMigbe6pBKA5Se4zHk8FozP5ZfXFlF2sJnr7nudhtYOnt5cwUU/+ju3/eEtxuUO5rHPns6vri1ixmiNAhBJZqoxKCLJaN2eGv7x0fUsHJ/L//vYqZpVIXKS4llDrhwY3+V+obfvaB4BfnEyQfW2LYc7NfoAIiLBtqakkgkjhjA1Lyuhz3PG1FH8/OpFfPY3b3Dmd1ZT39rJ1Lyh3PWpxVw4J7irXIkMNFu0IpaIJJk9B5u58YFiRudk8strilSXT6QXxDNiZy0w3cwmm9kg4EpgVdcTzGx6l7uXANt7L8STFwrXYxabYiAiElQt7RFeKj3A8ln5fZJYuWDOaH748QWMzR3Mf35kHk//4zlcNLdASR2RACmpaKBw+GCt6iYiSaG2uZ3r7n+diHPcf/0SRmZl+B2SSL9w3BE7zrlOM7sZeBpIBe51zm02szuAYufcKuBmM7sA6ABqgGsTGXRPhcL1TB45lCGD4hmgJCKSnF7ZeYC2zijnJXAaVneXnTqWy04d22fPJyK9KxSuV30dEUka3/zjJvZVt/CbG5cxJcGjj0UGkrgyHc65J4Enu+27vcv2Lb0cV68KhRuYN26Y32GIiJyU1SWVDBmUyrLJI/wORUQCoLUjws6qRj5wSoHfoYiIEIk6nt9axUcXF7JUfRmRXtVrxZOTVUNrB3uqm1VfR0QCzTnH6lAlZ04bpbnoIhKXbfsbiDoVThaR5LBtfwMNbZ0snTzc71BE+p1+n9jZWtEAwJyx6tSISHBt3d/A23WtnN+H07BEBMxshZltNbNSM7vtCMc/Z2Zvmdl6M3vRzOb4EeeRlIRjfSAldkQkGRSXVQNQNFGjdUR6W79P7GiZT5HkEseHpHPM7E0z6zSzy7vsX2Bmr5jZZjPbaGYf73LsfjPb5X2wWm9mC/ro5fSZQ8uc92V9HZGBzsxSgTuBi4E5wFVHSNw85Jyb55xbAHwP+EHfRnl0W8L1DBmUyoQRQ/wORUROQBx9plvNbIvXL3rWzCb6EWe8infXMDong8Lhg/0ORaTf6feJnS3henKHpFOQk+l3KCIDXpwfkvYA1wEPddvfDFzjnJsLrAB+ZGa5XY5/xTm3wLutT0D4vlodqmTu2BxGqy0T6UtLgVLn3E7nXDvwCLCy6wnOufoud4cCrg/jO6ZQuJ6ZBdmkpGglO5GgibPPtA4ocs7NB35PLLmctIrLaiiaOEKra4okwABI7DQwuyBHDYhIcojnQ1KZc24jEO22f5tzbru3/TZQCeT1Tdj+qmlq5809NZqGJdL3xgF7u9zf5+17FzP7gpntIPah6ktHupCZ3WRmxWZWXFVVlZBgu3LOUVLRoBHLIsEVT59pjXOu2bv7KlDYxzHG7e3aFsprWyiapPo6IonQrxM7kahja0W9OjUiySOuD0nHY2ZLgUHAji67/90bivxDM8s4yuP69INVb/n79iqiTtOwRJKVc+5O59xU4GvAN49yzt3OuSLnXFFeXuJz0uG6VupaOtQHEgmunvaZbgCeOtpBv/tAxbtrANXXEUmUfp3YKTvYRGtHVCtiifQjZjYGeBC43jl3aFTP14FZwBJgBLEPV+/R1x+sesuzoUpGDh3EqYW5fociMtCUA+O73C/09h3NI8CHEhlQvA7XGCxQH0ikvzOzTwJFwPePdo7ffaA3yqoZMihVn8tEEqRfJ3ZUOFkk6fT0Q9K7mFkO8BfgG865Vw/td86FXUwbcB+x4cv9QmckyvPbqjh3Zr7qZIj0vbXAdDObbGaDgCuBVV1PMLPpXe5eAmzvw/iO6lAfaJb6QCJBFVefycwuAL4BXOb1g5LS2rIaFk7IJS21X3/8FPFNv/6XFQrXk5ZiTB+d5XcoIhJz3A9JR+Od/zjwa+fc77sdG+P9NGLflm/qzaD99OaeWupaOliuaVgifc451wncDDwNhIDHnHObzewOM7vMO+1mb7W+9cCtwLX+RPtuoYoGJowYQlZGmt+hiMiJiSexvBC4i1hSp9KHGOPS0NpBSUW9pmGJJFC//t8+FG5gal4WGWmpfociIsQ+JJnZoQ9JqcC9hz4kAcXOuVVmtoRYAmc48EEz+zdvJayPAecAI83sOu+S13krYP3WzPIAA9YDn+vL15VIq0sqSUsxzp4xyu9QRAYk59yTwJPd9t3eZfuWPg8qDqFwvaY8iARYPH0mYlOvsoDfeQvF7HHOXXbUi/pk3Z5aog4VThZJoH6e2Kln2WRlhkWSSRwfktZyhFUdnHO/AX5zlGsu7+Uwk8aakkqWTBpBTma636GISEC0tEcoO9DEB+eP9TsUETkJcfSZLujzoE5A8e4aUgwWTlBiRyRR+u1UrJqmdsJ1rcwZq7nlIhJM+2qa2bq/QdOwRKRHtu5vIOpUY1BEkkNxWTWzx+RoaqhIAvXbxI4KJ4tI0K0piU2XXz5biR0RiV+J1weaoz6QiPisIxJl/d5alkzSLAqRROq3iZ0tSuyISMA9W1LJxJFDmDJqqN+hiEiAhML1ZGWkUTh8sN+hiMgAFwrX09weYfFETcMSSaR+m9gJhRvIy85gVFaG36GIiPRYS3uEV3YcZPmsfLyCiCIicQmFG5hZkE1KitoOEfFXcVkNoMLJIonWjxM79RqtIyKB9fKOA7R1RlVfR0R6xDlHqEIrYolIcnhjdw3jcgczZphGEIokUr9M7HREopRWNqpTIyKB9WxJJUMGpbJUK/uJSA+U17bQ0NqpL7dExHfOOdaWVWu0jkgf6JeJnR1VjbRHoioaKCKB5JxjTUklZ08fRUZaqt/hiEiAhMINgGoMioj/9tW0UNnQRpEKJ4skXL9M7GhFLBEJspKKBsJ1rZqGJSI9FgrXYwYzR2vUsoj4a21ZNQBFKpwsknD9NLHTwKC0FK0kIyKBtNpb5vy8mUrsiEjPhML1TBwxhKEZaX6HIiIDXPHuGrIz05ihRLNIwvXTxE49M0dnk5baL1+eiPRzq0sqmTduGPk5mX6HIiIBU1LRoBHLIpIUisuqWTRhOKlaoU8k4fpd5sM5x5a3tRqEiARTdVM7b+6p4TxNwxKRHmpu76TsYJMSOyLiu7rmDrbtb2SJCieL9Il+l9ipamjjYFO7OjUiEkjPb6vEOThfiR0R6aGSigacg1kF+nJLRPz1xp5YfZ3FE1U4WaQv9LvEzhYVThaRAFtdUsWorAzmjRvmdygiEjBaPEJEkkVxWQ1pKcaC8bl+hyIyIPS7xM7hZT4L1KkRkWDpjER5fmsl587MI0Xz0UWkh0rCDWRnplE4fLDfoYjIAFdcVsPcccMYPCjV71BEBoS4EjtmtsLMtppZqZnddoTjt5rZFjPbaGbPmtnE3g81PqFwPeNyBzNsSLpfIYiInJA3dtdQ39qpaVgickJC4XpmF+RgpsSwiPinrTPChn21LNEy5yJ95riJHTNLBe4ELgbmAFeZ2Zxup60Dipxz84HfA9/r7UDjFQqrcLKIBNPqrZWkpxpnTR/ldygiEjDRqKOkooFZ6gOJiM82ldfT1hmlSIWTRfpMPCN2lgKlzrmdzrl24BFgZdcTnHNrnHPN3t1XgcLeDTM+rR0Rdh7QahAiEkyrQ5UsnTyC7EyNOBSRntlX00JjW6f6QCLiuzd2q3CySF+LJ7EzDtjb5f4+b9/R3AA8daQDZnaTmRWbWXFVVVX8UcZp+/5GIlGnTo1IEotjauc5ZvammXWa2eXdjl1rZtu927Vd9i82s7e8a/7EAjgPYW91M9srGzlvpqZhiUjPhSpUOFlEksPashomjRxCXnaG36GIDBi9WjzZzD4JFAHfP9Jx59zdzrki51xRXl5ebz418M5qEHPUqRFJSnFO7dwDXAc81O2xI4B/AZYRG0n4L2Z2aIzvL4DPANO924oEvYSEWV1SCcBy1dcRkRMQCtdjBjNGZ/kdiogMYM453thdQ9EkjdYR6UvxJHbKgfFd7hd6+97FzC4AvgFc5pxr653wemZLuJ6hg1KZMGKIH08vIscXz9TOMufcRiDa7bEXAc8456qdczXAM8AKMxsD5DjnXnXOOeDXwIcS/UJ62+qSSiaPGsqUPH0oE0kmQVlAIhSuZ/LIoQwZlObH04uIALDzQBPVTe0UqXCySJ+KJ7GzFphuZpPNbBBwJbCq6wlmthC4i1hSp7L3w4zPlnA9MwuytUywSPLq6dTOeB47zts+7jUTPR30RDW3d/LKzoOahiWSZIK0gEQo3KBpWCLiuzfKagBUOFmkjx03seOc6wRuBp4GQsBjzrnNZnaHmV3mnfZ9IAv4nZmtN7NVR7lcwjjnvBWx1KkRkSNL9HTQE/VS6UHaO6OcP1uJHZEkE4gFJBrbOtlT3axVQUXEd8W7qxk+JJ2pGoEs0qfiGq/rnHsSeLLbvtu7bF/Qy3H1WHltCw2tWg1CJMnFNbXzGI89t9tjn/P2F3bbH+81k8LqkkqyMtJYovnoIsnmSCMFlx3j/GMuIAHcBDBhwoTeig+ArV7h5FkF6gOJiL+Ky2pYPHE4AVzHQiTQerV4sp9C4QZAq0GIJLnjTu08hqeBC81suFc0+ULgaedcGKg3s9O81bCuAf6UiOATwTnHmpJKzpo2ikFp/aZJFhlw/FxAYsuhPtBY9YFExD8HG9vYeaBJhZNFfNBvPkUcWg1iVoGGIYskq3imdprZEjPbB1wB3GVmm73HVgPfIpYcWgvc4e0D+DzwK6AU2MFRvjFPRlvC9VTUt7Jc07BEklEgFpAIhevJyUxj7LDMvn5qEZHDind79XVUOFmkz/WbpRNC4XomjhjC0Ix+85JE+qU4pnau5Sg1Kpxz9wL3HmF/MXBK70baN9Z4y5yfOzN5av6IyGGHRxkSS+hcCVzd9YQuC0is8GsBiRKvxqCmPoiIn97YXcOgtBTmFQ7zOxSRAadfjdjRNCwRCZpnSyo5tXAY+dn6pl0k2QRhAYlo1FFSoRWxRMR/a8uqmT9uGBlpqX6HIjLg9IvhLU1tneyubuaji/p8IQoRkRN2sLGN9XtrueX86X6HIiJHkewLSOypbqa5PaIVsUTEV60dETaV13HDWVP8DkVkQOoXI3ZKKhpwToWTRSRYnt9WhXOwfJbq64jIiQmFYytiqQ8kIn7asLeWjohTfR0Rn/SLxM7hTo1WgxCRAHm2pJK87AxOGau56CJyYkIVDaQYzBitETsi4p9DhZMXK7Ej4ot+kdjZotUgRCRgOiJR/r6tivNm5pGSooKnInJiQuF6Jo8aSma6alqIiH+Ky6qZlp/F8KGD/A5FZEDqF4mdkFaDEJGAKS6roaG1U9OwROSkaPEIEfFbNOp4Y3cNSyZptI6IXwKf2IlGHVu1GoSIBMyarZWkpxpnTdcy5yJyYupbO9hX06I+kIj4antlI/WtnSyeOMLvUEQGrMAndnZ7q0HMUadGRAJkdUklyyaPJCujXyxOKCI+2FrRAKA+kIj4qnh3NYBG7Ij4KPCJHa0GISJBs+dgM6WVjZqGJSIn5VAfaJaWOhfpl8xshZltNbNSM7vtCMfPMbM3zazTzC73I0aITS8flZXBhBFD/ApBZMDrF4md1BRj+ugsv0MREYnL6pL9gJY5F5GTEwrXkzsknYIcLR4h0t+YWSpwJ3AxMAe4yszmdDttD3Ad8FDfRvduxburWTJpuOqdivioXyR2pmg1CBEJkCffqmDKqKFMGjXU71BEJMBC4QZmF2jxCJF+ailQ6pzb6ZxrBx4BVnY9wTlX5pzbCET9CBBgf30re6tbtMy5iM/6QWKngTljNQ1LRIJhbVk1r5dV84nTJvodiogEWESLR4j0d+OAvV3u7/P2JZXishoAiiapcLKInwKd2Klr7qC8VqtBiEhw/OTZ7YzKGsTVSyf4HYqIBNjug020dERUX0dE4mJmN5lZsZkVV1VV9dp115ZVk5mewlx90S7iq0AndraocLKIBMi6PTW8sP0Anzl7CoMHafqoiJy4UFgrYon0c+XA+C73C719J8Q5d7dzrsg5V5SXl3fSwR3yxu4aFozPJT010B8rRQIv0P8C31kRS99WiUjy++nqUoYPSeeTmoYlIieppCK2eMS0fC0eIdJPrQWmm9lkMxsEXAms8jmmd2lq62RLuJ4lmoYl4rvAJ3ZGZQ0iP1urQYhIcttUXsfqkkpuOGsyQzPS/A5HRAIuFK5nap4WjxDpr5xzncDNwNNACHjMObfZzO4ws8sAzGyJme0DrgDuMrPNfRnj+r21RKJOhZNFkkCgP12EKuo1DUtEAuGnq7eTk5nGNWdM8jsUEekHQuEGfZgS6eecc08CT3bbd3uX7bXEpmj5orisBjNYpLZIxHeBHbHTGYmybX+jEjsiAWNmK8xsq5mVmtltRzieYWaPesdfM7NJ3v5PmNn6LreomS3wjj3nXfPQsfy+fVXHFgrX8/Tm/Vx/5mRyMtP9DkdEAk6LR4hIMijeXc3M0dnq24gkgcAmdnYeaKK9M6r6OiIBYmapwJ3AxcAc4Cozm9PttBuAGufcNOCHwHcBnHO/dc4tcM4tAD4F7HLOre/yuE8cOu6cq0zwS+mRn60pJSsjjU+fOdnvUESkHyipUI1BEfFXZyTKm7trVF9HJEkENrET0opYIkG0FCh1zu10zrUDjwAru52zEnjA2/49cL6ZWbdzrvIem/RKKxt48q0w15w+kWFD9I2WiJy8Q30grYglIn4pqWigqT1C0SRNwxJJBoFN7GwJ1zMoNYWpeVoNQiRAxgF7u9zf5+074jle4cA6YGS3cz4OPNxt333eNKx/PkIiCAAzu8nMis2suKqq6kRfQ4/cuWYHmWmp3HCWRuuISO8IhRsYMXQQedkZfociIgPUG7trACjSiB2RpBDYxE4o3MD00Vmkpwb2JYjICTCzZUCzc25Tl92fcM7NA872bp860mOdc3c754qcc0V5eXkJj3XXgSb+tL6cT542gZFZ+gAmIr0jtnhENkfJYYuIJNzasmrGDMtkXO5gv0MREeJM7MRR7PQcM3vTzDrN7PLeD/O9QmGtiCUSQOXA+C73C719RzzHzNKAYcDBLsevpNtoHedcufezAXiI2JQv3/18TSnpqSl85pwpfociIico2fpAkahja0UDswvUBxIRfzjnKC6r0WgdkSRy3MROnMVO9wDXEftAlXBVDW1UNbQpsSMSPGuB6WY22cwGEUvSrOp2zirgWm/7cmC1c84BmFkK8DG61NcxszQzG+VtpwOXApvw2d7qZh5fV85VSyeQn53pdzgicgKSsQ+060ATbZ1R9YFExDfltS1U1LdSpGXORZJGWhznHC52CmBmh4qdbjl0gnOuzDsWTUCM7/FO4WStBiESJM65TjO7GXgaSAXudc5tNrM7gGLn3CrgHuBBMysFqoklfw45B9h7qD3yZABPe0mdVOBvwC/74OUc0y+e30GKGZ9731S/QxGRE5e0faBZ6gOJiE/eqa+jxI5IsognsXOkYqfLTuTJzOwm4CaACRMmnMglAK0GIRJkzrkngSe77bu9y3YrcMVRHvsccFq3fU3A4l4P9CS8XdvC74r38rGi8RQM02gdkQDrtT5QbwmF60lLMabla/EIEfHH2rJqsjLSmKUpoSJJo08rD/dW4dJQuJ4xwzLJHTKoF6MTEekddz2/A+fgH87VaB0RiemtVflKKhqYlp9FRlpqL0YnIhK/4rIaFk7IJTVFBdxFkkU8iZ14ip32qVC4QXPLRSQpVda38vDavXx0USGFw4f4HY6InJxe6wP15pdb6gOJiF/qWjrYur+BookqnCySTOJJ7MRT7LTPtHVG2FHVqPo6IpKU7v77TiJRx+fP02gdkX4gqfpAtc3thOtamVWgPpCI+GPdnhqcU30dkWRz3MSOc64TOFTsNAQ8dqjYqZldBmBmS8xsH7G6GHeZ2eZEBbx9fyOdUcecMcMS9RQiIifkQGMbv31tDytPHcvEkUP9DkdETlKy9YG2HF48QiN2RMQfxWU1pKYYC8bn+h2KiHQRT/HkeIqdriU2PDnhtCKWiCSrX72wi9bOCJ8/b5rfoYhIL0mmPlBJuAFQYkdE/FO8u5o5Y3IYmhHXx0gR6SN9Wjy5N4TCDQxOT9W34SKSVGqa2nnwlTIunT9Wq9WISEKEwvWMysogLzvD71BEZADqiERZv7dW07BEklDgEjtbwnXMLMhWFXYRSSr3vbSLpvYIN2u0jogkSKiiXiOWRcQ3m9+up7UjqsLJIkkoUIkd55xWxBKRpFPX0sF9L5exYm4BM1XUVEQSoDMSZdv+RvWBRMQ3xWXVgAoniySjQCV2wnWt1LV0MEffVolIEnng5TIaWju5eblG64hIYuw60ER7Z1QjdkTEN8VlNYwfMZjROZl+hyIi3QQqsRPSahAikmQa2zq596VdXDA7n1PGabU+EUkMrYglIn5yzlG8u4YlmoYlkpQCmdiZpU6NiCSJB1/ZTW1zB19cPt3vUESkHwuFG0hPNaaMUnF2Eel7uw82c6CxjcWahiWSlAKW2GlgwoghZGl5PRFJAs3tnfzqhZ2cMyOPU8fn+h2OiPRjoXA90/KzGZQWqK6biPQTxbtrAFgySSN2RJJRoHoHoXA9czRaR0SSxEOv7eFgUztfUm0dEUmwEq2IJSI+Ki6rJiczjWl5GjUokowCk9hpbu9k18EmzS0XkaTQ2hHhrr/v5PQpIynSt1cikkDVTe3sr2/Tl1si4pvi3TUUTRpBSor5HYqIHEFgEjtbKxpwDn1bJSJJ4dG1e6lqaONL56u2jogk1uEagwVK7IhI36tpaqe0spHFE1VfRyRZBSaxEwo3AFoNQkT819YZ4b+f38GSScM5bYpG64hIYr2zKqi+3BKRvveG6uuIJL3AJHa2hOvIzkyjcPhgv0MRkQHu92/sI1zXyheXT8dMQ5JFJLFC4QbyszMYmZXhdygiMgCt3V1Neqoxv3CY36GIyFEEJrETCjcwuyBHH6JEAs7MVpjZVjMrNbPbjnA8w8we9Y6/ZmaTvP2TzKzFzNZ7t//u8pjFZvaW95ifWAIbio5IlF88t4NTx+dy9vRRiXoaEZHDQuF6jVgWEd+8UVbDvHHDyExP9TsUETmKQCR2olFHSVirQYgEnZmlAncCFwNzgKvMbE63024Aapxz04AfAt/tcmyHc26Bd/t(internal link)/AD4DTPduKxL1Gh5fV86+mhZuOX+aEs0iknAdkSillY3MUh9IRHzQ2hFh4746LRQhkuQCkdjZW9NMU3tE31aJBN9SoNQ5t9M51w48Aqzsds5K4AFv+/fA+ccagWNmY4Ac59yrzjkH/Br4UK9HDnRGovx8TSmnjMvhvJn5iXgKEZF32VnVRHskqhWxRMQXm8rraI9EVThZJMkFIrHzTtFAdWpEAm4csLfL/X3eviOe45zrBOqAkd6xyWa2zsyeN7Ozu5y/7zjXBMDMbjKzYjMrrqqq6nHwT2wMU3awmZvPU20dEekb6gOJiJ/WlsUKJxcpsSOS1AKR2NkSbiDFYGaBhiGLDGBhYIJzbiFwK/CQmfXok45z7m7nXJFzrigvL69HTx6JOn66ejszR2dz4ZzRPXqsiMiJCoXrGZSawpRRQ/0ORUQGoDd2VzNl1FAVbxdJcoFI7ITC9UzJy1LBLpHgKwfGd7lf6O074jlmlgYMAw4659qccwcBnHNvADuAGd75hce55kl7alOYHVVNfPH8aaSkaLSOiPSNLeF6po/OIi01EF02EelHolFH8e4aiiZptI5IsgtEL0GrQYj0G2uB6WY22cwGAVcCq7qdswq41tu+HFjtnHNmlucVX8bMphArkrzTORcG6s3sNK8WzzXAn3oz6GjU8bPVpUzNG8rFp4zpzUuLiBxTSUWD+kAi4oudBxqpbe6gaKIKJ4sku6RP7NS1dLCvpkUrYon0A17NnJuBp4EQ8JhzbrOZ3WFml3mn3QOMNLNSYlOuDi2Jfg6w0czWEyuq/DnnXLV37PPAr4BSYiN5nurNuJ8J7aekooGbl08jVaN1RKSPHGhso6qhTYkdEfHF4fo6GrEjkvTS/A7geEpUNFCkX3HOPQk82W3f7V22W4ErjvC4/wH+5yjXLAZO6d1ID1+bnzy7nYkjh/DB+WMT8RQiIkf0TuFkfbklIn2vuKyGkUMHMVk1vkSSXtKP2DnUqdEynyLihzVbK9n8dj1fOHeaalyIDEBmtsLMtppZqZnddoTjGWb2qHf8NTOb1FvPfTixU6A+kMhA5Gf7A7HCyYsnDtdKoCIBkPSfUkLhBkYMHUR+tiqxi0jfio3WKWVc7mA+vOiIK6iLSD/m1fW6E7gYmANcZWZzup12A1DjnJsG/BD4bm89f0m4gYKcTIYPHdRblxSRgPC7/alqaKPsYLOmYYkERPIndirqmT0mW5liEelzL5YeYP3eWj5/3lTSNVpHZCBaCpQ653Y659qBR4CV3c5ZCTzgbf8eON96qdOyJVyvaVgiA5ev7c8bu2NlDIsmqXCySBDE9UnFz2GA587I41LVtRARHwwbnM6l88dw+eLC458sIv3ROGBvl/v7vH1HPMcrEF8HjOx+ITO7ycyKzay4qqoqridfPiufi+dpJT6RAarX2h/oeRs0MiuDlQvGcsrYYScSu4j0seMWT+4yDPD9xBqUtWa2yjm3pctph4cBmtmVxIYBfrw3Arz1wpm9cRkRkR6bX5jLz65e5HcYItIPOOfuBu4GKCoqcvE85qsrZiU0JhEZOHraBi2ZNIIlGq0jEhjxjNjxdRigiIiIiE/KgfFd7hd6+454jpmlAcOAg30SnYj0Z2p/RCRu8SR2fB2GLCIiIuKTtcB0M5tsZoOAK4FV3c5ZBVzrbV8OrHbOxTUiR0TkGNT+iEjc+rQaqHPubudckXOuKC8vry+fWkRERKRHvC+rbgaeBkLAY865zWZ2h5ld5p12DzDSzEqBW4H31CIUEekptT8i0hPHrbFDz4YB7tMwQBEREekvnHNPAk9223d7l+1W4Iq+jktE+j+1PyISr3hG7GgYoIiIiIiIiIhIErJ48i9m9gHgR0AqcK9z7t/N7A6g2Dm3yswygQeBhUA1cKVzbudxrlkF7I4zzlHAgTjP9VuQYoVgxatYEyfeeCc65/rFPMoetEH99W+ZDBRr4gQp3p7E2i/aIPWBkkaQ4lWsiaM+0NH1179lMlCsiROkeHutDxRXYsdvZlbsnCvyO454BClWCFa8ijVxghZvXwra7yZI8SrWxAlSvEGK1Q9B+v0EKVYIVryKNXGCFm9fCtrvJkjxKtbECVK8vRlrnxZPFhERERERERGR3qPEjoiIiIiIiIhIQAUlsXO33wH0QJBihWDFq1gTJ2jx9qWg/W6CFK9iTZwgxRukWP0QpN9PkGKFYMWrWBMnaPH2paD9boIUr2JNnCDF22uxBqLGjoiIiIiIiIiIvFdQRuyIiIiIiIiIiEg3SuyIiIiIiIiIiARUUid2zGyFmW01s1Izu83HOO41s0oz29Rl3wgze8bMtns/h3v7zcx+4sW80cwWdXnMtd75283s2gTFOt7M1pjZFjPbbGa3JGu8ZpZpZq+b2QYv1n/z9k82s9e8mB41s0He/gzvfql3fFKXa33d27/VzC7q7Vi7PE+qma0zsycCEGuZmb1lZuvNrNjbl3Tvg2SWDG2Q2p+Exqs2KEGxqv05ecnQ/nhxqA1SH+jQ8wSi/fGeR23QSUqGNkjtj/pA3WIORBvkW/vjnEvKG5AK7ACmAIOADcAcn2I5B1gEbOqy73vAbd72bcB3ve0PAE8BBpwGvObtHwHs9H4O97aHJyDWMcAibzsb2AbMScZ4vefM8rbTgde8GB4DrvT2/zfwD97254H/9ravBB71tud4748MYLL3vklN0HvhVuAh4AnvfjLHWgaM6rYv6d4HyXpLljZI7U9C41UblKBY1f6c9O8vKdofLxa1QeoDHYo5EO2P91xqg07u95cUbZDaH/WBusUciDbIr/anT/9x9vAXcjrwdJf7Xwe+7mM8k7o1KluBMd72GGCrt30XcFX384CrgLu67H/XeQmM+0/A+5M9XmAI8CawDDgApHV/HwBPA6d722needb9vdH1vF6OsRB4FlgOPOE9d1LG6l37SI1KUr8PkumWTG2Q2p/Ex6s2qNdjVftzcr+/pGl/vOdXG6Q+UGDaH+/aaoNO7veXNG2Q2h/1gbzrBqYN8qv9SeapWOOAvV3u7/P2JYvRzrmwt10BjPa2jxZ3n78eb9jZQmIZ2KSM1xtStx6oBJ4hljmtdc51HuF5D8fkHa8DRvZVrMCPgK8CUe/+yCSOFcAB/2tmb5jZTd6+pHwfJKlkfu1J/3cMQvvjxak2KDGxqv05Ocn+2pP+bxmENkjtj/pASSyZX3vS/x2D0P54caoN6kd9oLSTjVrAOefMzPkdR1dmlgX8D/CPzrl6Mzt8LJnidc5FgAVmlgs8DszyN6IjM7NLgUrn3Btmdq7P4cTrLOdcuZnlA8+YWUnXg8n0PpATl4x/x6C0P6A2KIHU/gwQyfi3DEobpPYnodQGDQDJ+HcMSvsDaoMSyJf2J5lH7JQD47vcL/T2JYv9ZjYGwPtZ6e0/Wtx99nrMLJ1Yg/Jb59wfkj1eAOdcLbCG2DC6XDM7lHTs+ryHY/KODwMO9lGsZwKXmVkZ8AixYYA/TtJYAXDOlXs/K4k11ktJ8vdBkknm1560f8cgtj+gNqiXY1X7c/KS/bUn7d8yiG2Q2p/epzbopCXza0/av2MQ2x9QG9TLsfrX/vT2nLLeuhEbTbSTWGGjQ0W75voYzyTePb/z+7y7ANL3vO1LeHcBpNe9/SOAXcSKHw33tkckIE4Dfg38qNv+pIsXyANyve3BwAvApcDveHchrM9721/g3YWwHvO25/LuQlg7SVDRLu/5zuWdol1JGSswFMjusv0ysCIZ3wfJekumNkjtT8LiVRuUgFjV/vTK7zBp2h8vHrVB6gMdijup2x/vedQGnfzvMGnaILU/6gN1izup2yA/258+/8fZw1/MB4hVFN8BfMPHOB4GwkAHsfltNxCbp/cssB3426FftPdHudOL+S2gqMt1Pg2UerfrExTrWcTm9W0E1nu3DyRjvMB8YJ0X6ybgdm//FOB173l/B2R4+zO9+6Xe8SldrvUN7zVsBS5O8Puha4OSlLF6cW3wbpsP/ftJxvdBMt+SoQ1S+5PQeNUGJSBWtT+99nv0vf3x4lAbpD5Q17iTuv3pEpfaoJP/PfreBqn9UR/oCHEndRvkZ/tj3oNERERERERERCRgkrnGjoiIiIiIiIiIHIMSOyIiIiIiIiIiAaXEjoiIiIiIiIhIQCmxIyIiIiIiIiISUErsiIiIiIiIiIgElBI7clLM7B/NbIjfcYjIwKQ2SET8ovZHRPykNki60nLnclLMrAwocs4d8DsWERl41AaJiF/U/oiIn9QGSVcasSNxM7OhZvYXM9tgZpvM7F+AscAaM1vjnXOhmb1iZm+a2e/MLMvbX2Zm3zOzt8zsdTOb5udrEZHgURskIn5R+yMiflIbJMejxI70xArgbefcqc65U4AfAW8D5znnzjOzUcA3gQucc4uAYuDWLo+vc87NA37mPVZEpCfUBomIX9T+iIif1AbJMSmxIz3xFvB+M/uumZ3tnKvrdvw0YA7wkpmtB64FJnY5/nCXn6cnOlgR6XfUBomIX9T+iIif1AbJMaX5HYAEh3Num5ktAj4AfNvMnu12igHPOOeuOtoljrItInJcaoNExC9qf0TET2qD5Hg0YkfiZmZjgWbn3G+A7wOLgAYg2zvlVeDMQ/M2vbmgM7pc4uNdfr7SN1GLSH+hNkhE/KL2R0T8pDZIjkcjdqQn5gHfN7Mo0AH8A7GhfH81s7e9+Z3XAQ+bWYb3mG8C27zt4Wa2EWgDjpZNFhE5GrVBIuIXtT8i4ie1QXJMWu5c+oRpOT4R8ZHaIBHxi9ofEfGT2qCBQVOxREREREREREQCSiN2REREREREREQCSiN2REREREREREQCSokdEREREREREZGAUmJHRERERERERCSglNgREREREREREQkoJXZERERERERERALq/wN2Tyd6lh5ojQAAAABJRU5ErkJggg==\n" + }, + "metadata": { + "needs_background": "light" + } + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "### Inference\n", + "\n", + "Using an example from the online Demo\n", + "\n", + "https://google-research.github.io/vision_transformer/lit/" + ], + "metadata": { + "id": "DDNZ7kkGNr8_" + } + }, + { + "cell_type": "code", + "source": [ + "!test -f apple-ipod.jpg || wget https://cdn.openai.com/multimodal-neurons/assets/apple/apple-ipod.jpg\n", + "\n", + "labels = [\n", + " 'an apple',\n", + " 'an ipod',\n", + " 'granny smith',\n", + " 'an apple with a note saying \"ipod\"',\n", + " 'an adversarial attack',\n", + "]\n", + "\n", + "import PIL\n", + "import numpy as np\n", + "img = np.array(PIL.Image.open('apple-ipod.jpg'))\n", + "import matplotlib.pyplot as plt\n", + "plt.imshow(img)\n", + "img.shape, img.dtype" + ], + "metadata": { + "id": "4Fp2PiiYYnqp", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 286 + }, + "outputId": "9c8498b3-4330-46ef-e8d1-05864e59fccf" + }, + "execution_count": 8, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "((322, 322, 3), dtype('uint8'))" + ] + }, + "metadata": {}, + "execution_count": 8 + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "\u003cFigure size 432x288 with 1 Axes\u003e" + ], + "image/png": "(internal link)\n" + }, + "metadata": { + "needs_background": "light" + } + } + ] + }, + { + "cell_type": "code", + "source": [ + "!test -f LiT-B16B.npz || gsutil cp gs://vit_models/lit/LiT-B16B.* ." + ], + "metadata": { + "id": "sQLpVwrFTt5P" + }, + "execution_count": 9, + "outputs": [] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "id": "RkWQEDrkmSf0", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 17 + }, + "outputId": "34b2cd6d-7c02-457b-a127-30f7e6676b35" + }, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "\u003cIPython.core.display.Javascript object\u003e" + ], + "application/javascript": [ + "\n", + " ((filepath) =\u003e {{\n", + " if (!google.colab.kernel.accessAllowed) {{\n", + " return;\n", + " }}\n", + " google.colab.files.view(filepath);\n", + " }})(\"/content/big_vision/big_vision/configs/proj/image_text/lit_coco.py\")" + ] + }, + "metadata": {} + } + ], + "source": [ + "files.view('big_vision/big_vision/configs/proj/image_text/siglip_lit_coco.py')\n", + "from big_vision.configs.proj.image_text import siglip_lit_coco as lit_coco\n", + "arg = 'txt=bert_base,img=B/16,img_head,init=LiT-B16B.npz'\n", + "config = lit_coco.get_config(arg)" + ] + }, + { + "cell_type": "code", + "source": [ + "# Initialize template params...\n", + "import importlib\n", + "import jax.numpy as jnp\n", + "\n", + "model_mod = importlib.import_module(f'big_vision.models.{config.model_name}')\n", + "\n", + "model = model_mod.Model(**config.model)\n", + "\n", + "init_params = [\n", + " jnp.zeros(shape, dtype)\n", + " for shape, dtype in zip(config.init_shapes, config.init_types)\n", + "]\n", + "\n", + "params0 = model.init(jax.random.PRNGKey(42), *init_params)['params'].unfreeze()" + ], + "metadata": { + "id": "2Frpb9t8NsnH" + }, + "execution_count": 12, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# ... and load/modify pre-trained params.\n", + "from big_vision import utils\n", + "# Note that `.load()` is responsible for parameter tree surgery to adapt old\n", + "# checkpoints to most recent source code.\n", + "params = model_mod.load(params0, 'LiT-B16B.npz', config.model)" + ], + "metadata": { + "id": "gyZMqL5C9usk", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "61b595ff-fccf-4367-b565-d7bcc74fbea6" + }, + "execution_count": 13, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "INFO:absl:ViT: Loading and fixing VERY old posemb\n", + "INFO:absl:ViT: Loading and fixing combined cls+posemb\n", + "/content/./big_vision/big_vision/utils.py:593: FutureWarning: jax.tree_flatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_flatten instead.\n", + " vals, tree_def = jax.tree_flatten(tree)\n", + "INFO:absl:\n", + "INFO:absl:\n", + "INFO:absl:Could not find original BERT checkpoint path 'LiT-B16B.npz:txt/bert_model.ckpt', loading big_vision checkpoint 'LiT-B16B.npz:txt'\n", + "INFO:absl:\n", + "INFO:absl:\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "# The preprocessing is optimized for efficiently streaming through TFDS\n", + "# datasets - below code runs it separately on the image and every text.\n", + "set_max_height(222)\n", + "\n", + "from big_vision.pp import builder as pp_builder\n", + "for pp_mod in config.pp_modules:\n", + " importlib.import_module(f'big_vision.pp.{pp_mod}')\n", + "\n", + "pp_str = config.evals.val.pp_fn.replace('decode|', '')\n", + "imgs = np.array(pp_builder.get_preprocess_fn(pp_str)({\n", + " 'image': img[None],\n", + " 'captions/text': np.array(['']),\n", + "})['image'])\n", + "txts = np.stack([\n", + " pp_builder.get_preprocess_fn(pp_str, log_data=False)({\n", + " 'image': img[None],\n", + " 'captions/text': np.array([label]),\n", + " })['labels']\n", + " for label in labels\n", + "])\n", + "imgs.shape, txts.shape" + ], + "metadata": { + "id": "EWq4la1QZqZh", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 222 + }, + "outputId": "cfe73bae-b359-4470-bf25-f8f0eb9254de" + }, + "execution_count": 14, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "\u003cIPython.core.display.Javascript object\u003e" + ], + "application/javascript": [ + "\n", + " google.colab.output.setIframeHeight(0, true, {maxHeight: 222})\n", + " " + ] + }, + "metadata": {} + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "INFO:absl:Data before pre-processing:\n", + "{'image': array([[[[123, 94, 52],\n", + " [123, 96, 53],\n", + " [139, 114, 74],\n", + " ...,\n", + " [ 64, 54, 19],\n", + " [ 60, 54, 22],\n", + " [ 54, 48, 16]],\n", + "\n", + " [[149, 125, 81],\n", + " [149, 126, 84],\n", + " [152, 130, 89],\n", + " ...,\n", + " [ 69, 55, 26],\n", + " [ 61, 54, 28],\n", + " [ 48, 41, 15]],\n", + "\n", + " [[152, 133, 91],\n", + " [155, 138, 95],\n", + " [153, 135, 95],\n", + " ...,\n", + " [ 75, 59, 33],\n", + " [ 68, 60, 39],\n", + " [ 45, 37, 16]],\n", + "\n", + " ...,\n", + "\n", + " [[112, 101, 73],\n", + " [115, 103, 77],\n", + " [177, 165, 141],\n", + " ...,\n", + " [ 15, 9, 9],\n", + " [ 1, 0, 0],\n", + " [ 1, 0, 0]],\n", + "\n", + " [[113, 98, 79],\n", + " [160, 145, 126],\n", + " [191, 175, 159],\n", + " ...,\n", + " [ 3, 4, 6],\n", + " [ 0, 0, 0],\n", + " [ 0, 0, 0]],\n", + "\n", + " [[113, 98, 79],\n", + " [160, 145, 126],\n", + " [191, 175, 159],\n", + " ...,\n", + " [ 3, 4, 6],\n", + " [ 0, 0, 0],\n", + " [ 0, 0, 0]]]], dtype=uint8), 'captions/text': array([''], dtype='\u003cU1')}\n", + "INFO:absl:Data after pre-processing:\n", + "{'image': \u003ctf.Tensor: shape=(1, 224, 224, 3), dtype=float32, numpy=\n", + "array([[[[ 0.00392163, -0.20784312, -0.54509807],\n", + " [ 0.07450986, -0.12156862, -0.44313723],\n", + " [ 0.0196079 , -0.14509803, -0.44313723],\n", + " ...,\n", + " [-0.5137255 , -0.5921569 , -0.8352941 ],\n", + " [-0.5058824 , -0.5764706 , -0.8352941 ],\n", + " [-0.5764706 , -0.62352943, -0.8666667 ]],\n", + "\n", + " [[ 0.18431377, 0.02745104, -0.3098039 ],\n", + " [ 0.19215691, 0.04313731, -0.27843136],\n", + " [-0.01960784, -0.1372549 , -0.44313723],\n", + " ...,\n", + " [-0.41176468, -0.5294118 , -0.7411765 ],\n", + " [-0.45098037, -0.5529412 , -0.7490196 ],\n", + " [-0.60784316, -0.67058825, -0.8509804 ]],\n", + "\n", + " [[ 0.00392163, -0.12156862, -0.41960782],\n", + " [-0.01176471, -0.1372549 , -0.42745095],\n", + " [-0.16862744, -0.27843136, -0.5686275 ],\n", + " ...,\n", + " [-0.3098039 , -0.44313723, -0.64705884],\n", + " [-0.35686272, -0.49019605, -0.6862745 ],\n", + " [-0.5372549 , -0.654902 , -0.8039216 ]],\n", + "\n", + " ...,\n", + "\n", + " [[-0.09803921, -0.19215685, -0.3960784 ],\n", + " [ 0.0196079 , -0.0745098 , -0.27058822],\n", + " [ 0.5529412 , 0.45098042, 0.28627455],\n", + " ...,\n", + " [-0.84313726, -0.88235295, -0.90588236],\n", + " [-0.94509804, -0.9607843 , -0.9764706 ],\n", + " [-0.99215686, -1. , -1. ]],\n", + "\n", + " [[-0.09019607, -0.19215685, -0.38039213],\n", + " [ 0.28627455, 0.17647064, 0.00392163],\n", + " [ 0.52156866, 0.41960788, 0.26274514],\n", + " ...,\n", + " [-0.8745098 , -0.92941177, -0.9372549 ],\n", + " [-0.94509804, -0.96862745, -0.9607843 ],\n", + " [-1. , -1. , -1. ]],\n", + "\n", + " [[-0.03529412, -0.15294117, -0.30196077],\n", + " [ 0.41176474, 0.28627455, 0.15294123],\n", + " [ 0.52156866, 0.39607847, 0.27058828],\n", + " ...,\n", + " [-0.9607843 , -0.9843137 , -0.9764706 ],\n", + " [-0.99215686, -0.9843137 , -0.9764706 ],\n", + " [-1. , -1. , -1. ]]]], dtype=float32)\u003e, 'labels': \u003ctf.Tensor: shape=(16,), dtype=int32, numpy=\n", + "array([101, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0], dtype=int32)\u003e}\n" + ] + }, + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "((1, 224, 224, 3), (5, 16))" + ] + }, + "metadata": {}, + "execution_count": 14 + } + ] + }, + { + "cell_type": "code", + "source": [ + "%debug" + ], + "metadata": { + "id": "2KQGmbebjTmt", + "outputId": "c88fbf9c-550a-464e-d8d7-c066b94671d2", + "colab": { + "base_uri": "https://localhost:8080/" + } + }, + "execution_count": 15, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "ERROR:root:No traceback has been produced, nothing to debug.\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "zimg, _, _ = model.apply({'params': params}, imgs, None)\n", + "_, ztxt, _ = model.apply({'params': params}, None, txts)" + ], + "metadata": { + "id": "naL1Ul1ablZ5" + }, + "execution_count": 16, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "probs = jax.nn.softmax((zimg[0] @ ztxt.T * np.exp(params['t'])))\n", + "list(zip(labels, probs.tolist()))" + ], + "metadata": { + "id": "TIdAVw9VGEAw", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "15d3f416-438e-4529-c43e-0b29702320fd" + }, + "execution_count": 17, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "[('an apple', 0.0028924213256686926),\n", + " ('an ipod', 1.852169361882261e-06),\n", + " ('granny smith', 1.0219791874988005e-05),\n", + " ('an apple with a note saying \"ipod\"', 0.9970954656600952),\n", + " ('an adversarial attack', 4.8630912630187595e-08)]" + ] + }, + "metadata": {}, + "execution_count": 17 + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "### Run evaluation\n", + "\n", + "Below code runs a minimal version of the `big_vision.tols.eval_only` script." + ], + "metadata": { + "id": "gFsqgYov9hAx" + } + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": { + "id": "juyZDho1n51-", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 17 + }, + "outputId": "35023da6-b321-4657-9b46-bc46a3f39593" + }, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "\u003cIPython.core.display.Javascript object\u003e" + ], + "application/javascript": [ + "\n", + " ((filepath) =\u003e {{\n", + " if (!google.colab.kernel.accessAllowed) {{\n", + " return;\n", + " }}\n", + " google.colab.files.view(filepath);\n", + " }})(\"/content/big_vision/big_vision/tools/eval_only.py\")" + ] + }, + "metadata": {} + } + ], + "source": [ + "files.view('big_vision/big_vision/tools/eval_only.py')\n", + "from big_vision.tools import eval_only" + ] + }, + { + "cell_type": "code", + "source": [ + "!test -f LiT-B16B.npz || gsutil cp gs://vit_models/lit/LiT-B16B.* ." + ], + "metadata": { + "id": "9wcyh0wJJb5z" + }, + "execution_count": 19, + "outputs": [] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": { + "id": "ZLKRKtQlEbH8" + }, + "outputs": [], + "source": [ + "from big_vision.configs.proj.image_text import lit_coco\n", + "arg = 'txt=bert_base,img=B/16,img_head,init=LiT-B16B.npz'\n", + "config = lit_coco.get_config(arg)" + ] + }, + { + "cell_type": "code", + "source": [ + "# From all the pre-defined evaluators...\n", + "set_max_height(222)\n", + "config.evals" + ], + "metadata": { + "id": "P7GJX-ekFaJM", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 222 + }, + "outputId": "12768e66-3d59-462e-c92c-7a2b31a5dcde" + }, + "execution_count": 21, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "\u003cIPython.core.display.Javascript object\u003e" + ], + "application/javascript": [ + "\n", + " google.colab.output.setIframeHeight(0, true, {maxHeight: 222})\n", + " " + ] + }, + "metadata": {} + }, + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "coco:\n", + " data:\n", + " name: coco_captions\n", + " split: val\n", + " log_steps: 500\n", + " pp_fn: decode|resize(224)|value_range(-1,1)|flatten|bert_tokenize(inkey=\"captions/text\",\n", + " max_len=16, vocab_path=\"LiT-B16B.txt\")|keep(\"image\", \"labels\")\n", + " type: proj.image_text.contrastive\n", + " use_global_batch: true\n", + "disclf:\n", + " log_steps: 500\n", + " pp_img: resize(224)|value_range(-1,1)\n", + " pp_txt: bert_tokenize(inkey=\"texts\", max_len=16, vocab_path=\"LiT-B16B.txt\")\n", + " prefix: z/0shot/\n", + " type: proj.image_text.discriminative_classifier\n", + "imagenet:\n", + " data:\n", + " name: imagenet2012\n", + " split: validation\n", + " log_steps: 500\n", + " pp_fn: decode|resize(224)|value_range(-1,1)|clip_i1k_label_names|bert_tokenize(inkey=\"labels\",\n", + " max_len=16, vocab_path=\"LiT-B16B.txt\")|keep(\"image\", \"labels\")\n", + " type: proj.image_text.contrastive\n", + " use_global_batch: true\n", + "retrieval_coco:\n", + " dataset: coco_captions\n", + " log_steps: 500\n", + " pp_img: resize(224)|value_range(-1, 1)\n", + " pp_txt: bert_tokenize(inkey=\"texts\", max_len=16, vocab_path=\"LiT-B16B.txt\")\n", + " prefix: z/retr/coco_\n", + " txt_name: !!python/tuple\n", + " - captions\n", + " - text\n", + " type: proj.image_text.retrieval\n", + "val:\n", + " data:\n", + " name: coco_captions\n", + " split: val\n", + " log_steps: 500\n", + " pp_fn: decode|resize(224)|value_range(-1,1)|flatten|bert_tokenize(inkey=\"captions/text\",\n", + " max_len=16, vocab_path=\"LiT-B16B.txt\")|keep(\"image\", \"labels\")\n", + " type: proj.image_text.contrastive\n", + " use_global_batch: true" + ] + }, + "metadata": {}, + "execution_count": 21 + } + ] + }, + { + "cell_type": "code", + "source": [ + "# ... run a single zeroshot discriminative classifier on pets.\n", + "config.evals = {\n", + " 'disclf': {\n", + " **config.evals.disclf,\n", + " 'dataset_names': ['oxford_iiit_pet'],\n", + " 'dataset_overrides': (),\n", + " },\n", + "}\n", + "config.evals" + ], + "metadata": { + "id": "aqqr9PB0Fh8w", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "33ecdf97-74fb-4b1c-8b5c-1b16e6287f27" + }, + "execution_count": 22, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "disclf:\n", + " dataset_names:\n", + " - oxford_iiit_pet\n", + " dataset_overrides: !!python/tuple []\n", + " log_steps: 500\n", + " pp_img: resize(224)|value_range(-1,1)\n", + " pp_txt: bert_tokenize(inkey=\"texts\", max_len=16, vocab_path=\"LiT-B16B.txt\")\n", + " prefix: z/0shot/\n", + " type: proj.image_text.discriminative_classifier" + ] + }, + "metadata": {}, + "execution_count": 22 + } + ] + }, + { + "cell_type": "code", + "source": [ + "# Prepare pets dataset.\n", + "tfds.builder('oxford_iiit_pet').download_and_prepare()" + ], + "metadata": { + "id": "AZMFa5NjAwu-", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "e183d317-5cf5-400f-af10-dd2ee30ee4a7" + }, + "execution_count": 23, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "INFO:absl:Load dataset info from /root/tensorflow_datasets/oxford_iiit_pet/3.2.0\n", + "INFO:absl:Reusing dataset oxford_iiit_pet (/root/tensorflow_datasets/oxford_iiit_pet/3.2.0)\n" + ] + } + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": { + "id": "YWipdlfQn7Zg", + "outputId": "2d4c760f-85d8-4cab-8bea-d4816709f975", + "colab": { + "base_uri": "https://localhost:8080/" + } + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "mkdir: cannot create directory ‘lit_coco_B16B_eval’: File exists\n" + ] + } + ], + "source": [ + "workdir = 'lit_coco_B16B_eval'\n", + "!mkdir $workdir\n", + "flags.FLAGS.workdir = workdir\n", + "config.input.batch_size = 512\n", + "flags.FLAGS.config = config" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": { + "id": "JRKXQ9KyoCVi", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 444 + }, + "outputId": "dcd18c18-20ec-4108-8139-e68218b45164" + }, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "\u003cIPython.core.display.Javascript object\u003e" + ], + "application/javascript": [ + "\n", + " google.colab.output.setIframeHeight(0, true, {maxHeight: 444})\n", + " " + ] + }, + "metadata": {} + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "INFO:absl:Workdir: lit_coco_B16B_eval\n", + "INFO:absl:NOTE: Initializing proj.image_text.two_towers model...\n", + "INFO:absl:\u001b[35m[0]\u001b[0m z/secs/init = 53.88540307900007\n", + "INFO:absl:TIMING[z/secs/init]: 53.88540307900007\n", + "INFO:absl:\n", + "init params\n", + "+-------------------------------------------------------------------------------------------------------+------------------+------------+-----------+----------+\n", + "| Name | Shape | Size | Mean | Std |\n", + "+-------------------------------------------------------------------------------------------------------+------------------+------------+-----------+----------+\n", + "| img/Transformer/encoder_norm/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| img/Transformer/encoder_norm/scale | (768,) | 768 | 1.0 | 0.0 |\n", + "| img/Transformer/encoderblock_0/LayerNorm_0/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| img/Transformer/encoderblock_0/LayerNorm_0/scale | (768,) | 768 | 1.0 | 0.0 |\n", + "| img/Transformer/encoderblock_0/LayerNorm_1/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| img/Transformer/encoderblock_0/LayerNorm_1/scale | (768,) | 768 | 1.0 | 0.0 |\n", + "| img/Transformer/encoderblock_0/MlpBlock_0/Dense_0/bias | (3072,) | 3,072 | 6.23e-09 | 1.03e-06 |\n", + "| img/Transformer/encoderblock_0/MlpBlock_0/Dense_0/kernel | (768, 3072) | 2,359,296 | -1.19e-05 | 0.0228 |\n", + "| img/Transformer/encoderblock_0/MlpBlock_0/Dense_1/bias | (768,) | 768 | -4.47e-08 | 1e-06 |\n", + "| img/Transformer/encoderblock_0/MlpBlock_0/Dense_1/kernel | (3072, 768) | 2,359,296 | -3.51e-07 | 0.0228 |\n", + "| img/Transformer/encoderblock_0/MultiHeadDotProductAttention_0/key/bias | (12, 64) | 768 | 0.0 | 0.0 |\n", + "| img/Transformer/encoderblock_0/MultiHeadDotProductAttention_0/key/kernel | (768, 12, 64) | 589,824 | 9.09e-05 | 0.0361 |\n", + "| img/Transformer/encoderblock_0/MultiHeadDotProductAttention_0/out/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| img/Transformer/encoderblock_0/MultiHeadDotProductAttention_0/out/kernel | (12, 64, 768) | 589,824 | -7.25e-05 | 0.0361 |\n", + "| img/Transformer/encoderblock_0/MultiHeadDotProductAttention_0/query/bias | (12, 64) | 768 | 0.0 | 0.0 |\n", + "| img/Transformer/encoderblock_0/MultiHeadDotProductAttention_0/query/kernel | (768, 12, 64) | 589,824 | 8.03e-05 | 0.0361 |\n", + "| img/Transformer/encoderblock_0/MultiHeadDotProductAttention_0/value/bias | (12, 64) | 768 | 0.0 | 0.0 |\n", + "| img/Transformer/encoderblock_0/MultiHeadDotProductAttention_0/value/kernel | (768, 12, 64) | 589,824 | -6.62e-05 | 0.0361 |\n", + "| img/Transformer/encoderblock_1/LayerNorm_0/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| img/Transformer/encoderblock_1/LayerNorm_0/scale | (768,) | 768 | 1.0 | 0.0 |\n", + "| img/Transformer/encoderblock_1/LayerNorm_1/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| img/Transformer/encoderblock_1/LayerNorm_1/scale | (768,) | 768 | 1.0 | 0.0 |\n", + "| img/Transformer/encoderblock_1/MlpBlock_0/Dense_0/bias | (3072,) | 3,072 | 1.16e-08 | 1.01e-06 |\n", + "| img/Transformer/encoderblock_1/MlpBlock_0/Dense_0/kernel | (768, 3072) | 2,359,296 | -8.66e-06 | 0.0228 |\n", + "| img/Transformer/encoderblock_1/MlpBlock_0/Dense_1/bias | (768,) | 768 | 3.66e-08 | 1e-06 |\n", + "| img/Transformer/encoderblock_1/MlpBlock_0/Dense_1/kernel | (3072, 768) | 2,359,296 | -6.84e-06 | 0.0228 |\n", + "| img/Transformer/encoderblock_1/MultiHeadDotProductAttention_0/key/bias | (12, 64) | 768 | 0.0 | 0.0 |\n", + "| img/Transformer/encoderblock_1/MultiHeadDotProductAttention_0/key/kernel | (768, 12, 64) | 589,824 | 4.01e-05 | 0.0361 |\n", + "| img/Transformer/encoderblock_1/MultiHeadDotProductAttention_0/out/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| img/Transformer/encoderblock_1/MultiHeadDotProductAttention_0/out/kernel | (12, 64, 768) | 589,824 | 8.36e-05 | 0.0361 |\n", + "| img/Transformer/encoderblock_1/MultiHeadDotProductAttention_0/query/bias | (12, 64) | 768 | 0.0 | 0.0 |\n", + "| img/Transformer/encoderblock_1/MultiHeadDotProductAttention_0/query/kernel | (768, 12, 64) | 589,824 | -1.96e-06 | 0.036 |\n", + "| img/Transformer/encoderblock_1/MultiHeadDotProductAttention_0/value/bias | (12, 64) | 768 | 0.0 | 0.0 |\n", + "| img/Transformer/encoderblock_1/MultiHeadDotProductAttention_0/value/kernel | (768, 12, 64) | 589,824 | -5.22e-05 | 0.0361 |\n", + "| img/Transformer/encoderblock_10/LayerNorm_0/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| img/Transformer/encoderblock_10/LayerNorm_0/scale | (768,) | 768 | 1.0 | 0.0 |\n", + "| img/Transformer/encoderblock_10/LayerNorm_1/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| img/Transformer/encoderblock_10/LayerNorm_1/scale | (768,) | 768 | 1.0 | 0.0 |\n", + "| img/Transformer/encoderblock_10/MlpBlock_0/Dense_0/bias | (3072,) | 3,072 | 2.45e-08 | 1.02e-06 |\n", + "| img/Transformer/encoderblock_10/MlpBlock_0/Dense_0/kernel | (768, 3072) | 2,359,296 | 1.55e-05 | 0.0228 |\n", + "| img/Transformer/encoderblock_10/MlpBlock_0/Dense_1/bias | (768,) | 768 | -1.63e-08 | 9.76e-07 |\n", + "| img/Transformer/encoderblock_10/MlpBlock_0/Dense_1/kernel | (3072, 768) | 2,359,296 | 7.67e-07 | 0.0228 |\n", + "| img/Transformer/encoderblock_10/MultiHeadDotProductAttention_0/key/bias | (12, 64) | 768 | 0.0 | 0.0 |\n", + "| img/Transformer/encoderblock_10/MultiHeadDotProductAttention_0/key/kernel | (768, 12, 64) | 589,824 | 3.51e-05 | 0.0361 |\n", + "| img/Transformer/encoderblock_10/MultiHeadDotProductAttention_0/out/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| img/Transformer/encoderblock_10/MultiHeadDotProductAttention_0/out/kernel | (12, 64, 768) | 589,824 | -6.83e-07 | 0.0361 |\n", + "| img/Transformer/encoderblock_10/MultiHeadDotProductAttention_0/query/bias | (12, 64) | 768 | 0.0 | 0.0 |\n", + "| img/Transformer/encoderblock_10/MultiHeadDotProductAttention_0/query/kernel | (768, 12, 64) | 589,824 | -5.43e-05 | 0.0361 |\n", + "| img/Transformer/encoderblock_10/MultiHeadDotProductAttention_0/value/bias | (12, 64) | 768 | 0.0 | 0.0 |\n", + "| img/Transformer/encoderblock_10/MultiHeadDotProductAttention_0/value/kernel | (768, 12, 64) | 589,824 | -1.21e-05 | 0.0361 |\n", + "| img/Transformer/encoderblock_11/LayerNorm_0/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| img/Transformer/encoderblock_11/LayerNorm_0/scale | (768,) | 768 | 1.0 | 0.0 |\n", + "| img/Transformer/encoderblock_11/LayerNorm_1/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| img/Transformer/encoderblock_11/LayerNorm_1/scale | (768,) | 768 | 1.0 | 0.0 |\n", + "| img/Transformer/encoderblock_11/MlpBlock_0/Dense_0/bias | (3072,) | 3,072 | 2.03e-08 | 9.84e-07 |\n", + "| img/Transformer/encoderblock_11/MlpBlock_0/Dense_0/kernel | (768, 3072) | 2,359,296 | 1.7e-05 | 0.0228 |\n", + "| img/Transformer/encoderblock_11/MlpBlock_0/Dense_1/bias | (768,) | 768 | -1.2e-08 | 1.01e-06 |\n", + "| img/Transformer/encoderblock_11/MlpBlock_0/Dense_1/kernel | (3072, 768) | 2,359,296 | 1.59e-05 | 0.0228 |\n", + "| img/Transformer/encoderblock_11/MultiHeadDotProductAttention_0/key/bias | (12, 64) | 768 | 0.0 | 0.0 |\n", + "| img/Transformer/encoderblock_11/MultiHeadDotProductAttention_0/key/kernel | (768, 12, 64) | 589,824 | 7.33e-05 | 0.0361 |\n", + "| img/Transformer/encoderblock_11/MultiHeadDotProductAttention_0/out/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| img/Transformer/encoderblock_11/MultiHeadDotProductAttention_0/out/kernel | (12, 64, 768) | 589,824 | 1.83e-05 | 0.036 |\n", + "| img/Transformer/encoderblock_11/MultiHeadDotProductAttention_0/query/bias | (12, 64) | 768 | 0.0 | 0.0 |\n", + "| img/Transformer/encoderblock_11/MultiHeadDotProductAttention_0/query/kernel | (768, 12, 64) | 589,824 | 1.72e-05 | 0.0361 |\n", + "| img/Transformer/encoderblock_11/MultiHeadDotProductAttention_0/value/bias | (12, 64) | 768 | 0.0 | 0.0 |\n", + "| img/Transformer/encoderblock_11/MultiHeadDotProductAttention_0/value/kernel | (768, 12, 64) | 589,824 | 1.88e-05 | 0.0361 |\n", + "| img/Transformer/encoderblock_2/LayerNorm_0/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| img/Transformer/encoderblock_2/LayerNorm_0/scale | (768,) | 768 | 1.0 | 0.0 |\n", + "| img/Transformer/encoderblock_2/LayerNorm_1/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| img/Transformer/encoderblock_2/LayerNorm_1/scale | (768,) | 768 | 1.0 | 0.0 |\n", + "| img/Transformer/encoderblock_2/MlpBlock_0/Dense_0/bias | (3072,) | 3,072 | -1.56e-08 | 1e-06 |\n", + "| img/Transformer/encoderblock_2/MlpBlock_0/Dense_0/kernel | (768, 3072) | 2,359,296 | -8.46e-06 | 0.0228 |\n", + "| img/Transformer/encoderblock_2/MlpBlock_0/Dense_1/bias | (768,) | 768 | -6.15e-08 | 9.86e-07 |\n", + "| img/Transformer/encoderblock_2/MlpBlock_0/Dense_1/kernel | (3072, 768) | 2,359,296 | -8.32e-06 | 0.0228 |\n", + "| img/Transformer/encoderblock_2/MultiHeadDotProductAttention_0/key/bias | (12, 64) | 768 | 0.0 | 0.0 |\n", + "| img/Transformer/encoderblock_2/MultiHeadDotProductAttention_0/key/kernel | (768, 12, 64) | 589,824 | 7.9e-06 | 0.0361 |\n", + "INFO:absl:\n", + "| img/Transformer/encoderblock_2/MultiHeadDotProductAttention_0/out/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| img/Transformer/encoderblock_2/MultiHeadDotProductAttention_0/out/kernel | (12, 64, 768) | 589,824 | 1.59e-05 | 0.0361 |\n", + "| img/Transformer/encoderblock_2/MultiHeadDotProductAttention_0/query/bias | (12, 64) | 768 | 0.0 | 0.0 |\n", + "| img/Transformer/encoderblock_2/MultiHeadDotProductAttention_0/query/kernel | (768, 12, 64) | 589,824 | -2.98e-06 | 0.0361 |\n", + "| img/Transformer/encoderblock_2/MultiHeadDotProductAttention_0/value/bias | (12, 64) | 768 | 0.0 | 0.0 |\n", + "| img/Transformer/encoderblock_2/MultiHeadDotProductAttention_0/value/kernel | (768, 12, 64) | 589,824 | -3.39e-05 | 0.0361 |\n", + "| img/Transformer/encoderblock_3/LayerNorm_0/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| img/Transformer/encoderblock_3/LayerNorm_0/scale | (768,) | 768 | 1.0 | 0.0 |\n", + "| img/Transformer/encoderblock_3/LayerNorm_1/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| img/Transformer/encoderblock_3/LayerNorm_1/scale | (768,) | 768 | 1.0 | 0.0 |\n", + "| img/Transformer/encoderblock_3/MlpBlock_0/Dense_0/bias | (3072,) | 3,072 | -5.33e-09 | 1.01e-06 |\n", + "| img/Transformer/encoderblock_3/MlpBlock_0/Dense_0/kernel | (768, 3072) | 2,359,296 | 7.81e-06 | 0.0228 |\n", + "| img/Transformer/encoderblock_3/MlpBlock_0/Dense_1/bias | (768,) | 768 | -1.28e-08 | 9.99e-07 |\n", + "| img/Transformer/encoderblock_3/MlpBlock_0/Dense_1/kernel | (3072, 768) | 2,359,296 | 1.42e-05 | 0.0228 |\n", + "| img/Transformer/encoderblock_3/MultiHeadDotProductAttention_0/key/bias | (12, 64) | 768 | 0.0 | 0.0 |\n", + "| img/Transformer/encoderblock_3/MultiHeadDotProductAttention_0/key/kernel | (768, 12, 64) | 589,824 | 1.1e-05 | 0.0361 |\n", + "| img/Transformer/encoderblock_3/MultiHeadDotProductAttention_0/out/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| img/Transformer/encoderblock_3/MultiHeadDotProductAttention_0/out/kernel | (12, 64, 768) | 589,824 | -0.000122 | 0.0361 |\n", + "| img/Transformer/encoderblock_3/MultiHeadDotProductAttention_0/query/bias | (12, 64) | 768 | 0.0 | 0.0 |\n", + "| img/Transformer/encoderblock_3/MultiHeadDotProductAttention_0/query/kernel | (768, 12, 64) | 589,824 | -4.36e-06 | 0.0361 |\n", + "| img/Transformer/encoderblock_3/MultiHeadDotProductAttention_0/value/bias | (12, 64) | 768 | 0.0 | 0.0 |\n", + "| img/Transformer/encoderblock_3/MultiHeadDotProductAttention_0/value/kernel | (768, 12, 64) | 589,824 | 2.49e-05 | 0.0361 |\n", + "| img/Transformer/encoderblock_4/LayerNorm_0/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| img/Transformer/encoderblock_4/LayerNorm_0/scale | (768,) | 768 | 1.0 | 0.0 |\n", + "| img/Transformer/encoderblock_4/LayerNorm_1/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| img/Transformer/encoderblock_4/LayerNorm_1/scale | (768,) | 768 | 1.0 | 0.0 |\n", + "| img/Transformer/encoderblock_4/MlpBlock_0/Dense_0/bias | (3072,) | 3,072 | 4.7e-09 | 9.95e-07 |\n", + "| img/Transformer/encoderblock_4/MlpBlock_0/Dense_0/kernel | (768, 3072) | 2,359,296 | 1.16e-05 | 0.0228 |\n", + "| img/Transformer/encoderblock_4/MlpBlock_0/Dense_1/bias | (768,) | 768 | -3.34e-08 | 9.85e-07 |\n", + "| img/Transformer/encoderblock_4/MlpBlock_0/Dense_1/kernel | (3072, 768) | 2,359,296 | 1.28e-05 | 0.0228 |\n", + "| img/Transformer/encoderblock_4/MultiHeadDotProductAttention_0/key/bias | (12, 64) | 768 | 0.0 | 0.0 |\n", + "| img/Transformer/encoderblock_4/MultiHeadDotProductAttention_0/key/kernel | (768, 12, 64) | 589,824 | -6.55e-06 | 0.0361 |\n", + "| img/Transformer/encoderblock_4/MultiHeadDotProductAttention_0/out/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| img/Transformer/encoderblock_4/MultiHeadDotProductAttention_0/out/kernel | (12, 64, 768) | 589,824 | 4.19e-05 | 0.0361 |\n", + "| img/Transformer/encoderblock_4/MultiHeadDotProductAttention_0/query/bias | (12, 64) | 768 | 0.0 | 0.0 |\n", + "| img/Transformer/encoderblock_4/MultiHeadDotProductAttention_0/query/kernel | (768, 12, 64) | 589,824 | 6.5e-05 | 0.0361 |\n", + "| img/Transformer/encoderblock_4/MultiHeadDotProductAttention_0/value/bias | (12, 64) | 768 | 0.0 | 0.0 |\n", + "| img/Transformer/encoderblock_4/MultiHeadDotProductAttention_0/value/kernel | (768, 12, 64) | 589,824 | 2.84e-05 | 0.0361 |\n", + "| img/Transformer/encoderblock_5/LayerNorm_0/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| img/Transformer/encoderblock_5/LayerNorm_0/scale | (768,) | 768 | 1.0 | 0.0 |\n", + "| img/Transformer/encoderblock_5/LayerNorm_1/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| img/Transformer/encoderblock_5/LayerNorm_1/scale | (768,) | 768 | 1.0 | 0.0 |\n", + "| img/Transformer/encoderblock_5/MlpBlock_0/Dense_0/bias | (3072,) | 3,072 | -5.96e-09 | 1.01e-06 |\n", + "| img/Transformer/encoderblock_5/MlpBlock_0/Dense_0/kernel | (768, 3072) | 2,359,296 | 3.21e-05 | 0.0228 |\n", + "| img/Transformer/encoderblock_5/MlpBlock_0/Dense_1/bias | (768,) | 768 | 3.41e-08 | 9.66e-07 |\n", + "| img/Transformer/encoderblock_5/MlpBlock_0/Dense_1/kernel | (3072, 768) | 2,359,296 | -8.61e-06 | 0.0228 |\n", + "| img/Transformer/encoderblock_5/MultiHeadDotProductAttention_0/key/bias | (12, 64) | 768 | 0.0 | 0.0 |\n", + "| img/Transformer/encoderblock_5/MultiHeadDotProductAttention_0/key/kernel | (768, 12, 64) | 589,824 | 6.99e-06 | 0.0361 |\n", + "| img/Transformer/encoderblock_5/MultiHeadDotProductAttention_0/out/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| img/Transformer/encoderblock_5/MultiHeadDotProductAttention_0/out/kernel | (12, 64, 768) | 589,824 | -4.31e-05 | 0.0361 |\n", + "| img/Transformer/encoderblock_5/MultiHeadDotProductAttention_0/query/bias | (12, 64) | 768 | 0.0 | 0.0 |\n", + "| img/Transformer/encoderblock_5/MultiHeadDotProductAttention_0/query/kernel | (768, 12, 64) | 589,824 | -3.99e-06 | 0.0361 |\n", + "| img/Transformer/encoderblock_5/MultiHeadDotProductAttention_0/value/bias | (12, 64) | 768 | 0.0 | 0.0 |\n", + "| img/Transformer/encoderblock_5/MultiHeadDotProductAttention_0/value/kernel | (768, 12, 64) | 589,824 | -4.3e-05 | 0.0361 |\n", + "| img/Transformer/encoderblock_6/LayerNorm_0/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| img/Transformer/encoderblock_6/LayerNorm_0/scale | (768,) | 768 | 1.0 | 0.0 |\n", + "| img/Transformer/encoderblock_6/LayerNorm_1/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| img/Transformer/encoderblock_6/LayerNorm_1/scale | (768,) | 768 | 1.0 | 0.0 |\n", + "| img/Transformer/encoderblock_6/MlpBlock_0/Dense_0/bias | (3072,) | 3,072 | 1.27e-08 | 1.01e-06 |\n", + "| img/Transformer/encoderblock_6/MlpBlock_0/Dense_0/kernel | (768, 3072) | 2,359,296 | -1.43e-05 | 0.0228 |\n", + "| img/Transformer/encoderblock_6/MlpBlock_0/Dense_1/bias | (768,) | 768 | -4.46e-08 | 1.01e-06 |\n", + "| img/Transformer/encoderblock_6/MlpBlock_0/Dense_1/kernel | (3072, 768) | 2,359,296 | 3.42e-06 | 0.0228 |\n", + "| img/Transformer/encoderblock_6/MultiHeadDotProductAttention_0/key/bias | (12, 64) | 768 | 0.0 | 0.0 |\n", + "| img/Transformer/encoderblock_6/MultiHeadDotProductAttention_0/key/kernel | (768, 12, 64) | 589,824 | 1.97e-05 | 0.0361 |\n", + "| img/Transformer/encoderblock_6/MultiHeadDotProductAttention_0/out/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| img/Transformer/encoderblock_6/MultiHeadDotProductAttention_0/out/kernel | (12, 64, 768) | 589,824 | 2.28e-05 | 0.0361 |\n", + "| img/Transformer/encoderblock_6/MultiHeadDotProductAttention_0/query/bias | (12, 64) | 768 | 0.0 | 0.0 |\n", + "| img/Transformer/encoderblock_6/MultiHeadDotProductAttention_0/query/kernel | (768, 12, 64) | 589,824 | -2.81e-05 | 0.0361 |\n", + "| img/Transformer/encoderblock_6/MultiHeadDotProductAttention_0/value/bias | (12, 64) | 768 | 0.0 | 0.0 |\n", + "| img/Transformer/encoderblock_6/MultiHeadDotProductAttention_0/value/kernel | (768, 12, 64) | 589,824 | 2.17e-05 | 0.0361 |\n", + "| img/Transformer/encoderblock_7/LayerNorm_0/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| img/Transformer/encoderblock_7/LayerNorm_0/scale | (768,) | 768 | 1.0 | 0.0 |\n", + "| img/Transformer/encoderblock_7/LayerNorm_1/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| img/Transformer/encoderblock_7/LayerNorm_1/scale | (768,) | 768 | 1.0 | 0.0 |\n", + "| img/Transformer/encoderblock_7/MlpBlock_0/Dense_0/bias | (3072,) | 3,072 | -1.34e-08 | 9.87e-07 |\n", + "| img/Transformer/encoderblock_7/MlpBlock_0/Dense_0/kernel | (768, 3072) | 2,359,296 | 3.61e-07 | 0.0228 |\n", + "| img/Transformer/encoderblock_7/MlpBlock_0/Dense_1/bias | (768,) | 768 | -7.23e-08 | 1e-06 |\n", + "| img/Transformer/encoderblock_7/MlpBlock_0/Dense_1/kernel | (3072, 768) | 2,359,296 | 7.83e-06 | 0.0228 |\n", + "| img/Transformer/encoderblock_7/MultiHeadDotProductAttention_0/key/bias | (12, 64) | 768 | 0.0 | 0.0 |\n", + "| img/Transformer/encoderblock_7/MultiHeadDotProductAttention_0/key/kernel | (768, 12, 64) | 589,824 | -3.07e-05 | 0.0361 |\n", + "INFO:absl:\n", + "| img/Transformer/encoderblock_7/MultiHeadDotProductAttention_0/out/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| img/Transformer/encoderblock_7/MultiHeadDotProductAttention_0/out/kernel | (12, 64, 768) | 589,824 | 5.08e-05 | 0.0361 |\n", + "| img/Transformer/encoderblock_7/MultiHeadDotProductAttention_0/query/bias | (12, 64) | 768 | 0.0 | 0.0 |\n", + "| img/Transformer/encoderblock_7/MultiHeadDotProductAttention_0/query/kernel | (768, 12, 64) | 589,824 | -3.23e-05 | 0.036 |\n", + "| img/Transformer/encoderblock_7/MultiHeadDotProductAttention_0/value/bias | (12, 64) | 768 | 0.0 | 0.0 |\n", + "| img/Transformer/encoderblock_7/MultiHeadDotProductAttention_0/value/kernel | (768, 12, 64) | 589,824 | -2.02e-05 | 0.0361 |\n", + "| img/Transformer/encoderblock_8/LayerNorm_0/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| img/Transformer/encoderblock_8/LayerNorm_0/scale | (768,) | 768 | 1.0 | 0.0 |\n", + "| img/Transformer/encoderblock_8/LayerNorm_1/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| img/Transformer/encoderblock_8/LayerNorm_1/scale | (768,) | 768 | 1.0 | 0.0 |\n", + "| img/Transformer/encoderblock_8/MlpBlock_0/Dense_0/bias | (3072,) | 3,072 | 3.85e-10 | 1e-06 |\n", + "| img/Transformer/encoderblock_8/MlpBlock_0/Dense_0/kernel | (768, 3072) | 2,359,296 | 1.22e-06 | 0.0228 |\n", + "| img/Transformer/encoderblock_8/MlpBlock_0/Dense_1/bias | (768,) | 768 | 3.61e-08 | 1.01e-06 |\n", + "| img/Transformer/encoderblock_8/MlpBlock_0/Dense_1/kernel | (3072, 768) | 2,359,296 | 4.09e-07 | 0.0228 |\n", + "| img/Transformer/encoderblock_8/MultiHeadDotProductAttention_0/key/bias | (12, 64) | 768 | 0.0 | 0.0 |\n", + "| img/Transformer/encoderblock_8/MultiHeadDotProductAttention_0/key/kernel | (768, 12, 64) | 589,824 | 1.49e-05 | 0.0361 |\n", + "| img/Transformer/encoderblock_8/MultiHeadDotProductAttention_0/out/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| img/Transformer/encoderblock_8/MultiHeadDotProductAttention_0/out/kernel | (12, 64, 768) | 589,824 | -5.84e-05 | 0.0361 |\n", + "| img/Transformer/encoderblock_8/MultiHeadDotProductAttention_0/query/bias | (12, 64) | 768 | 0.0 | 0.0 |\n", + "| img/Transformer/encoderblock_8/MultiHeadDotProductAttention_0/query/kernel | (768, 12, 64) | 589,824 | 4.68e-05 | 0.0361 |\n", + "| img/Transformer/encoderblock_8/MultiHeadDotProductAttention_0/value/bias | (12, 64) | 768 | 0.0 | 0.0 |\n", + "| img/Transformer/encoderblock_8/MultiHeadDotProductAttention_0/value/kernel | (768, 12, 64) | 589,824 | 4.5e-05 | 0.036 |\n", + "| img/Transformer/encoderblock_9/LayerNorm_0/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| img/Transformer/encoderblock_9/LayerNorm_0/scale | (768,) | 768 | 1.0 | 0.0 |\n", + "| img/Transformer/encoderblock_9/LayerNorm_1/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| img/Transformer/encoderblock_9/LayerNorm_1/scale | (768,) | 768 | 1.0 | 0.0 |\n", + "| img/Transformer/encoderblock_9/MlpBlock_0/Dense_0/bias | (3072,) | 3,072 | -1.35e-08 | 9.85e-07 |\n", + "| img/Transformer/encoderblock_9/MlpBlock_0/Dense_0/kernel | (768, 3072) | 2,359,296 | -1.19e-05 | 0.0228 |\n", + "| img/Transformer/encoderblock_9/MlpBlock_0/Dense_1/bias | (768,) | 768 | 1.02e-08 | 9.43e-07 |\n", + "| img/Transformer/encoderblock_9/MlpBlock_0/Dense_1/kernel | (3072, 768) | 2,359,296 | 2.57e-05 | 0.0228 |\n", + "| img/Transformer/encoderblock_9/MultiHeadDotProductAttention_0/key/bias | (12, 64) | 768 | 0.0 | 0.0 |\n", + "| img/Transformer/encoderblock_9/MultiHeadDotProductAttention_0/key/kernel | (768, 12, 64) | 589,824 | -0.000109 | 0.0361 |\n", + "| img/Transformer/encoderblock_9/MultiHeadDotProductAttention_0/out/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| img/Transformer/encoderblock_9/MultiHeadDotProductAttention_0/out/kernel | (12, 64, 768) | 589,824 | 6.98e-05 | 0.0361 |\n", + "| img/Transformer/encoderblock_9/MultiHeadDotProductAttention_0/query/bias | (12, 64) | 768 | 0.0 | 0.0 |\n", + "| img/Transformer/encoderblock_9/MultiHeadDotProductAttention_0/query/kernel | (768, 12, 64) | 589,824 | 6.36e-05 | 0.0361 |\n", + "| img/Transformer/encoderblock_9/MultiHeadDotProductAttention_0/value/bias | (12, 64) | 768 | 0.0 | 0.0 |\n", + "| img/Transformer/encoderblock_9/MultiHeadDotProductAttention_0/value/kernel | (768, 12, 64) | 589,824 | -1.12e-05 | 0.0361 |\n", + "| img/cls | (1, 1, 768) | 768 | 0.0 | 0.0 |\n", + "| img/embedding/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| img/embedding/kernel | (16, 16, 3, 768) | 589,824 | 4.59e-06 | 0.0361 |\n", + "| img/head/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| img/head/kernel | (768, 768) | 589,824 | -3.08e-05 | 0.0361 |\n", + "| img/pos_embedding | (1, 196, 768) | 150,528 | 4.29e-05 | 0.0361 |\n", + "| t | (1,) | 1 | 2.3 | 0.0 |\n", + "| txt/BertEncoder_0/embedder/embedders_position_ids/embedding | (512, 768) | 393,216 | 2.24e-05 | 0.0176 |\n", + "| txt/BertEncoder_0/embedder/embedders_segment_ids/embedding | (2, 768) | 1,536 | -0.000444 | 0.0177 |\n", + "| txt/BertEncoder_0/embedder/embedders_token_ids/embedding | (30522, 768) | 23,440,896 | 8.75e-07 | 0.0176 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_0/attention_block/attention_layer/key/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_0/attention_block/attention_layer/key/kernel | (768, 768) | 589,824 | 9.06e-06 | 0.0176 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_0/attention_block/attention_layer/query/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_0/attention_block/attention_layer/query/kernel | (768, 768) | 589,824 | 5.48e-06 | 0.0176 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_0/attention_block/attention_layer/value/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_0/attention_block/attention_layer/value/kernel | (768, 768) | 589,824 | 3.62e-05 | 0.0176 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_0/attention_block/dense_layer/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_0/attention_block/dense_layer/kernel | (768, 768) | 589,824 | -2.55e-05 | 0.0176 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_0/attention_block/layer_norm/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_0/attention_block/layer_norm/scale | (768,) | 768 | 1.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_0/mlp_block/dense_layer/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_0/mlp_block/dense_layer/kernel | (3072, 768) | 2,359,296 | -5.72e-06 | 0.0176 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_0/mlp_block/layer_norm/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_0/mlp_block/layer_norm/scale | (768,) | 768 | 1.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_0/mlp_block/mlp/dense_layer/bias | (3072,) | 3,072 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_0/mlp_block/mlp/dense_layer/kernel | (768, 3072) | 2,359,296 | -5.07e-06 | 0.0176 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_1/attention_block/attention_layer/key/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_1/attention_block/attention_layer/key/kernel | (768, 768) | 589,824 | -3.57e-05 | 0.0176 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_1/attention_block/attention_layer/query/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_1/attention_block/attention_layer/query/kernel | (768, 768) | 589,824 | 4.81e-06 | 0.0176 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_1/attention_block/attention_layer/value/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_1/attention_block/attention_layer/value/kernel | (768, 768) | 589,824 | 2.48e-05 | 0.0176 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_1/attention_block/dense_layer/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_1/attention_block/dense_layer/kernel | (768, 768) | 589,824 | -3.36e-05 | 0.0176 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_1/attention_block/layer_norm/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_1/attention_block/layer_norm/scale | (768,) | 768 | 1.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_1/mlp_block/dense_layer/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_1/mlp_block/dense_layer/kernel | (3072, 768) | 2,359,296 | -3.38e-06 | 0.0176 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_1/mlp_block/layer_norm/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_1/mlp_block/layer_norm/scale | (768,) | 768 | 1.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_1/mlp_block/mlp/dense_layer/bias | (3072,) | 3,072 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_1/mlp_block/mlp/dense_layer/kernel | (768, 3072) | 2,359,296 | 1.24e-05 | 0.0176 |\n", + "INFO:absl:\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_10/attention_block/attention_layer/key/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_10/attention_block/attention_layer/key/kernel | (768, 768) | 589,824 | 2.89e-06 | 0.0176 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_10/attention_block/attention_layer/query/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_10/attention_block/attention_layer/query/kernel | (768, 768) | 589,824 | -5.85e-05 | 0.0176 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_10/attention_block/attention_layer/value/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_10/attention_block/attention_layer/value/kernel | (768, 768) | 589,824 | -2.57e-05 | 0.0176 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_10/attention_block/dense_layer/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_10/attention_block/dense_layer/kernel | (768, 768) | 589,824 | -1.51e-05 | 0.0176 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_10/attention_block/layer_norm/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_10/attention_block/layer_norm/scale | (768,) | 768 | 1.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_10/mlp_block/dense_layer/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_10/mlp_block/dense_layer/kernel | (3072, 768) | 2,359,296 | 1.63e-05 | 0.0176 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_10/mlp_block/layer_norm/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_10/mlp_block/layer_norm/scale | (768,) | 768 | 1.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_10/mlp_block/mlp/dense_layer/bias | (3072,) | 3,072 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_10/mlp_block/mlp/dense_layer/kernel | (768, 3072) | 2,359,296 | 4.42e-06 | 0.0176 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_11/attention_block/attention_layer/key/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_11/attention_block/attention_layer/key/kernel | (768, 768) | 589,824 | 3.27e-05 | 0.0176 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_11/attention_block/attention_layer/query/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_11/attention_block/attention_layer/query/kernel | (768, 768) | 589,824 | -2.42e-05 | 0.0176 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_11/attention_block/attention_layer/value/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_11/attention_block/attention_layer/value/kernel | (768, 768) | 589,824 | 3.01e-07 | 0.0176 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_11/attention_block/dense_layer/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_11/attention_block/dense_layer/kernel | (768, 768) | 589,824 | -1.24e-05 | 0.0176 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_11/attention_block/layer_norm/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_11/attention_block/layer_norm/scale | (768,) | 768 | 1.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_11/mlp_block/dense_layer/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_11/mlp_block/dense_layer/kernel | (3072, 768) | 2,359,296 | 1.53e-06 | 0.0176 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_11/mlp_block/layer_norm/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_11/mlp_block/layer_norm/scale | (768,) | 768 | 1.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_11/mlp_block/mlp/dense_layer/bias | (3072,) | 3,072 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_11/mlp_block/mlp/dense_layer/kernel | (768, 3072) | 2,359,296 | 1.69e-05 | 0.0176 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_2/attention_block/attention_layer/key/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_2/attention_block/attention_layer/key/kernel | (768, 768) | 589,824 | -2.72e-05 | 0.0176 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_2/attention_block/attention_layer/query/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_2/attention_block/attention_layer/query/kernel | (768, 768) | 589,824 | -8.02e-06 | 0.0176 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_2/attention_block/attention_layer/value/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_2/attention_block/attention_layer/value/kernel | (768, 768) | 589,824 | -9.89e-06 | 0.0176 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_2/attention_block/dense_layer/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_2/attention_block/dense_layer/kernel | (768, 768) | 589,824 | -8.38e-06 | 0.0176 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_2/attention_block/layer_norm/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_2/attention_block/layer_norm/scale | (768,) | 768 | 1.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_2/mlp_block/dense_layer/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_2/mlp_block/dense_layer/kernel | (3072, 768) | 2,359,296 | -4.68e-06 | 0.0176 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_2/mlp_block/layer_norm/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_2/mlp_block/layer_norm/scale | (768,) | 768 | 1.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_2/mlp_block/mlp/dense_layer/bias | (3072,) | 3,072 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_2/mlp_block/mlp/dense_layer/kernel | (768, 3072) | 2,359,296 | 1.8e-05 | 0.0176 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_3/attention_block/attention_layer/key/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_3/attention_block/attention_layer/key/kernel | (768, 768) | 589,824 | 2.57e-05 | 0.0176 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_3/attention_block/attention_layer/query/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_3/attention_block/attention_layer/query/kernel | (768, 768) | 589,824 | 4.71e-05 | 0.0176 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_3/attention_block/attention_layer/value/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_3/attention_block/attention_layer/value/kernel | (768, 768) | 589,824 | -8.88e-06 | 0.0176 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_3/attention_block/dense_layer/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_3/attention_block/dense_layer/kernel | (768, 768) | 589,824 | 1.81e-05 | 0.0176 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_3/attention_block/layer_norm/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_3/attention_block/layer_norm/scale | (768,) | 768 | 1.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_3/mlp_block/dense_layer/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_3/mlp_block/dense_layer/kernel | (3072, 768) | 2,359,296 | 1.51e-05 | 0.0176 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_3/mlp_block/layer_norm/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_3/mlp_block/layer_norm/scale | (768,) | 768 | 1.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_3/mlp_block/mlp/dense_layer/bias | (3072,) | 3,072 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_3/mlp_block/mlp/dense_layer/kernel | (768, 3072) | 2,359,296 | 1.5e-07 | 0.0176 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_4/attention_block/attention_layer/key/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_4/attention_block/attention_layer/key/kernel | (768, 768) | 589,824 | -6.07e-06 | 0.0176 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_4/attention_block/attention_layer/query/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_4/attention_block/attention_layer/query/kernel | (768, 768) | 589,824 | 2.15e-05 | 0.0176 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_4/attention_block/attention_layer/value/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_4/attention_block/attention_layer/value/kernel | (768, 768) | 589,824 | -2.21e-05 | 0.0176 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_4/attention_block/dense_layer/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_4/attention_block/dense_layer/kernel | (768, 768) | 589,824 | -6.73e-06 | 0.0176 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_4/attention_block/layer_norm/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_4/attention_block/layer_norm/scale | (768,) | 768 | 1.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_4/mlp_block/dense_layer/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_4/mlp_block/dense_layer/kernel | (3072, 768) | 2,359,296 | -5.38e-06 | 0.0176 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_4/mlp_block/layer_norm/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_4/mlp_block/layer_norm/scale | (768,) | 768 | 1.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_4/mlp_block/mlp/dense_layer/bias | (3072,) | 3,072 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_4/mlp_block/mlp/dense_layer/kernel | (768, 3072) | 2,359,296 | -1.15e-05 | 0.0176 |\n", + "INFO:absl:\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_5/attention_block/attention_layer/key/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_5/attention_block/attention_layer/key/kernel | (768, 768) | 589,824 | 3.9e-06 | 0.0176 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_5/attention_block/attention_layer/query/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_5/attention_block/attention_layer/query/kernel | (768, 768) | 589,824 | 2.32e-05 | 0.0176 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_5/attention_block/attention_layer/value/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_5/attention_block/attention_layer/value/kernel | (768, 768) | 589,824 | -5.67e-06 | 0.0176 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_5/attention_block/dense_layer/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_5/attention_block/dense_layer/kernel | (768, 768) | 589,824 | -2.08e-05 | 0.0176 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_5/attention_block/layer_norm/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_5/attention_block/layer_norm/scale | (768,) | 768 | 1.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_5/mlp_block/dense_layer/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_5/mlp_block/dense_layer/kernel | (3072, 768) | 2,359,296 | -1.6e-05 | 0.0176 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_5/mlp_block/layer_norm/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_5/mlp_block/layer_norm/scale | (768,) | 768 | 1.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_5/mlp_block/mlp/dense_layer/bias | (3072,) | 3,072 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_5/mlp_block/mlp/dense_layer/kernel | (768, 3072) | 2,359,296 | 3.79e-06 | 0.0176 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_6/attention_block/attention_layer/key/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_6/attention_block/attention_layer/key/kernel | (768, 768) | 589,824 | 4.87e-05 | 0.0176 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_6/attention_block/attention_layer/query/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_6/attention_block/attention_layer/query/kernel | (768, 768) | 589,824 | -2.25e-05 | 0.0176 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_6/attention_block/attention_layer/value/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_6/attention_block/attention_layer/value/kernel | (768, 768) | 589,824 | 1.32e-05 | 0.0176 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_6/attention_block/dense_layer/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_6/attention_block/dense_layer/kernel | (768, 768) | 589,824 | -3.61e-06 | 0.0176 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_6/attention_block/layer_norm/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_6/attention_block/layer_norm/scale | (768,) | 768 | 1.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_6/mlp_block/dense_layer/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_6/mlp_block/dense_layer/kernel | (3072, 768) | 2,359,296 | -4.74e-06 | 0.0176 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_6/mlp_block/layer_norm/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_6/mlp_block/layer_norm/scale | (768,) | 768 | 1.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_6/mlp_block/mlp/dense_layer/bias | (3072,) | 3,072 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_6/mlp_block/mlp/dense_layer/kernel | (768, 3072) | 2,359,296 | 1.82e-05 | 0.0176 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_7/attention_block/attention_layer/key/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_7/attention_block/attention_layer/key/kernel | (768, 768) | 589,824 | 5.25e-06 | 0.0176 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_7/attention_block/attention_layer/query/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_7/attention_block/attention_layer/query/kernel | (768, 768) | 589,824 | -1.27e-05 | 0.0176 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_7/attention_block/attention_layer/value/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_7/attention_block/attention_layer/value/kernel | (768, 768) | 589,824 | -9.13e-06 | 0.0176 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_7/attention_block/dense_layer/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_7/attention_block/dense_layer/kernel | (768, 768) | 589,824 | -1.41e-05 | 0.0176 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_7/attention_block/layer_norm/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_7/attention_block/layer_norm/scale | (768,) | 768 | 1.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_7/mlp_block/dense_layer/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_7/mlp_block/dense_layer/kernel | (3072, 768) | 2,359,296 | 6.44e-06 | 0.0176 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_7/mlp_block/layer_norm/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_7/mlp_block/layer_norm/scale | (768,) | 768 | 1.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_7/mlp_block/mlp/dense_layer/bias | (3072,) | 3,072 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_7/mlp_block/mlp/dense_layer/kernel | (768, 3072) | 2,359,296 | 1.11e-05 | 0.0176 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_8/attention_block/attention_layer/key/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_8/attention_block/attention_layer/key/kernel | (768, 768) | 589,824 | 1.21e-05 | 0.0176 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_8/attention_block/attention_layer/query/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_8/attention_block/attention_layer/query/kernel | (768, 768) | 589,824 | 1.18e-05 | 0.0176 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_8/attention_block/attention_layer/value/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_8/attention_block/attention_layer/value/kernel | (768, 768) | 589,824 | 3.39e-05 | 0.0176 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_8/attention_block/dense_layer/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_8/attention_block/dense_layer/kernel | (768, 768) | 589,824 | -4.93e-05 | 0.0176 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_8/attention_block/layer_norm/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_8/attention_block/layer_norm/scale | (768,) | 768 | 1.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_8/mlp_block/dense_layer/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_8/mlp_block/dense_layer/kernel | (3072, 768) | 2,359,296 | 1.02e-05 | 0.0176 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_8/mlp_block/layer_norm/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_8/mlp_block/layer_norm/scale | (768,) | 768 | 1.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_8/mlp_block/mlp/dense_layer/bias | (3072,) | 3,072 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_8/mlp_block/mlp/dense_layer/kernel | (768, 3072) | 2,359,296 | 4.39e-08 | 0.0176 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_9/attention_block/attention_layer/key/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_9/attention_block/attention_layer/key/kernel | (768, 768) | 589,824 | -2.86e-05 | 0.0176 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_9/attention_block/attention_layer/query/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_9/attention_block/attention_layer/query/kernel | (768, 768) | 589,824 | 1.33e-06 | 0.0176 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_9/attention_block/attention_layer/value/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_9/attention_block/attention_layer/value/kernel | (768, 768) | 589,824 | -1.02e-05 | 0.0176 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_9/attention_block/dense_layer/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_9/attention_block/dense_layer/kernel | (768, 768) | 589,824 | 1.6e-05 | 0.0176 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_9/attention_block/layer_norm/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_9/attention_block/layer_norm/scale | (768,) | 768 | 1.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_9/mlp_block/dense_layer/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_9/mlp_block/dense_layer/kernel | (3072, 768) | 2,359,296 | 2.31e-05 | 0.0176 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_9/mlp_block/layer_norm/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_9/mlp_block/layer_norm/scale | (768,) | 768 | 1.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_9/mlp_block/mlp/dense_layer/bias | (3072,) | 3,072 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_9/mlp_block/mlp/dense_layer/kernel | (768, 3072) | 2,359,296 | 5.85e-06 | 0.0176 |\n", + "INFO:absl:\n", + "| txt/BertEncoder_0/layer_norm/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/BertEncoder_0/layer_norm/scale | (768,) | 768 | 1.0 | 0.0 |\n", + "| txt/head/bias | (768,) | 768 | 0.0 | 0.0 |\n", + "| txt/head/kernel | (768, 768) | 589,824 | 5.69e-05 | 0.0361 |\n", + "+-------------------------------------------------------------------------------------------------------+------------------+------------+-----------+----------+\n", + "Total: 195,870,721\n", + "/content/./big_vision/big_vision/tools/eval_only.py:93: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.\n", + " num_params = sum(p.size for p in jax.tree_leaves(params_cpu))\n", + "INFO:absl:\u001b[35m[0]\u001b[0m num_params = 195870721.0\n", + "INFO:absl:NOTE: Initialize model from LiT-B16B.npz...\n", + "INFO:absl:ViT: Loading and fixing VERY old posemb\n", + "INFO:absl:ViT: Loading and fixing combined cls+posemb\n", + "INFO:absl:\n", + "INFO:absl:\n", + "INFO:absl:Could not find original BERT checkpoint path 'LiT-B16B.npz:txt/bert_model.ckpt', loading big_vision checkpoint 'LiT-B16B.npz:txt'\n", + "INFO:absl:\n", + "INFO:absl:\n", + "INFO:absl:\n", + "loaded params\n", + "+-------------------------------------------------------------------------------------------------------+------------------+------------+-----------+--------+\n", + "| Name | Shape | Size | Mean | Std |\n", + "+-------------------------------------------------------------------------------------------------------+------------------+------------+-----------+--------+\n", + "| img/Transformer/encoder_norm/bias | (768,) | 768 | 0.0156 | 0.669 |\n", + "| img/Transformer/encoder_norm/scale | (768,) | 768 | 2.83 | 0.42 |\n", + "| img/Transformer/encoderblock_0/LayerNorm_0/bias | (768,) | 768 | 0.00364 | 0.216 |\n", + "| img/Transformer/encoderblock_0/LayerNorm_0/scale | (768,) | 768 | 0.479 | 0.652 |\n", + "| img/Transformer/encoderblock_0/LayerNorm_1/bias | (768,) | 768 | 0.0148 | 0.243 |\n", + "| img/Transformer/encoderblock_0/LayerNorm_1/scale | (768,) | 768 | 1.49 | 2.31 |\n", + "| img/Transformer/encoderblock_0/MlpBlock_0/Dense_0/bias | (3072,) | 3,072 | -2.17 | 1.25 |\n", + "| img/Transformer/encoderblock_0/MlpBlock_0/Dense_0/kernel | (768, 3072) | 2,359,296 | -4.87e-05 | 0.0203 |\n", + "| img/Transformer/encoderblock_0/MlpBlock_0/Dense_1/bias | (768,) | 768 | -0.0293 | 0.857 |\n", + "| img/Transformer/encoderblock_0/MlpBlock_0/Dense_1/kernel | (3072, 768) | 2,359,296 | -4.37e-05 | 0.0204 |\n", + "| img/Transformer/encoderblock_0/MultiHeadDotProductAttention_0/key/bias | (12, 64) | 768 | 0.0052 | 0.517 |\n", + "| img/Transformer/encoderblock_0/MultiHeadDotProductAttention_0/key/kernel | (768, 12, 64) | 589,824 | 3.47e-05 | 0.0305 |\n", + "| img/Transformer/encoderblock_0/MultiHeadDotProductAttention_0/out/bias | (768,) | 768 | 0.00287 | 0.316 |\n", + "| img/Transformer/encoderblock_0/MultiHeadDotProductAttention_0/out/kernel | (12, 64, 768) | 589,824 | 1.99e-05 | 0.0189 |\n", + "| img/Transformer/encoderblock_0/MultiHeadDotProductAttention_0/query/bias | (12, 64) | 768 | -0.0976 | 1.21 |\n", + "| img/Transformer/encoderblock_0/MultiHeadDotProductAttention_0/query/kernel | (768, 12, 64) | 589,824 | 6.25e-05 | 0.03 |\n", + "| img/Transformer/encoderblock_0/MultiHeadDotProductAttention_0/value/bias | (12, 64) | 768 | -0.00342 | 0.0574 |\n", + "| img/Transformer/encoderblock_0/MultiHeadDotProductAttention_0/value/kernel | (768, 12, 64) | 589,824 | 1.58e-05 | 0.0161 |\n", + "| img/Transformer/encoderblock_1/LayerNorm_0/bias | (768,) | 768 | 0.00114 | 0.201 |\n", + "| img/Transformer/encoderblock_1/LayerNorm_0/scale | (768,) | 768 | 0.577 | 0.342 |\n", + "| img/Transformer/encoderblock_1/LayerNorm_1/bias | (768,) | 768 | -0.00644 | 0.252 |\n", + "| img/Transformer/encoderblock_1/LayerNorm_1/scale | (768,) | 768 | 0.93 | 0.62 |\n", + "| img/Transformer/encoderblock_1/MlpBlock_0/Dense_0/bias | (3072,) | 3,072 | -1.94 | 1.13 |\n", + "| img/Transformer/encoderblock_1/MlpBlock_0/Dense_0/kernel | (768, 3072) | 2,359,296 | 1.62e-05 | 0.0219 |\n", + "| img/Transformer/encoderblock_1/MlpBlock_0/Dense_1/bias | (768,) | 768 | -0.0188 | 0.45 |\n", + "| img/Transformer/encoderblock_1/MlpBlock_0/Dense_1/kernel | (3072, 768) | 2,359,296 | -4.68e-05 | 0.0195 |\n", + "| img/Transformer/encoderblock_1/MultiHeadDotProductAttention_0/key/bias | (12, 64) | 768 | 0.00409 | 0.577 |\n", + "| img/Transformer/encoderblock_1/MultiHeadDotProductAttention_0/key/kernel | (768, 12, 64) | 589,824 | -5.47e-05 | 0.03 |\n", + "| img/Transformer/encoderblock_1/MultiHeadDotProductAttention_0/out/bias | (768,) | 768 | -0.0102 | 0.522 |\n", + "| img/Transformer/encoderblock_1/MultiHeadDotProductAttention_0/out/kernel | (12, 64, 768) | 589,824 | -8.04e-06 | 0.0188 |\n", + "| img/Transformer/encoderblock_1/MultiHeadDotProductAttention_0/query/bias | (12, 64) | 768 | -0.00333 | 0.937 |\n", + "| img/Transformer/encoderblock_1/MultiHeadDotProductAttention_0/query/kernel | (768, 12, 64) | 589,824 | -3.2e-05 | 0.029 |\n", + "| img/Transformer/encoderblock_1/MultiHeadDotProductAttention_0/value/bias | (12, 64) | 768 | 0.000525 | 0.0589 |\n", + "| img/Transformer/encoderblock_1/MultiHeadDotProductAttention_0/value/kernel | (768, 12, 64) | 589,824 | -3.73e-06 | 0.0187 |\n", + "| img/Transformer/encoderblock_10/LayerNorm_0/bias | (768,) | 768 | 0.026 | 0.372 |\n", + "| img/Transformer/encoderblock_10/LayerNorm_0/scale | (768,) | 768 | 1.52 | 0.266 |\n", + "| img/Transformer/encoderblock_10/LayerNorm_1/bias | (768,) | 768 | 0.341 | 3.69 |\n", + "| img/Transformer/encoderblock_10/LayerNorm_1/scale | (768,) | 768 | 12.9 | 1.44 |\n", + "| img/Transformer/encoderblock_10/MlpBlock_0/Dense_0/bias | (3072,) | 3,072 | -2.35 | 0.379 |\n", + "| img/Transformer/encoderblock_10/MlpBlock_0/Dense_0/kernel | (768, 3072) | 2,359,296 | -0.000734 | 0.0239 |\n", + "| img/Transformer/encoderblock_10/MlpBlock_0/Dense_1/bias | (768,) | 768 | -0.00256 | 0.823 |\n", + "| img/Transformer/encoderblock_10/MlpBlock_0/Dense_1/kernel | (3072, 768) | 2,359,296 | 0.000145 | 0.029 |\n", + "| img/Transformer/encoderblock_10/MultiHeadDotProductAttention_0/key/bias | (12, 64) | 768 | 0.0212 | 0.483 |\n", + "| img/Transformer/encoderblock_10/MultiHeadDotProductAttention_0/key/kernel | (768, 12, 64) | 589,824 | -3.03e-05 | 0.0234 |\n", + "| img/Transformer/encoderblock_10/MultiHeadDotProductAttention_0/out/bias | (768,) | 768 | 0.00581 | 0.446 |\n", + "| img/Transformer/encoderblock_10/MultiHeadDotProductAttention_0/out/kernel | (12, 64, 768) | 589,824 | 5.76e-07 | 0.0239 |\n", + "| img/Transformer/encoderblock_10/MultiHeadDotProductAttention_0/query/bias | (12, 64) | 768 | -0.00684 | 0.722 |\n", + "| img/Transformer/encoderblock_10/MultiHeadDotProductAttention_0/query/kernel | (768, 12, 64) | 589,824 | 2.64e-05 | 0.023 |\n", + "| img/Transformer/encoderblock_10/MultiHeadDotProductAttention_0/value/bias | (12, 64) | 768 | 0.00383 | 0.112 |\n", + "| img/Transformer/encoderblock_10/MultiHeadDotProductAttention_0/value/kernel | (768, 12, 64) | 589,824 | -1.78e-05 | 0.0251 |\n", + "| img/Transformer/encoderblock_11/LayerNorm_0/bias | (768,) | 768 | 0.0345 | 0.563 |\n", + "| img/Transformer/encoderblock_11/LayerNorm_0/scale | (768,) | 768 | 1.56 | 0.249 |\n", + "| img/Transformer/encoderblock_11/LayerNorm_1/bias | (768,) | 768 | 0.00785 | 0.773 |\n", + "| img/Transformer/encoderblock_11/LayerNorm_1/scale | (768,) | 768 | 1.83 | 0.17 |\n", + "| img/Transformer/encoderblock_11/MlpBlock_0/Dense_0/bias | (3072,) | 3,072 | -1.99 | 0.585 |\n", + "| img/Transformer/encoderblock_11/MlpBlock_0/Dense_0/kernel | (768, 3072) | 2,359,296 | -0.000184 | 0.0239 |\n", + "| img/Transformer/encoderblock_11/MlpBlock_0/Dense_1/bias | (768,) | 768 | 0.0356 | 0.813 |\n", + "| img/Transformer/encoderblock_11/MlpBlock_0/Dense_1/kernel | (3072, 768) | 2,359,296 | 1.74e-05 | 0.0252 |\n", + "| img/Transformer/encoderblock_11/MultiHeadDotProductAttention_0/key/bias | (12, 64) | 768 | -0.0142 | 0.464 |\n", + "| img/Transformer/encoderblock_11/MultiHeadDotProductAttention_0/key/kernel | (768, 12, 64) | 589,824 | 8.66e-06 | 0.0241 |\n", + "| img/Transformer/encoderblock_11/MultiHeadDotProductAttention_0/out/bias | (768,) | 768 | 0.0271 | 0.974 |\n", + "| img/Transformer/encoderblock_11/MultiHeadDotProductAttention_0/out/kernel | (12, 64, 768) | 589,824 | -1.33e-06 | 0.0251 |\n", + "| img/Transformer/encoderblock_11/MultiHeadDotProductAttention_0/query/bias | (12, 64) | 768 | -0.0283 | 0.608 |\n", + "| img/Transformer/encoderblock_11/MultiHeadDotProductAttention_0/query/kernel | (768, 12, 64) | 589,824 | 1.1e-05 | 0.023 |\n", + "| img/Transformer/encoderblock_11/MultiHeadDotProductAttention_0/value/bias | (12, 64) | 768 | 0.0065 | 0.186 |\n", + "| img/Transformer/encoderblock_11/MultiHeadDotProductAttention_0/value/kernel | (768, 12, 64) | 589,824 | 3.94e-05 | 0.0273 |\n", + "| img/Transformer/encoderblock_2/LayerNorm_0/bias | (768,) | 768 | -0.00249 | 0.194 |\n", + "| img/Transformer/encoderblock_2/LayerNorm_0/scale | (768,) | 768 | 0.721 | 0.304 |\n", + "| img/Transformer/encoderblock_2/LayerNorm_1/bias | (768,) | 768 | 0.0129 | 0.228 |\n", + "| img/Transformer/encoderblock_2/LayerNorm_1/scale | (768,) | 768 | 1.14 | 0.448 |\n", + "| img/Transformer/encoderblock_2/MlpBlock_0/Dense_0/bias | (3072,) | 3,072 | -1.81 | 0.824 |\n", + "| img/Transformer/encoderblock_2/MlpBlock_0/Dense_0/kernel | (768, 3072) | 2,359,296 | 5.89e-05 | 0.022 |\n", + "| img/Transformer/encoderblock_2/MlpBlock_0/Dense_1/bias | (768,) | 768 | -0.00659 | 0.28 |\n", + "| img/Transformer/encoderblock_2/MlpBlock_0/Dense_1/kernel | (3072, 768) | 2,359,296 | 2.74e-06 | 0.0187 |\n", + "| img/Transformer/encoderblock_2/MultiHeadDotProductAttention_0/key/bias | (12, 64) | 768 | 0.00162 | 0.536 |\n", + "| img/Transformer/encoderblock_2/MultiHeadDotProductAttention_0/key/kernel | (768, 12, 64) | 589,824 | 2.21e-05 | 0.0274 |\n", + "INFO:absl:\n", + "| img/Transformer/encoderblock_2/MultiHeadDotProductAttention_0/out/bias | (768,) | 768 | -0.00715 | 0.291 |\n", + "| img/Transformer/encoderblock_2/MultiHeadDotProductAttention_0/out/kernel | (12, 64, 768) | 589,824 | 1.15e-05 | 0.0187 |\n", + "| img/Transformer/encoderblock_2/MultiHeadDotProductAttention_0/query/bias | (12, 64) | 768 | -0.00435 | 0.808 |\n", + "| img/Transformer/encoderblock_2/MultiHeadDotProductAttention_0/query/kernel | (768, 12, 64) | 589,824 | 2.73e-05 | 0.0278 |\n", + "| img/Transformer/encoderblock_2/MultiHeadDotProductAttention_0/value/bias | (12, 64) | 768 | -0.00516 | 0.0618 |\n", + "| img/Transformer/encoderblock_2/MultiHeadDotProductAttention_0/value/kernel | (768, 12, 64) | 589,824 | 6.66e-06 | 0.0191 |\n", + "| img/Transformer/encoderblock_3/LayerNorm_0/bias | (768,) | 768 | -0.000398 | 0.224 |\n", + "| img/Transformer/encoderblock_3/LayerNorm_0/scale | (768,) | 768 | 0.859 | 0.275 |\n", + "| img/Transformer/encoderblock_3/LayerNorm_1/bias | (768,) | 768 | 0.0115 | 0.263 |\n", + "| img/Transformer/encoderblock_3/LayerNorm_1/scale | (768,) | 768 | 1.12 | 0.257 |\n", + "| img/Transformer/encoderblock_3/MlpBlock_0/Dense_0/bias | (3072,) | 3,072 | -1.81 | 0.753 |\n", + "| img/Transformer/encoderblock_3/MlpBlock_0/Dense_0/kernel | (768, 3072) | 2,359,296 | 7.85e-05 | 0.0216 |\n", + "| img/Transformer/encoderblock_3/MlpBlock_0/Dense_1/bias | (768,) | 768 | -0.00283 | 0.194 |\n", + "| img/Transformer/encoderblock_3/MlpBlock_0/Dense_1/kernel | (3072, 768) | 2,359,296 | 5.27e-06 | 0.0188 |\n", + "| img/Transformer/encoderblock_3/MultiHeadDotProductAttention_0/key/bias | (12, 64) | 768 | -0.000427 | 0.543 |\n", + "| img/Transformer/encoderblock_3/MultiHeadDotProductAttention_0/key/kernel | (768, 12, 64) | 589,824 | 5.47e-06 | 0.0257 |\n", + "| img/Transformer/encoderblock_3/MultiHeadDotProductAttention_0/out/bias | (768,) | 768 | -0.0088 | 0.165 |\n", + "| img/Transformer/encoderblock_3/MultiHeadDotProductAttention_0/out/kernel | (12, 64, 768) | 589,824 | 7.88e-07 | 0.019 |\n", + "| img/Transformer/encoderblock_3/MultiHeadDotProductAttention_0/query/bias | (12, 64) | 768 | -0.0371 | 0.781 |\n", + "| img/Transformer/encoderblock_3/MultiHeadDotProductAttention_0/query/kernel | (768, 12, 64) | 589,824 | 5.56e-06 | 0.0261 |\n", + "| img/Transformer/encoderblock_3/MultiHeadDotProductAttention_0/value/bias | (12, 64) | 768 | 0.0039 | 0.0784 |\n", + "| img/Transformer/encoderblock_3/MultiHeadDotProductAttention_0/value/kernel | (768, 12, 64) | 589,824 | -3.02e-05 | 0.0198 |\n", + "| img/Transformer/encoderblock_4/LayerNorm_0/bias | (768,) | 768 | 0.000815 | 0.21 |\n", + "| img/Transformer/encoderblock_4/LayerNorm_0/scale | (768,) | 768 | 0.91 | 0.279 |\n", + "| img/Transformer/encoderblock_4/LayerNorm_1/bias | (768,) | 768 | 0.0128 | 0.256 |\n", + "| img/Transformer/encoderblock_4/LayerNorm_1/scale | (768,) | 768 | 1.14 | 0.209 |\n", + "| img/Transformer/encoderblock_4/MlpBlock_0/Dense_0/bias | (3072,) | 3,072 | -1.7 | 0.642 |\n", + "| img/Transformer/encoderblock_4/MlpBlock_0/Dense_0/kernel | (768, 3072) | 2,359,296 | 0.000135 | 0.0214 |\n", + "| img/Transformer/encoderblock_4/MlpBlock_0/Dense_1/bias | (768,) | 768 | -0.00268 | 0.149 |\n", + "| img/Transformer/encoderblock_4/MlpBlock_0/Dense_1/kernel | (3072, 768) | 2,359,296 | -9.72e-07 | 0.0189 |\n", + "| img/Transformer/encoderblock_4/MultiHeadDotProductAttention_0/key/bias | (12, 64) | 768 | 0.0218 | 0.562 |\n", + "| img/Transformer/encoderblock_4/MultiHeadDotProductAttention_0/key/kernel | (768, 12, 64) | 589,824 | 2e-05 | 0.025 |\n", + "| img/Transformer/encoderblock_4/MultiHeadDotProductAttention_0/out/bias | (768,) | 768 | -0.0103 | 0.154 |\n", + "| img/Transformer/encoderblock_4/MultiHeadDotProductAttention_0/out/kernel | (12, 64, 768) | 589,824 | 5.23e-06 | 0.019 |\n", + "| img/Transformer/encoderblock_4/MultiHeadDotProductAttention_0/query/bias | (12, 64) | 768 | 0.0216 | 0.725 |\n", + "| img/Transformer/encoderblock_4/MultiHeadDotProductAttention_0/query/kernel | (768, 12, 64) | 589,824 | 1.28e-05 | 0.0253 |\n", + "| img/Transformer/encoderblock_4/MultiHeadDotProductAttention_0/value/bias | (12, 64) | 768 | -0.000919 | 0.0754 |\n", + "| img/Transformer/encoderblock_4/MultiHeadDotProductAttention_0/value/kernel | (768, 12, 64) | 589,824 | -1.21e-05 | 0.0195 |\n", + "| img/Transformer/encoderblock_5/LayerNorm_0/bias | (768,) | 768 | 0.004 | 0.199 |\n", + "| img/Transformer/encoderblock_5/LayerNorm_0/scale | (768,) | 768 | 0.937 | 0.325 |\n", + "| img/Transformer/encoderblock_5/LayerNorm_1/bias | (768,) | 768 | 0.00563 | 0.288 |\n", + "| img/Transformer/encoderblock_5/LayerNorm_1/scale | (768,) | 768 | 1.22 | 0.231 |\n", + "| img/Transformer/encoderblock_5/MlpBlock_0/Dense_0/bias | (3072,) | 3,072 | -1.78 | 0.594 |\n", + "| img/Transformer/encoderblock_5/MlpBlock_0/Dense_0/kernel | (768, 3072) | 2,359,296 | 2.95e-05 | 0.0216 |\n", + "| img/Transformer/encoderblock_5/MlpBlock_0/Dense_1/bias | (768,) | 768 | -0.00481 | 0.16 |\n", + "| img/Transformer/encoderblock_5/MlpBlock_0/Dense_1/kernel | (3072, 768) | 2,359,296 | -1.36e-06 | 0.0198 |\n", + "| img/Transformer/encoderblock_5/MultiHeadDotProductAttention_0/key/bias | (12, 64) | 768 | 0.0157 | 0.535 |\n", + "| img/Transformer/encoderblock_5/MultiHeadDotProductAttention_0/key/kernel | (768, 12, 64) | 589,824 | 2.92e-05 | 0.0234 |\n", + "| img/Transformer/encoderblock_5/MultiHeadDotProductAttention_0/out/bias | (768,) | 768 | -0.0106 | 0.149 |\n", + "| img/Transformer/encoderblock_5/MultiHeadDotProductAttention_0/out/kernel | (12, 64, 768) | 589,824 | 1.08e-06 | 0.0204 |\n", + "| img/Transformer/encoderblock_5/MultiHeadDotProductAttention_0/query/bias | (12, 64) | 768 | -0.00102 | 0.668 |\n", + "| img/Transformer/encoderblock_5/MultiHeadDotProductAttention_0/query/kernel | (768, 12, 64) | 589,824 | 3.46e-05 | 0.0239 |\n", + "| img/Transformer/encoderblock_5/MultiHeadDotProductAttention_0/value/bias | (12, 64) | 768 | 0.0078 | 0.0672 |\n", + "| img/Transformer/encoderblock_5/MultiHeadDotProductAttention_0/value/kernel | (768, 12, 64) | 589,824 | -7.83e-07 | 0.0212 |\n", + "| img/Transformer/encoderblock_6/LayerNorm_0/bias | (768,) | 768 | 0.00578 | 0.175 |\n", + "| img/Transformer/encoderblock_6/LayerNorm_0/scale | (768,) | 768 | 0.971 | 0.349 |\n", + "| img/Transformer/encoderblock_6/LayerNorm_1/bias | (768,) | 768 | 0.00789 | 0.278 |\n", + "| img/Transformer/encoderblock_6/LayerNorm_1/scale | (768,) | 768 | 1.3 | 0.22 |\n", + "| img/Transformer/encoderblock_6/MlpBlock_0/Dense_0/bias | (3072,) | 3,072 | -1.82 | 0.551 |\n", + "| img/Transformer/encoderblock_6/MlpBlock_0/Dense_0/kernel | (768, 3072) | 2,359,296 | 4.07e-05 | 0.0216 |\n", + "| img/Transformer/encoderblock_6/MlpBlock_0/Dense_1/bias | (768,) | 768 | -0.00527 | 0.191 |\n", + "| img/Transformer/encoderblock_6/MlpBlock_0/Dense_1/kernel | (3072, 768) | 2,359,296 | -7.42e-06 | 0.021 |\n", + "| img/Transformer/encoderblock_6/MultiHeadDotProductAttention_0/key/bias | (12, 64) | 768 | 0.0274 | 0.523 |\n", + "| img/Transformer/encoderblock_6/MultiHeadDotProductAttention_0/key/kernel | (768, 12, 64) | 589,824 | -9.41e-06 | 0.0228 |\n", + "| img/Transformer/encoderblock_6/MultiHeadDotProductAttention_0/out/bias | (768,) | 768 | -0.00814 | 0.153 |\n", + "| img/Transformer/encoderblock_6/MultiHeadDotProductAttention_0/out/kernel | (12, 64, 768) | 589,824 | 5.8e-06 | 0.0213 |\n", + "| img/Transformer/encoderblock_6/MultiHeadDotProductAttention_0/query/bias | (12, 64) | 768 | -0.0326 | 0.634 |\n", + "| img/Transformer/encoderblock_6/MultiHeadDotProductAttention_0/query/kernel | (768, 12, 64) | 589,824 | -4.62e-06 | 0.0231 |\n", + "| img/Transformer/encoderblock_6/MultiHeadDotProductAttention_0/value/bias | (12, 64) | 768 | 0.00246 | 0.0661 |\n", + "| img/Transformer/encoderblock_6/MultiHeadDotProductAttention_0/value/kernel | (768, 12, 64) | 589,824 | -3.42e-06 | 0.0221 |\n", + "| img/Transformer/encoderblock_7/LayerNorm_0/bias | (768,) | 768 | 0.00394 | 0.19 |\n", + "| img/Transformer/encoderblock_7/LayerNorm_0/scale | (768,) | 768 | 1.04 | 0.357 |\n", + "| img/Transformer/encoderblock_7/LayerNorm_1/bias | (768,) | 768 | 0.0107 | 0.298 |\n", + "| img/Transformer/encoderblock_7/LayerNorm_1/scale | (768,) | 768 | 1.46 | 0.211 |\n", + "| img/Transformer/encoderblock_7/MlpBlock_0/Dense_0/bias | (3072,) | 3,072 | -1.97 | 0.549 |\n", + "| img/Transformer/encoderblock_7/MlpBlock_0/Dense_0/kernel | (768, 3072) | 2,359,296 | -4.23e-05 | 0.0215 |\n", + "| img/Transformer/encoderblock_7/MlpBlock_0/Dense_1/bias | (768,) | 768 | -0.00318 | 0.452 |\n", + "| img/Transformer/encoderblock_7/MlpBlock_0/Dense_1/kernel | (3072, 768) | 2,359,296 | 1.13e-05 | 0.0228 |\n", + "| img/Transformer/encoderblock_7/MultiHeadDotProductAttention_0/key/bias | (12, 64) | 768 | 0.0143 | 0.525 |\n", + "| img/Transformer/encoderblock_7/MultiHeadDotProductAttention_0/key/kernel | (768, 12, 64) | 589,824 | 1.5e-05 | 0.0227 |\n", + "INFO:absl:\n", + "| img/Transformer/encoderblock_7/MultiHeadDotProductAttention_0/out/bias | (768,) | 768 | -0.00723 | 0.188 |\n", + "| img/Transformer/encoderblock_7/MultiHeadDotProductAttention_0/out/kernel | (12, 64, 768) | 589,824 | 9.01e-06 | 0.0218 |\n", + "| img/Transformer/encoderblock_7/MultiHeadDotProductAttention_0/query/bias | (12, 64) | 768 | 0.0222 | 0.576 |\n", + "| img/Transformer/encoderblock_7/MultiHeadDotProductAttention_0/query/kernel | (768, 12, 64) | 589,824 | 7.02e-06 | 0.0229 |\n", + "| img/Transformer/encoderblock_7/MultiHeadDotProductAttention_0/value/bias | (12, 64) | 768 | 0.0016 | 0.0952 |\n", + "| img/Transformer/encoderblock_7/MultiHeadDotProductAttention_0/value/kernel | (768, 12, 64) | 589,824 | -4.78e-06 | 0.0225 |\n", + "| img/Transformer/encoderblock_8/LayerNorm_0/bias | (768,) | 768 | 0.00784 | 0.224 |\n", + "| img/Transformer/encoderblock_8/LayerNorm_0/scale | (768,) | 768 | 1.13 | 0.296 |\n", + "| img/Transformer/encoderblock_8/LayerNorm_1/bias | (768,) | 768 | 0.108 | 0.646 |\n", + "| img/Transformer/encoderblock_8/LayerNorm_1/scale | (768,) | 768 | 3.29 | 0.521 |\n", + "| img/Transformer/encoderblock_8/MlpBlock_0/Dense_0/bias | (3072,) | 3,072 | -2.01 | 0.613 |\n", + "| img/Transformer/encoderblock_8/MlpBlock_0/Dense_0/kernel | (768, 3072) | 2,359,296 | -0.000557 | 0.0217 |\n", + "| img/Transformer/encoderblock_8/MlpBlock_0/Dense_1/bias | (768,) | 768 | -0.00247 | 0.407 |\n", + "| img/Transformer/encoderblock_8/MlpBlock_0/Dense_1/kernel | (3072, 768) | 2,359,296 | 1.16e-05 | 0.0236 |\n", + "| img/Transformer/encoderblock_8/MultiHeadDotProductAttention_0/key/bias | (12, 64) | 768 | -0.00489 | 0.508 |\n", + "| img/Transformer/encoderblock_8/MultiHeadDotProductAttention_0/key/kernel | (768, 12, 64) | 589,824 | 9.53e-06 | 0.0228 |\n", + "| img/Transformer/encoderblock_8/MultiHeadDotProductAttention_0/out/bias | (768,) | 768 | -0.00214 | 0.455 |\n", + "| img/Transformer/encoderblock_8/MultiHeadDotProductAttention_0/out/kernel | (12, 64, 768) | 589,824 | -3.41e-06 | 0.022 |\n", + "| img/Transformer/encoderblock_8/MultiHeadDotProductAttention_0/query/bias | (12, 64) | 768 | 0.0533 | 0.646 |\n", + "| img/Transformer/encoderblock_8/MultiHeadDotProductAttention_0/query/kernel | (768, 12, 64) | 589,824 | 2.5e-05 | 0.0225 |\n", + "| img/Transformer/encoderblock_8/MultiHeadDotProductAttention_0/value/bias | (12, 64) | 768 | -0.000388 | 0.125 |\n", + "| img/Transformer/encoderblock_8/MultiHeadDotProductAttention_0/value/kernel | (768, 12, 64) | 589,824 | 1.94e-05 | 0.0229 |\n", + "| img/Transformer/encoderblock_9/LayerNorm_0/bias | (768,) | 768 | 0.0155 | 0.265 |\n", + "| img/Transformer/encoderblock_9/LayerNorm_0/scale | (768,) | 768 | 1.31 | 0.255 |\n", + "| img/Transformer/encoderblock_9/LayerNorm_1/bias | (768,) | 768 | 0.297 | 1.35 |\n", + "| img/Transformer/encoderblock_9/LayerNorm_1/scale | (768,) | 768 | 7.01 | 0.96 |\n", + "| img/Transformer/encoderblock_9/MlpBlock_0/Dense_0/bias | (3072,) | 3,072 | -2.14 | 0.514 |\n", + "| img/Transformer/encoderblock_9/MlpBlock_0/Dense_0/kernel | (768, 3072) | 2,359,296 | -0.000968 | 0.0227 |\n", + "| img/Transformer/encoderblock_9/MlpBlock_0/Dense_1/bias | (768,) | 768 | -0.00751 | 0.479 |\n", + "| img/Transformer/encoderblock_9/MlpBlock_0/Dense_1/kernel | (3072, 768) | 2,359,296 | -1.9e-06 | 0.0251 |\n", + "| img/Transformer/encoderblock_9/MultiHeadDotProductAttention_0/key/bias | (12, 64) | 768 | -0.00403 | 0.497 |\n", + "| img/Transformer/encoderblock_9/MultiHeadDotProductAttention_0/key/kernel | (768, 12, 64) | 589,824 | 1.15e-05 | 0.0224 |\n", + "| img/Transformer/encoderblock_9/MultiHeadDotProductAttention_0/out/bias | (768,) | 768 | -0.000966 | 0.233 |\n", + "| img/Transformer/encoderblock_9/MultiHeadDotProductAttention_0/out/kernel | (12, 64, 768) | 589,824 | 3.96e-06 | 0.0227 |\n", + "| img/Transformer/encoderblock_9/MultiHeadDotProductAttention_0/query/bias | (12, 64) | 768 | 0.00383 | 0.759 |\n", + "| img/Transformer/encoderblock_9/MultiHeadDotProductAttention_0/query/kernel | (768, 12, 64) | 589,824 | 2.61e-05 | 0.0221 |\n", + "| img/Transformer/encoderblock_9/MultiHeadDotProductAttention_0/value/bias | (12, 64) | 768 | 0.00333 | 0.111 |\n", + "| img/Transformer/encoderblock_9/MultiHeadDotProductAttention_0/value/kernel | (768, 12, 64) | 589,824 | 1.24e-05 | 0.024 |\n", + "| img/cls | (1, 1, 768) | 768 | 0.0472 | 0.959 |\n", + "| img/embedding/bias | (768,) | 768 | -0.00594 | 0.185 |\n", + "| img/embedding/kernel | (16, 16, 3, 768) | 589,824 | 7.7e-06 | 0.0157 |\n", + "| img/head/bias | (768,) | 768 | -0.00346 | 1.43 |\n", + "| img/head/kernel | (768, 768) | 589,824 | 5.02e-06 | 0.0398 |\n", + "| img/pos_embedding | (1, 196, 768) | 150,528 | -0.00125 | 0.344 |\n", + "| t | (1,) | 1 | 4.39 | 0.0 |\n", + "| txt/BertEncoder_0/embedder/embedders_position_ids/embedding | (512, 768) | 393,216 | -5.31e-05 | 0.0191 |\n", + "| txt/BertEncoder_0/embedder/embedders_segment_ids/embedding | (2, 768) | 1,536 | -0.00186 | 0.0445 |\n", + "| txt/BertEncoder_0/embedder/embedders_token_ids/embedding | (30522, 768) | 23,440,896 | -0.0278 | 0.0966 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_0/attention_block/attention_layer/key/bias | (768,) | 768 | 0.00173 | 0.0804 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_0/attention_block/attention_layer/key/kernel | (768, 768) | 589,824 | -5.01e-05 | 0.0453 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_0/attention_block/attention_layer/query/bias | (768,) | 768 | 0.0172 | 0.487 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_0/attention_block/attention_layer/query/kernel | (768, 768) | 589,824 | -8.24e-07 | 0.0453 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_0/attention_block/attention_layer/value/bias | (768,) | 768 | 0.000783 | 0.0517 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_0/attention_block/attention_layer/value/kernel | (768, 768) | 589,824 | -1.96e-05 | 0.031 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_0/attention_block/dense_layer/bias | (768,) | 768 | -0.00214 | 0.0456 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_0/attention_block/dense_layer/kernel | (768, 768) | 589,824 | -2.01e-05 | 0.0311 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_0/attention_block/layer_norm/bias | (768,) | 768 | -0.023 | 0.352 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_0/attention_block/layer_norm/scale | (768,) | 768 | 0.982 | 0.172 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_0/mlp_block/dense_layer/bias | (768,) | 768 | -0.00283 | 0.0899 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_0/mlp_block/dense_layer/kernel | (3072, 768) | 2,359,296 | -0.000185 | 0.0384 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_0/mlp_block/layer_norm/bias | (768,) | 768 | -0.0371 | 0.0578 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_0/mlp_block/layer_norm/scale | (768,) | 768 | 0.716 | 0.0649 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_0/mlp_block/mlp/dense_layer/bias | (3072,) | 3,072 | -0.209 | 0.0789 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_0/mlp_block/mlp/dense_layer/kernel | (768, 3072) | 2,359,296 | 0.000162 | 0.0395 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_1/attention_block/attention_layer/key/bias | (768,) | 768 | -0.00332 | 0.0765 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_1/attention_block/attention_layer/key/kernel | (768, 768) | 589,824 | 1.99e-05 | 0.0446 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_1/attention_block/attention_layer/query/bias | (768,) | 768 | 0.018 | 0.266 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_1/attention_block/attention_layer/query/kernel | (768, 768) | 589,824 | -0.000109 | 0.0444 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_1/attention_block/attention_layer/value/bias | (768,) | 768 | 0.000419 | 0.0381 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_1/attention_block/attention_layer/value/kernel | (768, 768) | 589,824 | -9.83e-06 | 0.0322 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_1/attention_block/dense_layer/bias | (768,) | 768 | -0.000375 | 0.0332 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_1/attention_block/dense_layer/kernel | (768, 768) | 589,824 | -1.98e-05 | 0.0311 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_1/attention_block/layer_norm/bias | (768,) | 768 | -0.0277 | 0.177 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_1/attention_block/layer_norm/scale | (768,) | 768 | 0.929 | 0.107 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_1/mlp_block/dense_layer/bias | (768,) | 768 | -0.00157 | 0.054 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_1/mlp_block/dense_layer/kernel | (3072, 768) | 2,359,296 | -7.92e-05 | 0.0389 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_1/mlp_block/layer_norm/bias | (768,) | 768 | -0.0381 | 0.0436 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_1/mlp_block/layer_norm/scale | (768,) | 768 | 0.85 | 0.0859 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_1/mlp_block/mlp/dense_layer/bias | (3072,) | 3,072 | -0.18 | 0.0889 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_1/mlp_block/mlp/dense_layer/kernel | (768, 3072) | 2,359,296 | 0.000555 | 0.0409 |\n", + "INFO:absl:\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_10/attention_block/attention_layer/key/bias | (768,) | 768 | -0.00276 | 0.0785 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_10/attention_block/attention_layer/key/kernel | (768, 768) | 589,824 | -1.02e-05 | 0.0451 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_10/attention_block/attention_layer/query/bias | (768,) | 768 | -0.00648 | 0.277 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_10/attention_block/attention_layer/query/kernel | (768, 768) | 589,824 | -4.15e-05 | 0.0446 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_10/attention_block/attention_layer/value/bias | (768,) | 768 | 0.000935 | 0.0378 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_10/attention_block/attention_layer/value/kernel | (768, 768) | 589,824 | -1.83e-06 | 0.0298 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_10/attention_block/dense_layer/bias | (768,) | 768 | 0.000223 | 0.0434 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_10/attention_block/dense_layer/kernel | (768, 768) | 589,824 | -7.34e-06 | 0.0299 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_10/attention_block/layer_norm/bias | (768,) | 768 | -0.0451 | 0.204 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_10/attention_block/layer_norm/scale | (768,) | 768 | 0.689 | 0.109 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_10/mlp_block/dense_layer/bias | (768,) | 768 | -0.00139 | 0.107 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_10/mlp_block/dense_layer/kernel | (3072, 768) | 2,359,296 | 6.03e-05 | 0.0434 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_10/mlp_block/layer_norm/bias | (768,) | 768 | -0.0359 | 0.0716 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_10/mlp_block/layer_norm/scale | (768,) | 768 | 0.88 | 0.0729 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_10/mlp_block/mlp/dense_layer/bias | (3072,) | 3,072 | -0.213 | 0.0613 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_10/mlp_block/mlp/dense_layer/kernel | (768, 3072) | 2,359,296 | 0.00169 | 0.0427 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_11/attention_block/attention_layer/key/bias | (768,) | 768 | -0.00566 | 0.0766 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_11/attention_block/attention_layer/key/kernel | (768, 768) | 589,824 | 0.000108 | 0.0482 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_11/attention_block/attention_layer/query/bias | (768,) | 768 | 0.0135 | 0.268 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_11/attention_block/attention_layer/query/kernel | (768, 768) | 589,824 | -3.91e-05 | 0.0484 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_11/attention_block/attention_layer/value/bias | (768,) | 768 | -2.05e-05 | 0.016 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_11/attention_block/attention_layer/value/kernel | (768, 768) | 589,824 | -1.32e-05 | 0.0346 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_11/attention_block/dense_layer/bias | (768,) | 768 | -4.59e-05 | 0.0201 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_11/attention_block/dense_layer/kernel | (768, 768) | 589,824 | 9.95e-08 | 0.0344 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_11/attention_block/layer_norm/bias | (768,) | 768 | -0.0645 | 0.147 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_11/attention_block/layer_norm/scale | (768,) | 768 | 0.711 | 0.0417 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_11/mlp_block/dense_layer/bias | (768,) | 768 | -0.00111 | 0.0884 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_11/mlp_block/dense_layer/kernel | (3072, 768) | 2,359,296 | -1.39e-05 | 0.0346 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_11/mlp_block/layer_norm/bias | (768,) | 768 | -0.0794 | 0.116 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_11/mlp_block/layer_norm/scale | (768,) | 768 | 0.616 | 0.0297 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_11/mlp_block/mlp/dense_layer/bias | (3072,) | 3,072 | -0.172 | 0.0627 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_11/mlp_block/mlp/dense_layer/kernel | (768, 3072) | 2,359,296 | 0.0027 | 0.0366 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_2/attention_block/attention_layer/key/bias | (768,) | 768 | 0.00158 | 0.0787 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_2/attention_block/attention_layer/key/kernel | (768, 768) | 589,824 | 3.48e-05 | 0.0444 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_2/attention_block/attention_layer/query/bias | (768,) | 768 | 0.00346 | 0.193 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_2/attention_block/attention_layer/query/kernel | (768, 768) | 589,824 | 2.12e-06 | 0.0442 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_2/attention_block/attention_layer/value/bias | (768,) | 768 | 0.00371 | 0.0626 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_2/attention_block/attention_layer/value/kernel | (768, 768) | 589,824 | -4.83e-06 | 0.0303 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_2/attention_block/dense_layer/bias | (768,) | 768 | -2.89e-05 | 0.0633 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_2/attention_block/dense_layer/kernel | (768, 768) | 589,824 | -9.67e-06 | 0.0292 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_2/attention_block/layer_norm/bias | (768,) | 768 | -0.0236 | 0.166 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_2/attention_block/layer_norm/scale | (768,) | 768 | 0.905 | 0.0763 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_2/mlp_block/dense_layer/bias | (768,) | 768 | -0.00248 | 0.0592 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_2/mlp_block/dense_layer/kernel | (3072, 768) | 2,359,296 | -7.06e-05 | 0.0394 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_2/mlp_block/layer_norm/bias | (768,) | 768 | -0.036 | 0.0455 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_2/mlp_block/layer_norm/scale | (768,) | 768 | 0.802 | 0.0567 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_2/mlp_block/mlp/dense_layer/bias | (3072,) | 3,072 | -0.18 | 0.0884 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_2/mlp_block/mlp/dense_layer/kernel | (768, 3072) | 2,359,296 | 0.000542 | 0.0415 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_3/attention_block/attention_layer/key/bias | (768,) | 768 | -0.00127 | 0.0755 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_3/attention_block/attention_layer/key/kernel | (768, 768) | 589,824 | 3.12e-05 | 0.0414 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_3/attention_block/attention_layer/query/bias | (768,) | 768 | -0.0105 | 0.204 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_3/attention_block/attention_layer/query/kernel | (768, 768) | 589,824 | 3.34e-05 | 0.0416 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_3/attention_block/attention_layer/value/bias | (768,) | 768 | 0.00178 | 0.0306 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_3/attention_block/attention_layer/value/kernel | (768, 768) | 589,824 | -1.34e-05 | 0.0303 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_3/attention_block/dense_layer/bias | (768,) | 768 | -0.00116 | 0.0364 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_3/attention_block/dense_layer/kernel | (768, 768) | 589,824 | 1.66e-05 | 0.0297 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_3/attention_block/layer_norm/bias | (768,) | 768 | -0.0146 | 0.149 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_3/attention_block/layer_norm/scale | (768,) | 768 | 0.897 | 0.217 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_3/mlp_block/dense_layer/bias | (768,) | 768 | -0.00258 | 0.0632 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_3/mlp_block/dense_layer/kernel | (3072, 768) | 2,359,296 | -4.58e-05 | 0.0396 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_3/mlp_block/layer_norm/bias | (768,) | 768 | -0.0281 | 0.0433 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_3/mlp_block/layer_norm/scale | (768,) | 768 | 0.736 | 0.047 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_3/mlp_block/mlp/dense_layer/bias | (3072,) | 3,072 | -0.18 | 0.0783 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_3/mlp_block/mlp/dense_layer/kernel | (768, 3072) | 2,359,296 | 0.000188 | 0.0418 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_4/attention_block/attention_layer/key/bias | (768,) | 768 | 0.0025 | 0.0828 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_4/attention_block/attention_layer/key/kernel | (768, 768) | 589,824 | -3.43e-06 | 0.0411 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_4/attention_block/attention_layer/query/bias | (768,) | 768 | 0.00266 | 0.215 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_4/attention_block/attention_layer/query/kernel | (768, 768) | 589,824 | -3.55e-05 | 0.0414 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_4/attention_block/attention_layer/value/bias | (768,) | 768 | -9.8e-05 | 0.0278 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_4/attention_block/attention_layer/value/kernel | (768, 768) | 589,824 | -2.48e-05 | 0.0328 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_4/attention_block/dense_layer/bias | (768,) | 768 | 0.000161 | 0.021 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_4/attention_block/dense_layer/kernel | (768, 768) | 589,824 | -5.22e-06 | 0.0317 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_4/attention_block/layer_norm/bias | (768,) | 768 | -0.0255 | 0.148 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_4/attention_block/layer_norm/scale | (768,) | 768 | 0.87 | 0.186 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_4/mlp_block/dense_layer/bias | (768,) | 768 | -0.00135 | 0.0519 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_4/mlp_block/dense_layer/kernel | (3072, 768) | 2,359,296 | -2.25e-05 | 0.0397 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_4/mlp_block/layer_norm/bias | (768,) | 768 | -0.0268 | 0.0392 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_4/mlp_block/layer_norm/scale | (768,) | 768 | 0.757 | 0.0459 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_4/mlp_block/mlp/dense_layer/bias | (3072,) | 3,072 | -0.182 | 0.0968 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_4/mlp_block/mlp/dense_layer/kernel | (768, 3072) | 2,359,296 | 0.000773 | 0.0423 |\n", + "INFO:absl:\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_5/attention_block/attention_layer/key/bias | (768,) | 768 | 0.000957 | 0.0788 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_5/attention_block/attention_layer/key/kernel | (768, 768) | 589,824 | -0.000109 | 0.0416 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_5/attention_block/attention_layer/query/bias | (768,) | 768 | -0.00894 | 0.199 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_5/attention_block/attention_layer/query/kernel | (768, 768) | 589,824 | 0.000138 | 0.0417 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_5/attention_block/attention_layer/value/bias | (768,) | 768 | 0.000613 | 0.0285 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_5/attention_block/attention_layer/value/kernel | (768, 768) | 589,824 | -8.91e-05 | 0.0347 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_5/attention_block/dense_layer/bias | (768,) | 768 | 0.000115 | 0.019 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_5/attention_block/dense_layer/kernel | (768, 768) | 589,824 | 1.99e-06 | 0.0335 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_5/attention_block/layer_norm/bias | (768,) | 768 | -0.0237 | 0.134 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_5/attention_block/layer_norm/scale | (768,) | 768 | 0.872 | 0.159 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_5/mlp_block/dense_layer/bias | (768,) | 768 | -0.000342 | 0.053 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_5/mlp_block/dense_layer/kernel | (3072, 768) | 2,359,296 | -1.83e-05 | 0.0404 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_5/mlp_block/layer_norm/bias | (768,) | 768 | -0.0257 | 0.0449 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_5/mlp_block/layer_norm/scale | (768,) | 768 | 0.748 | 0.0402 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_5/mlp_block/mlp/dense_layer/bias | (3072,) | 3,072 | -0.181 | 0.093 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_5/mlp_block/mlp/dense_layer/kernel | (768, 3072) | 2,359,296 | 0.000795 | 0.0425 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_6/attention_block/attention_layer/key/bias | (768,) | 768 | 0.00393 | 0.0801 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_6/attention_block/attention_layer/key/kernel | (768, 768) | 589,824 | -7.02e-05 | 0.0423 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_6/attention_block/attention_layer/query/bias | (768,) | 768 | -0.0112 | 0.209 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_6/attention_block/attention_layer/query/kernel | (768, 768) | 589,824 | 0.000111 | 0.0423 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_6/attention_block/attention_layer/value/bias | (768,) | 768 | 0.000317 | 0.0351 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_6/attention_block/attention_layer/value/kernel | (768, 768) | 589,824 | 5.99e-06 | 0.0351 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_6/attention_block/dense_layer/bias | (768,) | 768 | 0.000102 | 0.0233 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_6/attention_block/dense_layer/kernel | (768, 768) | 589,824 | 1.2e-05 | 0.0342 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_6/attention_block/layer_norm/bias | (768,) | 768 | -0.026 | 0.14 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_6/attention_block/layer_norm/scale | (768,) | 768 | 0.849 | 0.125 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_6/mlp_block/dense_layer/bias | (768,) | 768 | -0.000724 | 0.0653 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_6/mlp_block/dense_layer/kernel | (3072, 768) | 2,359,296 | -3.74e-06 | 0.0416 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_6/mlp_block/layer_norm/bias | (768,) | 768 | -0.0266 | 0.0467 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_6/mlp_block/layer_norm/scale | (768,) | 768 | 0.771 | 0.0392 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_6/mlp_block/mlp/dense_layer/bias | (3072,) | 3,072 | -0.178 | 0.0903 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_6/mlp_block/mlp/dense_layer/kernel | (768, 3072) | 2,359,296 | 0.000836 | 0.0433 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_7/attention_block/attention_layer/key/bias | (768,) | 768 | 0.00154 | 0.0742 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_7/attention_block/attention_layer/key/kernel | (768, 768) | 589,824 | -3.73e-05 | 0.043 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_7/attention_block/attention_layer/query/bias | (768,) | 768 | -0.0066 | 0.23 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_7/attention_block/attention_layer/query/kernel | (768, 768) | 589,824 | 3.89e-05 | 0.0429 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_7/attention_block/attention_layer/value/bias | (768,) | 768 | -0.00204 | 0.0397 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_7/attention_block/attention_layer/value/kernel | (768, 768) | 589,824 | 2.6e-05 | 0.0357 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_7/attention_block/dense_layer/bias | (768,) | 768 | 0.000126 | 0.0222 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_7/attention_block/dense_layer/kernel | (768, 768) | 589,824 | -3.72e-06 | 0.035 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_7/attention_block/layer_norm/bias | (768,) | 768 | -0.0382 | 0.145 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_7/attention_block/layer_norm/scale | (768,) | 768 | 0.805 | 0.105 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_7/mlp_block/dense_layer/bias | (768,) | 768 | -0.000484 | 0.0755 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_7/mlp_block/dense_layer/kernel | (3072, 768) | 2,359,296 | -4.84e-06 | 0.0436 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_7/mlp_block/layer_norm/bias | (768,) | 768 | -0.0271 | 0.0477 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_7/mlp_block/layer_norm/scale | (768,) | 768 | 0.787 | 0.0368 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_7/mlp_block/mlp/dense_layer/bias | (3072,) | 3,072 | -0.183 | 0.0784 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_7/mlp_block/mlp/dense_layer/kernel | (768, 3072) | 2,359,296 | 0.00155 | 0.0443 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_8/attention_block/attention_layer/key/bias | (768,) | 768 | -0.00364 | 0.0721 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_8/attention_block/attention_layer/key/kernel | (768, 768) | 589,824 | -8.37e-05 | 0.0435 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_8/attention_block/attention_layer/query/bias | (768,) | 768 | -0.0128 | 0.212 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_8/attention_block/attention_layer/query/kernel | (768, 768) | 589,824 | 0.000144 | 0.044 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_8/attention_block/attention_layer/value/bias | (768,) | 768 | -0.00064 | 0.0414 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_8/attention_block/attention_layer/value/kernel | (768, 768) | 589,824 | 3.29e-05 | 0.036 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_8/attention_block/dense_layer/bias | (768,) | 768 | 0.000144 | 0.0389 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_8/attention_block/dense_layer/kernel | (768, 768) | 589,824 | -1.05e-05 | 0.0366 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_8/attention_block/layer_norm/bias | (768,) | 768 | -0.0412 | 0.159 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_8/attention_block/layer_norm/scale | (768,) | 768 | 0.75 | 0.106 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_8/mlp_block/dense_layer/bias | (768,) | 768 | -0.000383 | 0.0733 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_8/mlp_block/dense_layer/kernel | (3072, 768) | 2,359,296 | 2.33e-06 | 0.0465 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_8/mlp_block/layer_norm/bias | (768,) | 768 | -0.0255 | 0.103 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_8/mlp_block/layer_norm/scale | (768,) | 768 | 0.782 | 0.0419 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_8/mlp_block/mlp/dense_layer/bias | (3072,) | 3,072 | -0.188 | 0.0604 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_8/mlp_block/mlp/dense_layer/kernel | (768, 3072) | 2,359,296 | 0.00191 | 0.046 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_9/attention_block/attention_layer/key/bias | (768,) | 768 | 0.00177 | 0.0769 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_9/attention_block/attention_layer/key/kernel | (768, 768) | 589,824 | -2.25e-05 | 0.0445 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_9/attention_block/attention_layer/query/bias | (768,) | 768 | 0.00698 | 0.221 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_9/attention_block/attention_layer/query/kernel | (768, 768) | 589,824 | -9.69e-05 | 0.0451 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_9/attention_block/attention_layer/value/bias | (768,) | 768 | -0.000478 | 0.0379 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_9/attention_block/attention_layer/value/kernel | (768, 768) | 589,824 | 9.23e-06 | 0.0272 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_9/attention_block/dense_layer/bias | (768,) | 768 | 0.000851 | 0.0372 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_9/attention_block/dense_layer/kernel | (768, 768) | 589,824 | 6.89e-06 | 0.0278 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_9/attention_block/layer_norm/bias | (768,) | 768 | -0.0508 | 0.206 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_9/attention_block/layer_norm/scale | (768,) | 768 | 0.66 | 0.0814 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_9/mlp_block/dense_layer/bias | (768,) | 768 | -0.00159 | 0.1 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_9/mlp_block/dense_layer/kernel | (3072, 768) | 2,359,296 | 3.22e-05 | 0.0476 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_9/mlp_block/layer_norm/bias | (768,) | 768 | -0.0281 | 0.0884 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_9/mlp_block/layer_norm/scale | (768,) | 768 | 0.782 | 0.0604 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_9/mlp_block/mlp/dense_layer/bias | (3072,) | 3,072 | -0.202 | 0.0655 |\n", + "| txt/BertEncoder_0/encoder_block/layer_sequence/layers_9/mlp_block/mlp/dense_layer/kernel | (768, 3072) | 2,359,296 | 0.00262 | 0.0463 |\n", + "INFO:absl:\n", + "| txt/BertEncoder_0/layer_norm/bias | (768,) | 768 | -0.0183 | 0.0679 |\n", + "| txt/BertEncoder_0/layer_norm/scale | (768,) | 768 | 0.836 | 0.125 |\n", + "| txt/head/bias | (768,) | 768 | -0.0015 | 0.0879 |\n", + "| txt/head/kernel | (768, 768) | 589,824 | 7.58e-05 | 0.0409 |\n", + "+-------------------------------------------------------------------------------------------------------+------------------+------------+-----------+--------+\n", + "Total: 195,870,721\n", + "INFO:absl:NOTE: Replicating...\n", + "INFO:absl:NOTE: Initializing evaluator: disclf...\n", + "INFO:absl:Using 81 prompts_templates: ['a bad photo of a {}', 'a photo of many {}', 'a sculpture of a {}', 'a photo of the hard to see {}', 'a low resolution photo of the {}', 'a rendering of a {}', 'graffiti of a {}', 'a bad photo of the {}', 'a cropped photo of the {}', 'a tattoo of a {}', 'the embroidered {}', 'a photo of a hard to see {}', 'a bright photo of a {}', 'a photo of a clean {}', 'a photo of a dirty {}', 'a dark photo of the {}', 'a drawing of a {}', 'a photo of my {}', 'the plastic {}', 'a photo of the cool {}', 'a closeup photo of a {}', 'a black and white photo of the {}', 'a painting of the {}', 'a painting of a {}', 'a pixelated photo of the {}', 'a sculpture of the {}', 'a bright photo of the {}', 'a cropped photo of a {}', 'a plastic {}', 'a photo of the dirty {}', 'a jpeg corrupted photo of a {}', 'a blurry photo of the {}', 'a photo of the {}', 'a good photo of the {}', 'a rendering of the {}', 'a {} in a video game', 'a photo of one {}', 'a doodle of a {}', 'a closeup photo of the {}', 'a photo of a {}', 'the origami {}', 'the {} in a video game', 'a sketch of a {}', 'a doodle of the {}', 'a origami {}', 'a low resolution photo of a {}', 'the toy {}', 'a rendition of the {}', 'a photo of the clean {}', 'a photo of a large {}', 'a rendition of a {}', 'a photo of a nice {}', 'a photo of a weird {}', 'a blurry photo of a {}', 'a cartoon {}', 'art of a {}', 'a sketch of the {}', 'a embroidered {}', 'a pixelated photo of a {}', 'itap of the {}', 'a jpeg corrupted photo of the {}', 'a good photo of a {}', 'a plushie {}', 'a photo of the nice {}', 'a photo of the small {}', 'a photo of the weird {}', 'the cartoon {}', 'art of the {}', 'a drawing of the {}', 'a photo of the large {}', 'a black and white photo of a {}', 'the plushie {}', 'a dark photo of a {}', 'itap of a {}', 'graffiti of the {}', 'a toy {}', 'itap of my {}', 'a photo of a cool {}', 'a photo of a small {}', 'a tattoo of the {}', '{}']\n", + "INFO:absl:Load dataset info from /root/tensorflow_datasets/oxford_iiit_pet/3.2.0\n", + "INFO:absl:Using 37 class_names: ['abyssinian', 'american bulldog', 'american pit bull terrier', 'basset hound', 'beagle', 'bengal', 'birman', 'bombay', 'boxer', 'british shorthair', 'chihuahua', 'egyptian mau', 'english cocker spaniel', 'english setter', 'german shorthaired', 'great pyrenees', 'havanese', 'japanese chin', 'keeshond', 'leonberger', 'maine coon', 'miniature pinscher', 'newfoundland', 'persian', 'pomeranian', 'pug', 'ragdoll', 'russian blue', 'saint bernard', 'samoyed', 'scottish terrier', 'shiba inu', 'siamese', 'sphynx', 'staffordshire bull terrier', 'wheaten terrier', 'yorkshire terrier']\n", + "INFO:absl:Load dataset info from /root/tensorflow_datasets/oxford_iiit_pet/3.2.0\n", + "INFO:absl:Constructing tf.data.Dataset oxford_iiit_pet for split _EvenSplit(split='test', index=0, count=1, drop_remainder=False), from /root/tensorflow_datasets/oxford_iiit_pet/3.2.0\n", + "INFO:absl:Data before pre-processing:\n", + "{'file_name': \u003ctf.Tensor 'args_0:0' shape=() dtype=string\u003e, 'image': \u003ctf.Tensor 'args_1:0' shape=(None, None, 3) dtype=uint8\u003e, 'label': \u003ctf.Tensor 'args_2:0' shape=() dtype=int64\u003e, 'segmentation_mask': \u003ctf.Tensor 'args_3:0' shape=(None, None, 1) dtype=uint8\u003e, 'species': \u003ctf.Tensor 'args_4:0' shape=() dtype=int64\u003e}\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "INFO:absl:Data after pre-processing:\n", + "{'image': \u003ctf.Tensor 'add:0' shape=(224, 224, 3) dtype=float32\u003e, 'label': \u003ctf.Tensor 'args_2:0' shape=() dtype=int64\u003e}\n", + "INFO:absl:Data before pre-processing:\n", + "{'label': \u003ctf.Tensor 'args_0:0' shape=() dtype=int64\u003e, 'texts': \u003ctf.Tensor 'args_1:0' shape=() dtype=string\u003e}\n", + "INFO:absl:Data after pre-processing:\n", + "{'label': \u003ctf.Tensor 'args_0:0' shape=() dtype=int64\u003e, 'labels': \u003ctf.Tensor 'strided_slice_4:0' shape=(16,) dtype=int32\u003e}\n", + "WARNING:tensorflow:From /usr/local/lib/python3.9/dist-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.\n", + "Instructions for updating:\n", + "Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089\n", + "INFO:absl:Initialized evaluator in 3.0 seconds\n", + "INFO:absl:NOTE: disclf evaluation step 0...\n", + "INFO:absl:Starting text embedding...\n", + "INFO:absl:Compiled text embeddings in 7.6s\n", + "INFO:absl:Embedded oxford_iiit_pet text in 7 steps - ...[512 512 512 512 512 437 0]\n", + "INFO:absl:Totalling 2997 text in 2.6s\n", + "INFO:absl:Total texts embeddings size 9.2M\n", + "INFO:absl:Starting image embedding...\n", + "INFO:absl:Compiled image embeddings in 17.5s\n", + "INFO:absl:Embedded oxford_iiit_pet image in 9 steps - ...[512 512 512 512 512 512 512 85 0]\n", + "INFO:absl:Totalling 3669 image in 46.1s\n", + "INFO:absl:Dataset oxford_iiit_pet, results {'accuracy': 0.8103025347506132, 'correct': 2973, 'count': 3669}\n", + "INFO:absl:\u001b[35m[0]\u001b[0m z/0shot/oxford_iiit_pet_accuracy = 0.8103025347506132\n", + "INFO:absl:\u001b[35m[0]\u001b[0m z/secs/eval/disclf = 73.91012993699997\n", + "INFO:absl:TIMING[z/secs/eval/disclf]: 73.91012993699997\n", + "INFO:absl:NOTE: Done!\n" + ] + } + ], + "source": [ + "# Should run in ~5 minutes on a T4 GPU...\n", + "set_max_height(444)\n", + "eval_only.main([])" + ] + }, + { + "cell_type": "code", + "source": [ + "# ... and yield a final 81% accuracy.\n", + "import json\n", + "json.loads(open(flags.FLAGS.workdir + '/big_vision_metrics.txt').readline())" + ], + "metadata": { + "id": "J14NHGBAGXYp", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "0d5af276-a60f-44cd-c25f-0bf54f0e7a93" + }, + "execution_count": 26, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "{'step': 0,\n", + " 'z/0shot/oxford_iiit_pet_accuracy': 0.8103025347506132,\n", + " 'z/secs/eval/disclf': 73.91012993699997}" + ] + }, + "metadata": {}, + "execution_count": 26 + } + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "name": "lit OSS", + "provenance": [] + }, + "gpuClass": "standard", + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/big_vision/configs/proj/image_text/siglip_lit_coco.py b/big_vision/configs/proj/image_text/siglip_lit_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..f1dbd472a1d440351b3e176696873fc566797ffd --- /dev/null +++ b/big_vision/configs/proj/image_text/siglip_lit_coco.py @@ -0,0 +1,115 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""Minimal SigLIP (https://arxiv.org/abs/2303.15343) example. + +Example training: + +big_vision.trainers.proj.image_text.siglip \ + --config big_vision/configs/proj/image_text/lit_coco.py:batch_size=512 \ + --workdir gs://$GS_BUCKET_NAME/big_vision/`date '+%Y-%m-%d_%H%M'` +""" + +import big_vision.configs.common as bvcc +from big_vision.configs.proj.image_text import common +from ml_collections import ConfigDict + + +def get_config(arg=None): + """The base configuration.""" + arg = bvcc.parse_arg( + arg, res=224, runlocal=False, token_len=16, txt='bert_base', img='B/16', + init='', img_head=False, batch_size=512) + img_name, img_init = common.inits[arg.img] + txt_name, txt_init = common.inits[arg.txt] + config = ConfigDict() + + config.input = {} + config.input.data = dict(name='coco_captions', split='train') + config.input.batch_size = arg.batch_size if not arg.runlocal else 32 + config.input.shuffle_buffer_size = 250_000 if not arg.runlocal else 50 + + config.total_steps = 5_000 if not arg.runlocal else 1 + + config.init_shapes = [(1, arg.res, arg.res, 3), (1, arg.token_len,)] + config.init_types = ['float32', 'int32'] + + if arg.init: + vocab_path = arg.init.rsplit('.', 1)[0] + '.txt' + else: + vocab_path = f'{txt_init}/vocab.txt' + tokenizer = lambda inkey: ( + f'bert_tokenize(inkey="{inkey}", max_len={arg.token_len}, ' + f'vocab_path="{vocab_path}")') + config.input.pp = ( + f'decode|resize({arg.res})|flip_lr|randaug(2,10)|value_range(-1,1)' + f'|flatten|{tokenizer("captions/text")}|keep("image", "labels")' + ) + config.pp_modules = ['ops_general', 'ops_image', 'ops_text', + 'proj.flaxformer.bert_ops', 'archive.randaug'] + + config.log_training_steps = 50 + config.ckpt_steps = 1000 + + # Model section + config.model_name = 'proj.image_text.two_towers' + config.model_load = {} + if arg.init: + config.model_init = arg.init + else: + config.model_init = {'image': img_init, 'text': txt_init} + config.model_load['txt_load_kw'] = {'dont_load': ['head/kernel', 'head/bias']} + if not arg.img_head: + config.model_load['img_load_kw'] = {'dont_load': ['head/kernel', 'head/bias']} + config.model = ConfigDict() + config.model.image_model = 'vit' + config.model.text_model = 'proj.flaxformer.bert' + config.model.image = ConfigDict({ + 'variant': img_name, + 'pool_type': 'tok', + 'head_zeroinit': False, + }) + config.model.text = ConfigDict({ + 'config': txt_name, + 'head_zeroinit': False, + }) + config.model.temperature_init = 10.0 + dim = {'B': 768, 'L': 1024}[arg.img[0]] + config.model.out_dim = (dim if arg.img_head else None, dim) # (image_out_dim, text_out_dim) + config.model.bias_init = -2.71 + + if txt_name == 'base': + config.optax_name = 'scale_by_adam' + else: + config.optax_name = 'big_vision.scale_by_adafactor' + + config.lr = 0.001 + config.wd = 0.01 + warmup_steps = max(int(0.03 * config.total_steps), 100) + config.schedule = [ + ('img/.*', None), # Freezes image tower. + ('.*', dict(decay_type='cosine', warmup_steps=warmup_steps)), + ] + + config.grad_clip_norm = 1.0 + + config.evals = {} + config.evals.retrieval_coco = common.get_coco( + pp_img=f'resize({arg.res})|value_range(-1, 1)', + pp_txt=tokenizer('texts'), + log_steps=1000, + ) + + return config diff --git a/big_vision/configs/proj/paligemma/README.md b/big_vision/configs/proj/paligemma/README.md new file mode 100644 index 0000000000000000000000000000000000000000..1ebe888c65eb4602e0e6849821cddd901e510669 --- /dev/null +++ b/big_vision/configs/proj/paligemma/README.md @@ -0,0 +1,270 @@ +# PaliGemma model README + +PaliGemma is an open vision-language model (VLM) inspired by PaLI-3, built with +open components, such as +the [SigLIP vision model](https://colab.research.google.com/github/google-research/big_vision/blob/main/big_vision/configs/proj/image_text/SigLIP_demo.ipynb) +and +the [Gemma language model](https://ai.google.dev/gemma). +PaliGemma is designed as a versatile model for transfer to a wide range of +vision-language tasks such as image and short video caption, visual question +answering, text reading, object detection and object segmentation. Together with +the pretrained and transfer checkpoints at multiple resolutions, we provide a +checkpoint transferred to a mixture of tasks that can be used for off-the-shelf +exploration. + +## Quick Reference + +This is the reference repository of the model, you may also want to check out the resources on + + - [ArXiv](https://arxiv.org/abs/2407.07726): Technical report. + - [Kaggle](https://www.kaggle.com/models/google/paligemma): + All pre-trained / mix checkpoints and model card. + - [Kaggle-FT](https://www.kaggle.com/models/google/paligemma-ft): + All fine-tuned checkpoints and model card. + - [VertexAI Model Garden](https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/363): + Paligemma models on GCP. + - [Hugging Face](https://huggingface.co/google/paligemma-3b-pt-224): + PyTorch port of paligemma models. + - [Light finetuning colab](https://colab.research.google.com/github/google-research/big_vision/blob/main/big_vision/configs/proj/paligemma/finetune_paligemma.ipynb): + Lightweight colab for fine-tuning PaliGemma. It can be run on a single T4 GPU (16GB) + available on free Colab. + - [HuggingFace demo](https://hf.co/spaces/google/paligemma): live demo. + +### Citation BibTeX + +``` +@article{beyer2024paligemma, + title={{PaliGemma: A versatile 3B VLM for transfer}}, + author={Lucas Beyer and Andreas Steiner and André Susano Pinto and Alexander Kolesnikov and Xiao Wang and Daniel Salz and Maxim Neumann and Ibrahim Alabdulmohsin and Michael Tschannen and Emanuele Bugliarello and Thomas Unterthiner and Daniel Keysers and Skanda Koppula and Fangyu Liu and Adam Grycner and Alexey Gritsenko and Neil Houlsby and Manoj Kumar and Keran Rong and Julian Eisenschlos and Rishabh Kabra and Matthias Bauer and Matko Bošnjak and Xi Chen and Matthias Minderer and Paul Voigtlaender and Ioana Bica and Ivana Balazevic and Joan Puigcerver and Pinelopi Papalampidi and Olivier Henaff and Xi Xiong and Radu Soricut and Jeremiah Harmsen and Xiaohua Zhai}, + year={2024}, + journal={arXiv preprint arXiv:2407.07726} +} +``` + +## Model description + +### Overview + +PaliGemma-3B is Vision-Language model that was inspired by the PaLI-3 recipe. +It is built on SigLIP visual encoder (specifically, SigLIP-So400m/14) and the +Gemma 2B language model. PaliGemma takes as input one or more images, +which are turned into "soft tokens" by the SigLIP encoder, and input text +(codenamed the "prefix") that is tokenized by Gemma's tokenizer. The image +tokens and prefix tokens are concatenated (in this order) and passed to the +Gemma decoder with full block-attention, which then generates an output text +(the "suffix") auto-regressively with masked attention. + +![PaliGemma model](paligemma.png) + +### Training stages + +Similar to PaLI-3, PaliGemma's training consists of multiple stages: + + - Stage 0: the unimodal pre-training. We use publicly available off-the-shelf + SigLIP and Gemma models which have been pre-trained unimodally by their + respective authors. + - Stage 1: multimodal pre-training. The combined PaliGemma model is now + pre-trained on a fully multimodal training dataset, this at a low resolution + of 224px² and prefix+suffix sequence length of 128 tokens. This results in + the first base model that we release. + - Stage 2: high-resolution pre-training. We continue pre-training of the + Stage 1 model at resolution 448px² with sequence length 512 tokens for a short + duration on the same multimodal training data, but re-weighted with more + emphasis on examples that make use of higher resolution or longer sequence + length. We repeat this once more at resolution 896px². This results in two + further "high res" base models that we also release. + - Stage 3: fine-tune. The base models are transferred to + specific tasks by fine-tuning. To facilitate further research and + reproducibility, we release checkpoints fine-tuned on most of the benchmarks + we evaluate on. We also provide a "mix" transfer model, fine-tuned on a wide + variety of data, for use in interactive demos. + +Most of the code examples, use-cases, and code release are about Stage 3: +transferring to a task or dataset of interest to the user. + +### Tokenizer + +PaliGemma uses the Gemma tokenizer with 256'000 tokens, but we further extend +its vocabulary with 1024 entries that represent coordinates in normalized +image-space (\...\), and another with 128 entries +(\...\) that are codewords used by a lightweight +referring-expression segmentation vector-quantized variational auto-encoder +(VQ-VAE) with the architecture of [Ning et al. (2023)](https://arxiv.org/abs/2301.02229) and trained on OpenImages +as in PaLI-3. While the `big_vision` codebase is flexible enough to extend +tokenizers on-the-fly, we also provide a SentencePiece model file of the Gemma +tokenizer with these additional tokens baked in, for the convenience of +other codebases. + +## Checkpoints + +The PaliGemma models are released under the same open license as the Gemma +models, and hence require manual acknowledgement of the license terms on kaggle: +https://www.kaggle.com/models/google/paligemma. The reference checkpoints are +available on +[Kaggle](https://www.kaggle.com/models/google/paligemma), +[VertexAI Model Garden](https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/363) and +[Hugging Face](https://huggingface.co/google/paligemma-3b-pt-224). + +### Pretrained checkpoints + +Use one of these checkpoints as initialization for fine-tuning: + + - pt-224: Versatile pretrained model for tasks that do not require seeing + small details in the image. + Examples: natural image captioning and question-answering, detection and + segmentation of medium-large objects. This model was trained with + sequence length 128. + - pt-448: Versatile base model for mid/higher resolution tasks with access + to smaller details. Besides higher resolution, it has gotten more weight on + text reading, detection, and segmentation during its pre-training. Examples: + as above, plus detection, segmentation, text/diagram reading. This model was + trained with sequence length 512. + - pt-896: Further scaled-up version of pt-448, especially good at reading + very small texts as often found in documents and infographics. This model + was trained with sequence length 512. + +Besides the reference float32 checkpoint (11GB), we further provide +bfloat16 and float16 variants of each, to reduce download and storage time. +These are good for inference and frozen transfers, but full fine-tuning +should happen in float32 or mixed precision. + +### Mixture checkpoint + +This checkpoint is trained on a mixture of all our transfer tasks, +with a balancing intended to make it "nice to use" out of the box for +predictions. This model is multilingual and should +understand prompts in various languages, although English +is still its "mother tongue". +Questions can be asked in a natural way (including asking for a caption or +reading the text), and detection and segmentation should still work with the +structured `detect {things}` and `segment {things}` prompts as in the base model. + + - mix-224: Similarly to pt-224, this model is good at many natural image + tasks that do not require high resolution. Unlike the raw pre-trained model, + however, it can be interacted with more freely. For example, ask it to + "describe this image in great detail, please" or "How many coins do you see + in the picture?". This model was trained with sequence length 256. + - mix-448: As above, but it is better at tasks that require higher-resolution + input. For example, one could ask it "what is written in the "sum" field?", + to "describe this figure", or to "what is the GDP of France?" when shown an + infographic of countries' GDPs. This model was trained with + sequence length 512. + +### Transfers results and checkpoints + +We provide checkpoints transferred to most of the tasks we evaluated +transfer on, see the [kaggle page](https://www.kaggle.com/models/google/paligemma). +These are intended for use when a specialised model corresponding +to one of the tasks is needed, for academic research purposes only. +Depending on the task, they may require a specialised preprocessing format. + +The transfer setup is reasonably unified, with the main factors of variation +being the training duration, learning-rate, and whether or not to use dropout +and label-smoothing. Details can be found in the corresponding config files or +in an upcoming tech report. + +Importantly, none of these tasks or datasets are part of the pre-training data +mixture, and their images are explicitly removed from the web-scale +pretraining data. + +#### Captioning + +Benchmark (train split) | Metric (split) | pt-224 | pt-448 | pt-896 +-----------------------|----------------|--------|--------|-------- +[COCO captions](https://cocodataset.org/#home) (train+restval) | CIDEr (val) | 141.92 | 144.60 | +[NoCaps](https://nocaps.org/) (Eval of COCO captions transfer) | CIDEr (val) | 121.72 | 123.58 | +[COCO-35L](https://arxiv.org/abs/2205.12522) (train) | CIDEr dev (en / avg-34 / avg) | 139.2 / 115.8 / 116.4 | 141.2 / 118.0 / 118.6 | +[XM3600](https://arxiv.org/abs/2205.12522) (Eval of COCO-35L transfer) | CIDEr test (en / avg-35 / avg) | 78.1 / 41.3 / 42.4 | 80.0 / 41.9 / 42.9 | +[TextCaps](https://textvqa.org/textcaps/) (train) | CIDEr (val) | 127.48 | 153.94 | +[SciCap](https://arxiv.org/abs/2110.11624) (first sentence, no subfigure) (train+val) | CIDEr / BLEU-4 (test) | 162.25 / 0.192 | 181.49 / 0.211 | +[Screen2words](https://arxiv.org/abs/2108.03353) (train+dev) | CIDEr (test) | 117.57 | 119.59 | +[Widget Captioning](https://arxiv.org/abs/2010.04295) (train+dev) | CIDEr (test) | 136.07 | 148.36 | + +#### Question Answering + +Benchmark (train split) | Metric (split) | pt-224 | pt-448 | pt-896 +-----------------------|----------------|--------|--------|-------- +[VQAv2](https://visualqa.org/index.html) (train+validation) | Accuracy (Test server - std) | 83.19 | 85.64 | +[MMVP](https://arxiv.org/abs/2401.06209) (Eval of VQAv2 transfer) | Paired Accuracy | 47.33 | 45.33 | +[POPE](https://arxiv.org/abs/2305.10355) (Eval of VQAv2 transfer) | Accuracy (random / popular / adversarial) | 87.80 / 85.87 / 84.27 | 88.23 / 86.77 / 85.90 | +[Objaverse Multiview](https://arxiv.org/abs/2311.17851) (Eval of VQAv2 transfer) | Cosine Similarity (USEv4) | 62.7 | 62.8 | +[OKVQA](https://okvqa.allenai.org/) (train) | Accuracy (val) | 63.54 | 63.15 | +[A-OKVQA](https://allenai.org/project/a-okvqa/home) (MC) (train+val) | Accuracy (Test server) | 76.37 | 76.90 | +[A-OKVQA](https://allenai.org/project/a-okvqa/home) (DA) (train+val) | Accuracy (Test server) | 61.85 | 63.22 | +[GQA](https://cs.stanford.edu/people/dorarad/gqa/about.html) (train_balanced+val_balanced) | Accuracy (testdev balanced) | 65.61 | 67.03 | +[xGQA](https://aclanthology.org/2022.findings-acl.196/) (Eval of GQA transfer) | Mean Accuracy (bn,de,en,id,ko,pt,ru,zh) | 58.37 | 59.07 | +[NLVR2](https://lil.nlp.cornell.edu/nlvr/) (train+dev) | Accuracy (test) | 90.02 | 88.93 | +[MaRVL](https://marvl-challenge.github.io/) (Eval of NLVR2 transfer) | Mean Accuracy (test) (id,sw,ta,tr,zh) | 80.57 | 76.78 | +[AI2D](https://allenai.org/data/diagrams) (train) | Accuracy (test) | 72.12 | 73.28 | +[ScienceQA](https://scienceqa.github.io/) (Img subset, no CoT) (train+val) | Accuracy (test) | 95.39 | 95.93 | +[RSVQA-LR](https://zenodo.org/records/6344334) (Non numeric) (train+val) | Mean Accuracy (test) | 92.65 | 93.11 | +[RSVQA-HR](https://zenodo.org/records/6344367) (Non numeric) (train+val) | Mean Accuracy (test/test2) | 92.61 / 90.58 | 92.79 / 90.54 | +[ChartQA](https://arxiv.org/abs/2203.10244) (human+aug)x(train+val) | Mean Relaxed Accuracy (test_human, test_aug) | 57.08 | 71.36 | +[VizWiz](https://vizwiz.org/tasks-and-datasets/vqa/) VQA (train+val) | Accuracy (Test server - std) | 73.7 | 75.52 | +[TallyQA](https://arxiv.org/abs/1810.12440) (train) | Accuracy (test_simple/test_complex) | 81.72 / 69.56 | 84.86 / 72.27 | +[OCR-VQA](https://ocr-vqa.github.io/) (train+val) | Accuracy (test) | 73.24 | 75.60 | 75.90 +[TextVQA](https://textvqa.org/) (train+val) | Accuracy (Test server - std) | 55.47 | 73.15 | 76.48 +[DocVQA](https://www.docvqa.org/) (train+val) | ANLS (Test server) | 43.74 | 78.02 | 84.77 +[Infographic VQA](https://openaccess.thecvf.com/content/WACV2022/papers/Mathew_InfographicVQA_WACV_2022_paper.pdf) (train+val) | ANLS (Test server) | 28.46 | 40.47 | 47.75 +[SceneText VQA](https://arxiv.org/abs/1905.13648) (train+val) | ANLS (Test server) | 63.29 | 81.82 | 84.40 + +#### Segmentation + +Benchmark (train split) | Metric (split) | pt-224 | pt-448 | pt-896 +-----------------------|----------------|--------|--------|-------- +[RefCOCO](https://arxiv.org/abs/1608.00272) (combined refcoco, refcoco+, refcocog excluding val and test images) | MIoU (validation) refcoco / refcoco+ / refcocog | 73.40 / 68.32 / 67.65 | 75.57 / 69.76 / 70.17 | 76.94 / 72.18 / 72.22 + +#### Video tasks (Caption/QA) + +Benchmark (train split) | Metric (split) | pt-224 | pt-448 | pt-896 +-----------------------|----------------|--------|--------|-------- +[MSR-VTT](https://www.microsoft.com/en-us/research/publication/msr-vtt-a-large-video-description-dataset-for-bridging-video-and-language/) (Captioning) | CIDEr (test) | 70.54 | +[MSR-VTT](https://www.microsoft.com/en-us/research/publication/msr-vtt-a-large-video-description-dataset-for-bridging-video-and-language/) (QA) | Accuracy (test) | 50.09 | +[ActivityNet](http://activity-net.org/) (Captioning)] | CIDEr (test) | 34.62 | +[ActivityNet](http://activity-net.org/) (QA) | Accuracy (test) | 50.78 | +[VATEX](https://eric-xw.github.io/vatex-website/about.html) (Captioning) | CIDEr (test) | 79.73 | +[MSVD](https://www.cs.utexas.edu/users/ml/clamp/videoDescription/) (QA) | Accuracy (test) | 60.22 | + +#### Mix model (finetune on mixture of transfer tasks) + +Benchmark | Metric (split) | mix-224 | mix-448 +----------|----------------|---------|--------- +[MMVP](https://arxiv.org/abs/2401.06209) | Paired Accuracy | 46.00 | 45.33 +[POPE](https://arxiv.org/abs/2305.10355) | Accuracy (random / popular / adversarial) | 88.00 / 86.63 / 85.67 | 89.37 / 88.40 / 87.47 + + +## How to run PaliGemma fine-tuning + +To run PaliGemma fine-tuning, set up the `big_vision` repository by following the +main README file. Here we provide PaliGemma-specific instructions. + +Checkpoints can be downloaded from Kaggle. You need to create an account and acknowledge checkpoint usage policy. You can then download any checkpoint: + +``` +export KAGGLE_USERNAME= +export KAGGLE_KEY= + +# See https://www.kaggle.com/models/google/paligemma for a full list of models. +export MODEL_NAME=paligemma-3b-pt-224 +export CKPT_FILE=paligemma-3b-pt-224.npz + +mkdir ckpts/ +cd ckpts/ + +curl -L -u $KAGGLE_USERNAME:$KAGGLE_KEY\ + -o pt_224.npz \ + https://www.kaggle.com/api/v1/models/google/paligemma/jax/$MODEL_NAME/1/download/$CKPT_FILE +``` + +As an example, we provide the `forkme.py` config that is based on the easily-adjustable jsonl data source: + +``` +BV_GEMMA_DIR=ckpts/ python -m big_vision.trainers.proj.paligemma.train --config big_vision/configs/proj/paligemma/transfers/forkme.py --workdir workdirs/`date '+%m-%d_%H%M'` +``` + +If you want to use TFDS-based data, check out other transfer configs. Remember to set `TFDS_DATA_DIR` to point to the folder with data (can be GCP data bucket). + + +## Model Development Contributions + +See the [technical report](https://arxiv.org/abs/2407.07726)'s Appendix. diff --git a/big_vision/configs/proj/paligemma/finetune_paligemma.ipynb b/big_vision/configs/proj/paligemma/finetune_paligemma.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..3047d666827681dda62f01c5b95df35c1ad95bfe --- /dev/null +++ b/big_vision/configs/proj/paligemma/finetune_paligemma.ipynb @@ -0,0 +1,1167 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "wR53lePHuiP-" + }, + "source": [ + "# Finetune PaliGemma\n", + "\n", + "> *These models and code are not official Google products and were trained and released for research purposes.*\n", + "\n", + "\n", + "**This notebook shows how to finetune PaliGemma on a vision-language task.**\n", + "The training data consists of 90 pairs of images and long captions describing them.\n", + "To make it runnable on a T4 colab runtime with 16GB HBM and 12GB RAM, we opt to only finetune the attention layers of the language model and freeze the other parameters.\n", + "\n", + " **This setup is illustrative**. In a real usecase, the amount of data, trainable parameters, training steps and hyper-parameters and obtained results could be significantly different.\n", + "\n", + "This notebook uses the model reference implementation from [big_vision](https://github.com/google-research/big_vision).\n", + "and shows how to:\n", + "\n", + " * Install deps, download model checkpoint and training data.\n", + " * Load the model onto GPU devices.\n", + " * Prepare the input to the model for training and inference.\n", + " * Finetune the model and inspect output in validation split." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "6U0QUFveqSP2" + }, + "source": [ + "## Setup" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "id": "DfxKb3F839Ks", + "outputId": "d02e98d5-8334-463f-f529-6292dd73b04b", + "colab": { + "base_uri": "https://localhost:8080/" + } + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m77.9/77.9 kB\u001b[0m \u001b[31m3.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m43.2/43.2 kB\u001b[0m \u001b[31m6.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h Building wheel for ml_collections (setup.py) ... \u001b[?25l\u001b[?25hdone\n" + ] + } + ], + "source": [ + "# @title Fetch big_vision code and install dependencies.\n", + "import os\n", + "import sys\n", + "\n", + "# TPUs with\n", + "if \"COLAB_TPU_ADDR\" in os.environ:\n", + " raise \"It seems you are using Colab with remote TPUs which is not supported.\"\n", + "\n", + "# Fetch big_vision repository if python doesn't know about it and install\n", + "# dependencies needed for this notebook.\n", + "if not os.path.exists(\"big_vision_repo\"):\n", + " !git clone --quiet --branch=main --depth=1 \\\n", + " https://github.com/google-research/big_vision big_vision_repo\n", + "\n", + "# Append big_vision code to python import path\n", + "if \"big_vision_repo\" not in sys.path:\n", + " sys.path.append(\"big_vision_repo\")\n", + "\n", + "# Install missing dependencies. Assume jax~=0.4.25 with GPU available.\n", + "!pip3 install -q \"overrides\" \"ml_collections\" \"einops~=0.7\" \"sentencepiece\"\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "azmRZvgGyhAb" + }, + "source": [ + "### Configure your API key to access Kaggle\n", + "\n", + "To use PaliGemma, you must provide your Kaggle username and a Kaggle API key.\n", + "\n", + "1. To generate a Kaggle API key, go to the **Account** tab of your Kaggle user profile and select **Create New Token**. This will trigger the download of a `kaggle.json` file containing your API credentials.\n", + "1. In Colab, select **Secrets** (🔑) in the left pane and add your Kaggle username and Kaggle API key. Store your username under the name `KAGGLE_USERNAME` and your API key under the name `KAGGLE_KEY`.\n", + "\n", + "To be able to download, you will also need to acknowledge the Terms and Conditions of the PaliGemma on:\n", + "\n", + "* https://www.kaggle.com/models/google/paligemma/\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "id": "zGLIp1Cx3_CX" + }, + "outputs": [], + "source": [ + "import os\n", + "from google.colab import userdata\n", + "\n", + "# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env\n", + "# vars as appropriate or make your credentials available in ~/.kaggle/kaggle.json\n", + "\n", + "os.environ[\"KAGGLE_USERNAME\"] = userdata.get('KAGGLE_USERNAME')\n", + "os.environ[\"KAGGLE_KEY\"] = userdata.get('KAGGLE_KEY')" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "gQNOTfF24AV4", + "outputId": "54f8aeed-bdbd-4ab3-941b-373392591505" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Downloading the checkpoint from Kaggle, this could take a few minutes....\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "Downloading from https://www.kaggle.com/api/v1/models/google/paligemma/jax/paligemma-3b-pt-224/1/download/paligemma-3b-pt-224.f16.npz...\n", + "100%|██████████| 5.45G/5.45G [01:00<00:00, 95.9MB/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Model path: /root/.cache/kagglehub/models/google/paligemma/jax/paligemma-3b-pt-224/1/./paligemma-3b-pt-224.f16.npz\n", + "Downloading the model tokenizer...\n", + "Copying gs://big_vision/paligemma_tokenizer.model...\n", + "- [1 files][ 4.1 MiB/ 4.1 MiB] \n", + "Operation completed over 1 objects/4.1 MiB. \n", + "Tokenizer path: ./paligemma_tokenizer.model\n", + "Downloading the dataset...\n", + "Data path: ./longcap100\n" + ] + } + ], + "source": [ + "# @title Download checkpoint, tokenizer and dataset to local filesystem.\n", + "#\n", + "import os\n", + "import kagglehub\n", + "\n", + "MODEL_PATH = \"./paligemma-3b-pt-224.f16.npz\"\n", + "if not os.path.exists(MODEL_PATH):\n", + " print(\"Downloading the checkpoint from Kaggle, this could take a few minutes....\")\n", + " # Note: kaggle archive contains the same checkpoint in multiple formats.\n", + " # Download only the float16 model.\n", + " MODEL_PATH = kagglehub.model_download('google/paligemma/jax/paligemma-3b-pt-224', MODEL_PATH)\n", + " print(f\"Model path: {MODEL_PATH}\")\n", + "\n", + "TOKENIZER_PATH = \"./paligemma_tokenizer.model\"\n", + "if not os.path.exists(TOKENIZER_PATH):\n", + " print(\"Downloading the model tokenizer...\")\n", + " !gsutil cp gs://big_vision/paligemma_tokenizer.model {TOKENIZER_PATH}\n", + " print(f\"Tokenizer path: {TOKENIZER_PATH}\")\n", + "\n", + "DATA_DIR=\"./longcap100\"\n", + "if not os.path.exists(DATA_DIR):\n", + " print(\"Downloading the dataset...\")\n", + " !gsutil -m -q cp -n -r gs://longcap100/ .\n", + " print(f\"Data path: {DATA_DIR}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "zDoq0O77GF30" + }, + "source": [ + "## Notebook" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "dTfe2k8J4Bw0", + "outputId": "b9864437-9e35-493a-bf52-019c18d5dfd9" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "JAX version: 0.4.26\n", + "JAX platform: gpu\n", + "JAX devices: 1\n" + ] + } + ], + "source": [ + "import base64\n", + "import functools\n", + "import html\n", + "import io\n", + "import os\n", + "import warnings\n", + "\n", + "import jax\n", + "import jax.numpy as jnp\n", + "import numpy as np\n", + "import ml_collections\n", + "\n", + "import tensorflow as tf\n", + "import sentencepiece\n", + "\n", + "from IPython.core.display import display, HTML\n", + "from PIL import Image\n", + "\n", + "# Import model definition from big_vision\n", + "from big_vision.models.proj.paligemma import paligemma\n", + "from big_vision.trainers.proj.paligemma import predict_fns\n", + "\n", + "# Import big vision utilities\n", + "import big_vision.datasets.jsonl\n", + "import big_vision.utils\n", + "import big_vision.sharding\n", + "\n", + "# Don't let TF use the GPU or TPUs\n", + "tf.config.set_visible_devices([], \"GPU\")\n", + "tf.config.set_visible_devices([], \"TPU\")\n", + "\n", + "backend = jax.lib.xla_bridge.get_backend()\n", + "print(f\"JAX version: {jax.__version__}\")\n", + "print(f\"JAX platform: {backend.platform}\")\n", + "print(f\"JAX devices: {jax.device_count()}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "id": "1aghcULcEdtv" + }, + "outputs": [], + "source": [ + "# @title Construct model and load params into RAM.\n", + "\n", + "# Define model\n", + "model_config = ml_collections.FrozenConfigDict({\n", + " \"llm\": {\"vocab_size\": 257_152},\n", + " \"img\": {\"variant\": \"So400m/14\", \"pool_type\": \"none\", \"scan\": True, \"dtype_mm\": \"float16\"}\n", + "})\n", + "model = paligemma.Model(**model_config)\n", + "tokenizer = sentencepiece.SentencePieceProcessor(TOKENIZER_PATH)\n", + "\n", + "# Load params - this can take up to 1 minute in T4 colabs.\n", + "params = paligemma.load(None, MODEL_PATH, model_config)\n", + "\n", + "# Define `decode` function to sample outputs from the model.\n", + "decode_fn = predict_fns.get_all(model)['decode']\n", + "decode = functools.partial(decode_fn, devices=jax.devices(), eos_token=tokenizer.eos_id())" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "RWOdf_fw2SAO", + "outputId": "6d48433f-7410-480d-b889-e2b679caa8a6" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + " == Model params == \n", + "img/Transformer/encoder_norm/bias (1152,) float16\n", + "img/Transformer/encoder_norm/scale (1152,) float16\n", + "img/Transformer/encoderblock/LayerNorm_0/bias (27, 1152) float16\n", + "img/Transformer/encoderblock/LayerNorm_0/scale (27, 1152) float16\n", + "img/Transformer/encoderblock/LayerNorm_1/bias (27, 1152) float16\n", + "img/Transformer/encoderblock/LayerNorm_1/scale (27, 1152) float16\n", + "img/Transformer/encoderblock/MlpBlock_0/Dense_0/bias (27, 4304) float16\n", + "img/Transformer/encoderblock/MlpBlock_0/Dense_0/kernel (27, 1152, 4304) float16\n", + "img/Transformer/encoderblock/MlpBlock_0/Dense_1/bias (27, 1152) float16\n", + "img/Transformer/encoderblock/MlpBlock_0/Dense_1/kernel (27, 4304, 1152) float16\n", + "img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/bias (27, 16, 72) float16\n", + "img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/kernel (27, 1152, 16, 72) float16\n", + "img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/bias (27, 1152) float16\n", + "img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/kernel (27, 16, 72, 1152) float16\n", + "img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/bias (27, 16, 72) float16\n", + "img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/kernel (27, 1152, 16, 72) float16\n", + "img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/bias (27, 16, 72) float16\n", + "img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/kernel (27, 1152, 16, 72) float16\n", + "img/embedding/bias (1152,) float16\n", + "img/embedding/kernel (14, 14, 3, 1152) float16\n", + "img/head/bias (2048,) float16\n", + "img/head/kernel (1152, 2048) float16\n", + "img/pos_embedding (1, 256, 1152) float16\n", + "llm/embedder/input_embedding (257152, 2048) float16\n", + "llm/final_norm/scale (2048,) float16\n", + "llm/layers/attn/attn_vec_einsum/w (18, 8, 256, 2048) float32\n", + "llm/layers/attn/kv_einsum/w (18, 2, 1, 2048, 256) float32\n", + "llm/layers/attn/q_einsum/w (18, 8, 2048, 256) float32\n", + "llm/layers/mlp/gating_einsum (18, 2, 2048, 16384) float16\n", + "llm/layers/mlp/linear (18, 16384, 2048) float16\n", + "llm/layers/pre_attention_norm/scale (18, 2048) float16\n", + "llm/layers/pre_ffw_norm/scale (18, 2048) float16\n" + ] + } + ], + "source": [ + "# @title Move params to GPU/TPU memory.\n", + "#\n", + "# To keep HBM usage low and fit in a T4 GPU (16GB HBM) we opt to only finetune\n", + "# a part of the parameters. Additionally we keep the frozen params in float16\n", + "# and cast trainable to float32.\n", + "\n", + "# Create a pytree mask of the trainable params.\n", + "def is_trainable_param(name, param): # pylint: disable=unused-argument\n", + " if name.startswith(\"llm/layers/attn/\"): return True\n", + " if name.startswith(\"llm/\"): return False\n", + " if name.startswith(\"img/\"): return False\n", + " raise ValueError(f\"Unexpected param name {name}\")\n", + "trainable_mask = big_vision.utils.tree_map_with_names(is_trainable_param, params)\n", + "\n", + "#\n", + "# If more than one device is available (e.g. multiple GPUs) the parameters can\n", + "# be sharded across them to reduce HBM usage per device.\n", + "mesh = jax.sharding.Mesh(jax.devices(), (\"data\"))\n", + "\n", + "data_sharding = jax.sharding.NamedSharding(\n", + " mesh, jax.sharding.PartitionSpec(\"data\"))\n", + "\n", + "params_sharding = big_vision.sharding.infer_sharding(\n", + " params, strategy=[('.*', 'fsdp(axis=\"data\")')], mesh=mesh)\n", + "\n", + "# Yes: Some donated buffers are not usable.\n", + "warnings.filterwarnings(\n", + " \"ignore\", message=\"Some donated buffers were not usable\")\n", + "\n", + "@functools.partial(jax.jit, donate_argnums=(0,), static_argnums=(1,))\n", + "def maybe_cast_to_f32(params, trainable):\n", + " return jax.tree.map(lambda p, m: p.astype(jnp.float32) if m else p,\n", + " params, trainable)\n", + "\n", + "# Loading all params in simultaneous - albeit much faster and more succinct -\n", + "# requires more RAM than the T4 colab runtimes have by default (12GB RAM).\n", + "# Instead we do it param by param.\n", + "params, treedef = jax.tree.flatten(params)\n", + "sharding_leaves = jax.tree.leaves(params_sharding)\n", + "trainable_leaves = jax.tree.leaves(trainable_mask)\n", + "for idx, (sharding, trainable) in enumerate(zip(sharding_leaves, trainable_leaves)):\n", + " params[idx] = big_vision.utils.reshard(params[idx], sharding)\n", + " params[idx] = maybe_cast_to_f32(params[idx], trainable)\n", + " params[idx].block_until_ready()\n", + "params = jax.tree.unflatten(treedef, params)\n", + "\n", + "# Print params to show what the model is made of.\n", + "def parameter_overview(params):\n", + " for path, arr in big_vision.utils.tree_flatten_with_names(params)[0]:\n", + " print(f\"{path:80s} {str(arr.shape):22s} {arr.dtype}\")\n", + "\n", + "print(\" == Model params == \")\n", + "parameter_overview(params)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "id": "8SRW0NuU4UcW" + }, + "outputs": [], + "source": [ + "# @title Define preprocess functions to create inputs to the model.\n", + "\n", + "def preprocess_image(image, size=224):\n", + " # Model has been trained to handle images of different aspects ratios\n", + " # resized to 224x224 in the range [-1, 1]. Bilinear and antialias resize\n", + " # options are helpful to improve quality in some tasks.\n", + " image = np.asarray(image)\n", + " if image.ndim == 2: # Convert image without last channel into greyscale.\n", + " image = np.stack((image,)*3, axis=-1)\n", + " image = image[..., :3] # Remove alpha layer.\n", + " assert image.shape[-1] == 3\n", + "\n", + " image = tf.constant(image)\n", + " image = tf.image.resize(image, (size, size), method='bilinear', antialias=True)\n", + " return image.numpy() / 127.5 - 1.0 # [0, 255]->[-1,1]\n", + "\n", + "def preprocess_tokens(prefix, suffix=None, seqlen=None):\n", + " # Model has been trained to handle tokenized text composed of a prefix with\n", + " # full attention and a suffix with causal attention.\n", + " separator = \"\\n\"\n", + " tokens = tokenizer.encode(prefix, add_bos=True) + tokenizer.encode(separator)\n", + " mask_ar = [0] * len(tokens) # 0 to use full attention for prefix.\n", + " mask_loss = [0] * len(tokens) # 0 to not use prefix tokens in the loss.\n", + "\n", + " if suffix:\n", + " suffix = tokenizer.encode(suffix, add_eos=True)\n", + " tokens += suffix\n", + " mask_ar += [1] * len(suffix) # 1 to use causal attention for suffix.\n", + " mask_loss += [1] * len(suffix) # 1 to use suffix tokens in the loss.\n", + "\n", + " mask_input = [1] * len(tokens) # 1 if its a token, 0 if padding.\n", + " if seqlen:\n", + " padding = [0] * max(0, seqlen - len(tokens))\n", + " tokens = tokens[:seqlen] + padding\n", + " mask_ar = mask_ar[:seqlen] + padding\n", + " mask_loss = mask_loss[:seqlen] + padding\n", + " mask_input = mask_input[:seqlen] + padding\n", + "\n", + " return jax.tree.map(np.array, (tokens, mask_ar, mask_loss, mask_input))\n", + "\n", + "def postprocess_tokens(tokens):\n", + " tokens = tokens.tolist() # np.array to list[int]\n", + " try: # Remove tokens at and after EOS if any.\n", + " eos_pos = tokens.index(tokenizer.eos_id())\n", + " tokens = tokens[:eos_pos]\n", + " except ValueError:\n", + " pass\n", + " return tokenizer.decode(tokens)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "id": "whzWOojGOtzi" + }, + "outputs": [], + "source": [ + "# @title Function to iterate over train and validation examples.\n", + "SEQLEN = 128\n", + "\n", + "# TODO: Consider data iterators skipping big_vision and tf.data?\n", + "train_dataset = big_vision.datasets.jsonl.DataSource(\n", + " os.path.join(DATA_DIR, \"data_train90.jsonl\"),\n", + " fopen_keys={\"image\": DATA_DIR})\n", + "\n", + "val_dataset = big_vision.datasets.jsonl.DataSource(\n", + " os.path.join(DATA_DIR, \"data_val10.jsonl\"),\n", + " fopen_keys={\"image\": DATA_DIR})\n", + "\n", + "\n", + "def train_data_iterator():\n", + " \"\"\"Never ending iterator over training examples.\"\"\"\n", + " # Shuffle examples and repeat so one can train for many epochs.\n", + " dataset = train_dataset.get_tfdata().shuffle(1_000).repeat()\n", + " for example in dataset.as_numpy_iterator():\n", + " image = Image.open(io.BytesIO(example[\"image\"]))\n", + " image = preprocess_image(image)\n", + "\n", + " prefix = \"caption en\" # Could also be a different prefix per example.\n", + " suffix = example[\"suffix\"].decode().lower()\n", + " tokens, mask_ar, mask_loss, _ = preprocess_tokens(prefix, suffix, SEQLEN)\n", + "\n", + " yield {\n", + " \"image\": np.asarray(image),\n", + " \"text\": np.asarray(tokens),\n", + " \"mask_ar\": np.asarray(mask_ar),\n", + " \"mask_loss\": np.asarray(mask_loss),\n", + " }\n", + "\n", + "\n", + "def validation_data_iterator():\n", + " \"\"\"Single iterator over validation examples.\"\"\"\n", + " for example in val_dataset.get_tfdata(ordered=True).as_numpy_iterator():\n", + " image = Image.open(io.BytesIO(example[\"image\"]))\n", + " image = preprocess_image(image)\n", + "\n", + " prefix = \"caption en\" # Could also be a different prefix per example.\n", + " tokens, mask_ar, _, mask_input = preprocess_tokens(prefix, seqlen=SEQLEN)\n", + "\n", + " yield {\n", + " \"image\": np.asarray(image),\n", + " \"text\": np.asarray(tokens),\n", + " \"mask_ar\": np.asarray(mask_ar),\n", + " \"mask_input\": np.asarray(mask_input),\n", + " }\n" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 516 + }, + "id": "BzJfb5t0nsLq", + "outputId": "1f6640f7-09b4-41a3-c713-62966b0df7e7" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Training examples\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "\n", + "
\n", + " \n", + "

a tall crane stands in the heart of the city, casting long shadows across the streets below. the sky is clear and blue, with fluffy white clouds drifting lazily. a tall white building dominates the skyline, its windows reflecting the afternoon sun. a tall black building casts a long shadow on the ground.

\n", + "
\n", + " \n", + "
\n", + " \n", + "

a neon sign on a brick wall reads "this is the sign you ve been looking for." the sign is lit up and the letters are white. there are several pillows on the wall, including a pillow with a skull and crossbones. the wall is made of bricks and the sign is on the wall. the sign is neon and the letters are white.

\n", + "
\n", + " \n", + "
\n", + " \n", + "

a tool box filled with a variety of tools, including a wrench with a silver head, a screwdriver with a gray handle,a wrench with a gray head, a screwdriver with a gray handle, a metal socket with a silver head...

\n", + "
\n", + " \n", + "
\n", + " \n", + "

a person stands on a sidewalk, their shoes and legs visible. the ground is made of concrete. the person wears black pants, with a white sole on their shoe and a white lace on their shoe. the shoes are black and white. the person's legs are visible. the words "passion led us here" are written on the ground in red. the concrete has a shadow on it, and the sun shines on the ground

\n", + "
\n", + " \n", + "
\n", + " \n", + "

a bowl of steaming instant noodle soup with a spoon resting in the center. the broth is clear and the vegetables, including carrots, peas, and green beans, are floating gently in the liquid. the spoon is long and silver, with a reflection of light on its handle. the overall image is simple and straightforward, with a focus on the deliciousness of the soup.

\n", + "
\n", + " \n", + "
\n", + " \n", + "

a kitchen with a stove and a window. the room has a green floor and a white wall. there is a white towel hanging on the oven door. the stove has a black oven door, a black knob on the stove, and a black and white oven. there is a a silver pot on the stove. the window has a yellow frame and there is a white plastic bag hanging on the oven door. the door is open.

\n", + "
\n", + " \n", + "
\n", + " \n", + "

a plate of colorful desserts and a cup of coffee. the plate features a variety of sweet treats, including a macaron with a bite taken out, a green macaron, a pink macaron, and a white macaron. the plate is adorned with a white flower on a tree branch and a green leaf on a plant. the coffee cup has a white handle. the table is covered in a white and gold tablecloth.

\n", + "
\n", + " \n", + "
\n", + " \n", + "

a car is parked in front of a building with a green umbrella. the building has a green and white sign, a green and white umbrella, and a white sign with black lettering. there is a tall palm tree and a tall tree. the car is parked next to a yellow car and a black car. the road is grey and the sky is white.

\n", + "
\n", + " " + ] + }, + "metadata": {} + } + ], + "source": [ + "# @title Inspect training examples.\n", + "def render_inline(image, resize=(128, 128)):\n", + " \"\"\"Convert image into inline html.\"\"\"\n", + " image = Image.fromarray(image)\n", + " image.resize(resize)\n", + " with io.BytesIO() as buffer:\n", + " image.save(buffer, format='jpeg')\n", + " image_b64 = str(base64.b64encode(buffer.getvalue()), \"utf-8\")\n", + " return f\"data:image/jpeg;base64,{image_b64}\"\n", + "\n", + "def render_example(image, caption):\n", + " image = ((image + 1)/2 * 255).astype(np.uint8) # [-1,1] -> [0, 255]\n", + " return f\"\"\"\n", + "
\n", + " \n", + "

{html.escape(caption)}

\n", + "
\n", + " \"\"\"\n", + "\n", + "html_out = \"\"\n", + "for idx, example in zip(range(8), train_data_iterator()):\n", + " caption = postprocess_tokens(example[\"text\"]) # detokenize model input.\n", + " caption = caption[len(\"caption en\\n\"):] # strip prefix\n", + " html_out += render_example(example[\"image\"], caption)\n", + "\n", + "print(\"Training examples\")\n", + "display(HTML(html_out))" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "id": "dwUV_imW3WQJ" + }, + "outputs": [], + "source": [ + "# @title Define the training step and evaluation loop.\n", + "#\n", + "# The main update_fn using simple SGD.\n", + "#\n", + "@functools.partial(jax.jit, donate_argnums=(0,))\n", + "def update_fn(params, batch, learning_rate):\n", + " imgs, txts, mask_ar = batch[\"image\"], batch[\"text\"], batch[\"mask_ar\"]\n", + "\n", + " def loss_fn(params):\n", + " text_logits, _ = model.apply({\"params\": params}, imgs, txts[:, :-1], mask_ar[:, :-1], train=True)\n", + " logp = jax.nn.log_softmax(text_logits, axis=-1)\n", + "\n", + " # The model takes as input txts[:, :-1] but the loss is defined as predicting\n", + " # next tokens txts[:, 1:]. Additionally, mask_loss[:, 1:] indicates which tokens\n", + " # are part of the loss (e.g. prefix and padded tokens are not included).\n", + " mask_loss = batch[\"mask_loss\"][:, 1:]\n", + " targets = jax.nn.one_hot(txts[:, 1:], text_logits.shape[-1])\n", + "\n", + " # Compute the loss per example. i.e. the mean of per token pplx.\n", + " # Since each example has a different number of tokens we normalize it.\n", + " token_pplx = jnp.sum(logp * targets, axis=-1) # sum across vocab_size.\n", + " example_loss = -jnp.sum(token_pplx * mask_loss, axis=-1) # sum across seq_len.\n", + " example_loss /= jnp.clip(jnp.sum(mask_loss, -1), 1) # weight by num of tokens.\n", + "\n", + " # batch_loss: mean of per example loss.\n", + " return jnp.mean(example_loss)\n", + "\n", + " loss, grads = jax.value_and_grad(loss_fn)(params)\n", + "\n", + " # Apply gradients to trainable params using SGD.\n", + " def apply_grad(param, gradient, trainable):\n", + " if not trainable: return param\n", + " return param - learning_rate * gradient\n", + "\n", + " params = jax.tree_util.tree_map(apply_grad, params, grads, trainable_mask)\n", + "\n", + " return params, loss\n", + "\n", + "# Evaluation/inference loop.\n", + "def make_predictions(data_iterator, *, num_examples=None,\n", + " batch_size=4, seqlen=SEQLEN, sampler=\"greedy\"):\n", + " outputs = []\n", + " while True:\n", + " # Construct a list of examples in the batch.\n", + " examples = []\n", + " try:\n", + " for _ in range(batch_size):\n", + " examples.append(next(data_iterator))\n", + " examples[-1][\"_mask\"] = np.array(True) # Indicates true example.\n", + " except StopIteration:\n", + " if len(examples) == 0:\n", + " return outputs\n", + "\n", + " # Not enough examples to complete a batch. Pad by repeating last example.\n", + " while len(examples) % batch_size:\n", + " examples.append(dict(examples[-1]))\n", + " examples[-1][\"_mask\"] = np.array(False) # Indicates padding example.\n", + "\n", + " # Convert list of examples into a dict of np.arrays and load onto devices.\n", + " batch = jax.tree.map(lambda *x: np.stack(x), *examples)\n", + " batch = big_vision.utils.reshard(batch, data_sharding)\n", + "\n", + " # Make model predictions\n", + " tokens = decode({\"params\": params}, batch=batch,\n", + " max_decode_len=seqlen, sampler=sampler)\n", + "\n", + " # Fetch model predictions to device and detokenize.\n", + " tokens, mask = jax.device_get((tokens, batch[\"_mask\"]))\n", + " tokens = tokens[mask] # remove padding examples.\n", + " responses = [postprocess_tokens(t) for t in tokens]\n", + "\n", + " # Append to html output.\n", + " for example, response in zip(examples, responses):\n", + " outputs.append((example[\"image\"], response))\n", + " if num_examples and len(outputs) >= num_examples:\n", + " return outputs" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000 + }, + "id": "067wj_6bZAG3", + "outputId": "e1aa2df0-502e-4a70-c88d-db98739c01d5" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "step: 1/64 lr: 0.00500 loss: 2.7898\n", + "Model predictions at step 1\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "\n", + "
\n", + " \n", + "

the beauty of a puff sleeve

\n", + "
\n", + " \n", + "
\n", + " \n", + "

how to wear a maxi dress for summer

\n", + "
\n", + " \n", + "
\n", + " \n", + "

a red blazer and a black bag

\n", + "
\n", + " \n", + "
\n", + " \n", + "

how to wear boyfriend jeans like a fashion blogger

\n", + "
\n", + " " + ] + }, + "metadata": {} + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "step: 2/64 lr: 0.01000 loss: 2.1176\n", + "step: 3/64 lr: 0.01500 loss: 1.7491\n", + "step: 4/64 lr: 0.02000 loss: 1.5594\n", + "step: 5/64 lr: 0.02500 loss: 1.6047\n", + "step: 6/64 lr: 0.03000 loss: 1.3865\n", + "step: 7/64 lr: 0.02998 loss: 1.4946\n", + "step: 8/64 lr: 0.02992 loss: 1.6175\n", + "step: 9/64 lr: 0.02981 loss: 1.3377\n", + "step: 10/64 lr: 0.02966 loss: 1.4888\n", + "step: 11/64 lr: 0.02947 loss: 1.3479\n", + "step: 12/64 lr: 0.02924 loss: 1.3211\n", + "step: 13/64 lr: 0.02897 loss: 1.0806\n", + "step: 14/64 lr: 0.02866 loss: 1.1590\n", + "step: 15/64 lr: 0.02831 loss: 1.1158\n", + "step: 16/64 lr: 0.02792 loss: 1.1702\n", + "Model predictions at step 16\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "\n", + "
\n", + " \n", + "

a woman wearing a pink blouse with a large puffy sleeve stands on a white wall. the woman's hand rests on the wall, and her fingers are intertwined. the wall is white, and the light is shining on the woman's hand. the sky is clear, and the sun is shining. the woman is wearing a pink blouse, and her hair is in a bun.

\n", + "
\n", + " \n", + "
\n", + " \n", + "

a woman wearing a white floral dress sits on a stone wall overlooking the ocean. the dress is flowing in the wind and the hat is on her head. the sky is clear and the sun is shining. the woman is wearing a hat and a bag.

\n", + "
\n", + " \n", + "
\n", + " \n", + "

a person wearing a red blazer and a black belt bag. the person is standing in the woods and the grass is green. the person is wearing a black belt bag and a black belt. the person is wearing a red blazer and a black belt bag. the person is wearing a red blazer and a black belt bag. the person is wearing a red blazer and a black belt bag. the person is wearing a red blazer and a black belt bag. the person is wearing a red blazer and a black belt bag. the person is wearing a red blazer and a black belt bag. the person is wearing a red blazer and a black belt bag

\n", + "
\n", + " \n", + "
\n", + " \n", + "

a woman wearing a pink shirt and jeans stands on a stone staircase. she is holding a pink bag and wearing a bracelet. the stairs are made of stone and the woman is wearing a bracelet.

\n", + "
\n", + " " + ] + }, + "metadata": {} + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "step: 17/64 lr: 0.02750 loss: 1.1610\n", + "step: 18/64 lr: 0.02704 loss: 1.1972\n", + "step: 19/64 lr: 0.02655 loss: 1.1947\n", + "step: 20/64 lr: 0.02602 loss: 1.3566\n", + "step: 21/64 lr: 0.02546 loss: 1.1505\n", + "step: 22/64 lr: 0.02488 loss: 1.1470\n", + "step: 23/64 lr: 0.02426 loss: 0.9735\n", + "step: 24/64 lr: 0.02362 loss: 1.1087\n", + "step: 25/64 lr: 0.02296 loss: 0.9770\n", + "step: 26/64 lr: 0.02227 loss: 1.0618\n", + "step: 27/64 lr: 0.02156 loss: 0.9121\n", + "step: 28/64 lr: 0.02083 loss: 0.9501\n", + "step: 29/64 lr: 0.02009 loss: 0.9369\n", + "step: 30/64 lr: 0.01933 loss: 1.0276\n", + "step: 31/64 lr: 0.01856 loss: 0.9005\n", + "step: 32/64 lr: 0.01778 loss: 0.8751\n", + "Model predictions at step 32\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "\n", + "
\n", + " \n", + "

a person wearing a pink blouse with a white wall in the background. the blouse has a white collar and a white wall. the person is standing and wearing a white shirt. the wall is white and the person is standing on a white step. the person is wearing a white shirt and a white wall. the person is wearing a pink blouse and a white wall.

\n", + "
\n", + " \n", + "
\n", + " \n", + "

a woman wearing a white floral dress with a brown belt, holding a white bag. the dress has a v-neckline and short sleeves. the woman is standing on a stone wall overlooking the ocean. the sky is clear and blue, with a few white clouds. the water is calm and blue, with a few white boats on the horizon. the woman is wearing a brown belt and a brown bag. the woman's hair is in a ponytail.

\n", + "
\n", + " \n", + "
\n", + " \n", + "

a person wearing a red blazer with a black fanny pack around their waist. the blazer has a white button down and a black belt. the person is standing in front of a green plant. the plant has a green leaf and a green stem. the person is wearing a black shirt and black pants. the person is wearing a black belt and a black fanny pack. the person is wearing a red blazer and a red shirt. the person is wearing a black jacket and a black pants. the person is wearing a black jacket and a black pants. the person is wearing a black jacket and a black pants. the person is wearing a

\n", + "
\n", + " \n", + "
\n", + " \n", + "

a woman wearing a pink shirt and blue jeans, holding a pink bag. the woman is standing on a stone staircase, wearing a white cardigan and a gold bracelet. the bag is pink and has a black strap. the woman is wearing a gold bracelet and a gold bracelet on her wrist. the woman is standing on a stone staircase.

\n", + "
\n", + " " + ] + }, + "metadata": {} + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "step: 33/64 lr: 0.01699 loss: 0.8570\n", + "step: 34/64 lr: 0.01620 loss: 1.0866\n", + "step: 35/64 lr: 0.01540 loss: 0.7469\n", + "step: 36/64 lr: 0.01460 loss: 0.8605\n", + "step: 37/64 lr: 0.01380 loss: 0.6830\n", + "step: 38/64 lr: 0.01301 loss: 0.7631\n", + "step: 39/64 lr: 0.01222 loss: 0.7814\n", + "step: 40/64 lr: 0.01144 loss: 0.8579\n", + "step: 41/64 lr: 0.01067 loss: 0.7825\n", + "step: 42/64 lr: 0.00991 loss: 0.6906\n", + "step: 43/64 lr: 0.00917 loss: 0.7922\n", + "step: 44/64 lr: 0.00844 loss: 0.7030\n", + "step: 45/64 lr: 0.00773 loss: 0.7501\n", + "step: 46/64 lr: 0.00704 loss: 0.6384\n", + "step: 47/64 lr: 0.00638 loss: 0.6309\n", + "step: 48/64 lr: 0.00574 loss: 0.6149\n", + "Model predictions at step 48\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "\n", + "
\n", + " \n", + "

a person wearing a pink blouse with a puffy sleeve. the blouse has a white wall in the background. the person is standing on a white step, and the sun is shining on the wall. the person is wearing a bracelet on their wrist, and their hand is on the step. the person is wearing a watch on their wrist, and their nails are painted red.

\n", + "
\n", + " \n", + "
\n", + " \n", + "

a woman wears a white floral dress with a brown belt. the dress has a v-neckline and short sleeves. the woman is standing on a stone wall overlooking the ocean. the sky is clear and blue, and the boats are visible on the horizon. the woman is holding a white wicker bag. the woman's hair is tied back.

\n", + "
\n", + " \n", + "
\n", + " \n", + "

a person wearing a red blazer with a black fanny pack around their waist. the blazer has a white button down and a black belt. the person is standing and wearing a black shirt. the fanny pack is black and has a white loading on it. the person is wearing black pants and black shoes. the person is standing and the grass is green.

\n", + "
\n", + " \n", + "
\n", + " \n", + "

a woman wears a pink shirt and blue jeans. she has a pink bag on her shoulder and a bracelet on her wrist. the stairs are made of stone and the woman is standing on the stairs.

\n", + "
\n", + " " + ] + }, + "metadata": {} + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "step: 49/64 lr: 0.00512 loss: 0.7896\n", + "step: 50/64 lr: 0.00454 loss: 0.6380\n", + "step: 51/64 lr: 0.00398 loss: 0.6263\n", + "step: 52/64 lr: 0.00345 loss: 0.6160\n", + "step: 53/64 lr: 0.00296 loss: 0.6626\n", + "step: 54/64 lr: 0.00250 loss: 0.5598\n", + "step: 55/64 lr: 0.00208 loss: 0.5567\n", + "step: 56/64 lr: 0.00169 loss: 0.7069\n", + "step: 57/64 lr: 0.00134 loss: 0.5293\n", + "step: 58/64 lr: 0.00103 loss: 0.5725\n", + "step: 59/64 lr: 0.00076 loss: 0.5477\n", + "step: 60/64 lr: 0.00053 loss: 0.6153\n", + "step: 61/64 lr: 0.00034 loss: 0.5260\n", + "step: 62/64 lr: 0.00019 loss: 0.5673\n", + "step: 63/64 lr: 0.00008 loss: 0.5721\n", + "step: 64/64 lr: 0.00002 loss: 0.6681\n", + "Model predictions at step 64\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "\n", + "
\n", + " \n", + "

a person wearing a pink blouse with a puffy sleeve. the blouse has a white wall in the background. the person is standing and their hand is on the wall. the wall has a white line on it. the person is wearing a pink blouse with a puffy sleeve.

\n", + "
\n", + " \n", + "
\n", + " \n", + "

a woman wearing a white floral dress stands on a pier overlooking the ocean. she is wearing a brown bag and a brown hat. the dress has a v-neckline and a tie belt. the woman has her hand in her pocket and her leg is visible. the sky is clear and blue.

\n", + "
\n", + " \n", + "
\n", + " \n", + "

a person wearing a red blazer with a black fanny pack around their waist. the blazer has a white button down and a black belt. the person is standing and wearing a black shirt. the fanny pack is black and has a white loading on it. the person is wearing black pants and black shoes. the tree behind them is green.

\n", + "
\n", + " \n", + "
\n", + " \n", + "

a woman wears a pink shirt and blue jeans. she has a pink bag on her shoulder and a bracelet on her wrist. the stairs are made of gray stone. the woman is standing on the stairs.

\n", + "
\n", + " " + ] + }, + "metadata": {} + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "CPU times: user 12min 34s, sys: 6.31 s, total: 12min 40s\n", + "Wall time: 13min 5s\n" + ] + } + ], + "source": [ + "# @title Run training loop.\n", + "#\n", + "# Run a short training loop with cosine learning rate schedule.\n", + "#\n", + "# Note: the first step can be quite slow on some machines (up to several minutes)\n", + "# due to XLA compilation of the jax.jit'd function.\n", + "#\n", + "%%time\n", + "\n", + "BATCH_SIZE = 8\n", + "TRAIN_EXAMPLES = 512\n", + "LEARNING_RATE = 0.03\n", + "\n", + "TRAIN_STEPS = TRAIN_EXAMPLES // BATCH_SIZE\n", + "EVAL_STEPS = TRAIN_STEPS // 4\n", + "\n", + "train_data_it = train_data_iterator()\n", + "\n", + "sched_fn = big_vision.utils.create_learning_rate_schedule(\n", + " total_steps=TRAIN_STEPS+1, base=LEARNING_RATE,\n", + " decay_type=\"cosine\", warmup_percent=0.10)\n", + "\n", + "for step in range(1, TRAIN_STEPS+1):\n", + " # Make list of N training examples.\n", + " examples = [next(train_data_it) for _ in range(BATCH_SIZE)]\n", + "\n", + " # Convert list of examples into a dict of np.arrays and load onto devices.\n", + " batch = jax.tree.map(lambda *x: np.stack(x), *examples)\n", + " batch = big_vision.utils.reshard(batch, data_sharding)\n", + "\n", + " # Training step and report training loss\n", + " learning_rate = sched_fn(step)\n", + " params, loss = update_fn(params, batch, learning_rate)\n", + "\n", + " loss = jax.device_get(loss)\n", + " print(f\"step: {step:2d}/{TRAIN_STEPS:2d} lr: {learning_rate:.5f} loss: {loss:.4f}\")\n", + "\n", + " if step == 1 or (step % EVAL_STEPS) == 0:\n", + " print(f\"Model predictions at step {step}\")\n", + " html_out = \"\"\n", + " for image, caption in make_predictions(\n", + " validation_data_iterator(), num_examples=4, batch_size=4):\n", + " html_out += render_example(image, caption)\n", + " display(HTML(html_out))\n" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 699 + }, + "id": "hgUhEKjzPdMQ", + "outputId": "63037cd6-151c-4802-9de8-be2cb7818d12" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Model predictions\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "\n", + "
\n", + " \n", + "

a person wearing a pink blouse with a puffy sleeve. the blouse has a white wall in the background. the person is standing and their hand is on the wall. the wall has a white line on it. the person is wearing a pink blouse with a puffy sleeve.

\n", + "
\n", + " \n", + "
\n", + " \n", + "

a woman wearing a white floral dress stands on a pier overlooking the ocean. she is wearing a brown bag and a brown hat. the dress has a v-neckline and a tie belt. the woman has her hand in her pocket and her leg is visible. the sky is clear and blue.

\n", + "
\n", + " \n", + "
\n", + " \n", + "

a person wearing a red blazer with a black fanny pack around their waist. the blazer has a white button down and a black belt. the person is standing and wearing a black shirt. the fanny pack is black and has a white loading on it. the person is wearing black pants and black shoes. the tree behind them is green.

\n", + "
\n", + " \n", + "
\n", + " \n", + "

a woman wears a pink shirt and blue jeans. she has a pink bag on her shoulder and a bracelet on her wrist. the stairs are made of gray stone. the woman is standing on the stairs.

\n", + "
\n", + " \n", + "
\n", + " \n", + "

a pink sweatshirt with a red slogan lies on a bed next to a pair of jeans and a pair of white sneakers. the sweatshirt features long sleeves and a crew neck. the text on the sweatshirt reads "love well, save us." the sneakers are white and have white laces. the hand on the sweatshirt is gentle.

\n", + "
\n", + " \n", + "
\n", + " \n", + "

a man with long blonde hair covers his face with his hand. he wears a navy sweater and a black and white checkered shirt. the sweater has long sleeves and a collar. the man has a beard and mustache. the hair on his face is long and wavy. the man is standing and his hands are on his head. the background is pink.

\n", + "
\n", + " \n", + "
\n", + " \n", + "

a white metal rack with a white metal pole and a white metal pole. the rack has a white metal pole and a white metal pole. the rack has a white metal pole and a white metal pole.

\n", + "
\n", + " \n", + "
\n", + " \n", + "

a white hoodie hangs on a black coat rack, with a white drawstring on the left side of the hoodie. the coat rack is black, and the wall is white. there is a black circle on the wall, and a black circle on the wall. the hoodie has a white drawstring on the left side of the hoodie.

\n", + "
\n", + " \n", + "
\n", + " \n", + "

a woman wears a pair of blue jeans with a black bag. the jeans have a gray wall behind them. the woman is wearing a black bag with a gold chain. the bag has a silver chain and a silver lock. the woman is wearing black boots with black socks. the bag has a silver chain and a silver lock. the woman is wearing a black bracelet on her wrist. the woman is standing and the bag is on her shoulder.

\n", + "
\n", + " \n", + "
\n", + " \n", + "

a man stands on a sidewalk wearing a blue denim jacket with a white t-shirt, brown pants, and white shoes. he has his hands in his pockets and his legs are stretched out. the jacket has a blue collar and a white t-shirt. the pants have a brown stripe on the side and a white shoe. the man is wearing a white t-shirt and brown pants.

\n", + "
\n", + " " + ] + }, + "metadata": {} + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "CPU times: user 26.2 s, sys: 112 ms, total: 26.3 s\n", + "Wall time: 32.9 s\n" + ] + } + ], + "source": [ + "# @title Evaluate the model on all examples.\n", + "#\n", + "# The validation data consists of 10 images in a different domain than training\n", + "# data.\n", + "%%time\n", + "\n", + "print(\"Model predictions\")\n", + "html_out = \"\"\n", + "for image, caption in make_predictions(validation_data_iterator(), batch_size=4):\n", + " html_out += render_example(image, caption)\n", + "display(HTML(html_out))\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Ai0NMbAwsr0j" + }, + "source": [ + "# Save the final checkpoint" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "5H_3CV33_JkV" + }, + "outputs": [], + "source": [ + "def npsave(pytree, path):\n", + " names_and_vals, _ = big_vision.utils.tree_flatten_with_names(pytree)\n", + " with open(path, \"wb\") as f:\n", + " np.savez(f, **{k:v for k, v in names_and_vals})\n", + "\n", + "# Takes around 4 minutes\n", + "npsave(params, 'my-custom-paligemma-ckpt.npz')" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/big_vision/configs/proj/paligemma/paligemma.png b/big_vision/configs/proj/paligemma/paligemma.png new file mode 100644 index 0000000000000000000000000000000000000000..949d4af23958646ff18899028c906022e9aa1357 Binary files /dev/null and b/big_vision/configs/proj/paligemma/paligemma.png differ diff --git a/big_vision/configs/proj/paligemma/transfers/activitynet_cap.py b/big_vision/configs/proj/paligemma/transfers/activitynet_cap.py new file mode 100644 index 0000000000000000000000000000000000000000..2176595c357eeab9c2a4f57f819513409e97be23 --- /dev/null +++ b/big_vision/configs/proj/paligemma/transfers/activitynet_cap.py @@ -0,0 +1,209 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""PaliGemma transfer to ActivityNet Video captioning. + +IMPORTANT: This config is based on an unreleased version of DeepMind Video +Readers (DMVR). Users can either set up DMVR using the open source code from +GitHub (see below for details), or add their own data loader of choice. +""" + +import big_vision.configs.common as bvcc +from big_vision.configs.proj.paligemma.transfers.common import combine_and_keep_train, combine_and_keep_eval, TOKENIZER + +TEXT_LEN = 64 +DATASET_NAME = 'activitynet_captions_mr' +# Numbers might need to be updated due to wipeout. Current from 2024-04-28 +SPLIT_SIZE = {'train': 30545, 'valid': 14338, 'test': 13982} + + +def training_data(res, *, final_split, num_frames=8, stride=None): + """Creates training data config. + + Args: + res: The requested image resolution (eg 224). + final_split: Train on all train+valid data. + num_frames: number of sampled frames per video. + stride: stride at which the frames are sampled. + + Returns: + The ConfigDict for the input section. + """ + pp = '|'.join([ + # prepare the frames by decoding, resizing, replicating, sampling: + f'video_decode({res})|video_replicate_img({num_frames},{num_frames})', + f'video_ensure_shape("image", {(num_frames, res, res, 3)})', + # pick one caption at random during training (there is actually just one!) + 'strfmt("caption en", outkey="prefix")', + 'video_choice(inkey="caption/string", outkey="suffix")', + combine_and_keep_train(TEXT_LEN), + ]) + + c = bvcc.parse_arg('') + c.data = {} + splits = ['train', 'valid'] if final_split else ['train'] + raise NotImplementedError('Please implement a video reader of choice!') + # For example DMVR https://github.com/google-deepmind/dmvr + # The reader should support the following arguments: + # - name: Name of the reader. + # - dataset_name: Name of the data set. + # - split: Data set split. + # - num_frames: Number of frames sampled from the video. + # - stride: Stride at which the video frames are sampled. + # - deterministic_fs: Whether to sample the frames starting at the first + # frame or whether an offest should be chosen at random (if there are more + # frames than num_frames * stride) + # - first_k_shards: Whether to only use the first k shards of the data + # (optional but useful for speeding up intermediate evaluations). + for split in splits: + c.data[split] = SPLIT_SIZE[split] + c[split] = {'pp': pp} + c[split].data = dict( + # PLEASE ADD YOUR READER HERE: + name='', + dataset_name=DATASET_NAME, split=split, + num_frames=num_frames, stride=stride, + deterministic_fs=False) + return c + + +def add_eval(c, res, num_frames=8, stride=None): # pylint: disable=unused-argument + """Captioning evaluator.""" + c_train = training_data(res, final_split=True, num_frames=num_frames, stride=stride) + + pp = '|'.join([ + f'video_decode({res})|video_replicate_img({num_frames},{num_frames})', + f'video_ensure_shape("image", {(num_frames, res, res, 3)})', + 'strfmt("caption en", outkey="prefix")', + 'strfmt("{example/video_id}[{segment_start}-{segment_end}]", outkey="image/id")', + 'copy("caption/string", "captions")', + combine_and_keep_eval(TEXT_LEN, keep=('image/id', 'captions')), + ]) + + for freq, name, split, first_k_shards, skip_first_eval in [ + (1/8, 'minitrain', 'train', 2, False), # To gauge memorization. + (1/4, 'minival', 'valid', 2, False), # To monitor val progress. + (1, 'val', 'valid', None, False), # To tune hparams. + (1, 'eval', 'test', None, False), # final metric + ]: + c.evals[f'{DATASET_NAME}/{name}'] = dict( + type='proj.paligemma.transfers.coco_caption', + pred='decode', pred_kw={'max_decode_len': TEXT_LEN}, + data={**c_train.train.data, 'split': split, + 'first_k_shards': first_k_shards, + 'deterministic_fs': True}, + log_percent=freq, tokenizer=TOKENIZER, + pp_fn=pp, skip_first=skip_first_eval) + + +def add_eval_pplx(c, res, num_frames=8, stride=None): + """Perplexity evaluator to test runs before implementing the real deal.""" + c_train = training_data(res, final_split=True, num_frames=num_frames, stride=stride) + + for name, split, first_k_shards in [ + ('minitrain', 'train', 2), # To gauge memorization. + ]: + c.evals[f'{DATASET_NAME}/{name}/pplx'] = dict( + type='proj.paligemma.perplexity', pred='logits', + key='text', shift_labels=True, + log_percent=1/8, # Not too cheap, do 10x per run. + data={**c_train.train.data, 'split': split, + 'first_k_shards': first_k_shards, + 'deterministic_fs': True}, + pp_fn=c_train.train.pp, + ) + + +def sweep_best(add, arg=None): + """Train with best hyper-params.""" + c = bvcc.parse_arg(arg, final_split=False) + add(lr=1e-5, wd=1e-6, total_epochs=1, **bvcc.arg(freeze_vit=True, res=224, **c)) + + +sweep = sweep_best + + +def get_config(arg=None): + """Config for training.""" + c = bvcc.parse_arg(arg, mode='xm', num_frames=16, stride=30, res=224, + freeze_vit=False, freeze_llm=False, final_split=False) + + c.input = training_data( + c.res, final_split=c.final_split, + num_frames=c.num_frames, stride=c.stride) + + c.total_epochs = 3 + c.input.batch_size = 128 + c.optax_name = 'scale_by_adam' + c.optax = dict(b2=0.999) + c.lr = 3e-6 + c.wd = 3e-7 + c.grad_clip_norm = 1.0 + c.label_smoothing = 0.0 + + # Learning-rate schedule. + sched = dict(decay_type='cosine', warmup_percent=0.05) + c.schedule = [ + ('img/.*', None if c.freeze_vit else sched), + ('llm/.*', None if c.freeze_llm else sched), + ] + + # Add evaluators. + c.evals = {} + add_eval(c, c.res, c.num_frames, c.stride) + add_eval_pplx(c, c.res, c.num_frames, c.stride) + + # Model section. + c.model_name = 'proj.paligemma.paligemma' + c.model = {} + c.model.img = dict(variant='So400m/14', pool_type='none', scan=True) + c.model.llm = dict(vocab_size=256_000 + 1024 + 128, dropout=0.0) + c.model_init = f'pt_{c.res}' + + # FSDP strategy. + c.mesh = [('data', -1)] + c.sharding_strategy = [('.*', 'fsdp(axis="data")')] + c.sharding_rules = [('act_batch', ('data',))] + + for split in c.input.data.keys(): + c.input[split].shuffle_buffer_size = 10_000 + c.log_training_steps = 50 + c.ckpt_steps = 1_000 + c.pp_modules = ['ops_general', 'ops_image', 'ops_text', 'proj.paligemma.ops', + 'proj.paligemma.video'] + + # Update configs for quicker local runs and avoid swapping. + if c.mode in ('runlocal', 'mock'): + for split in c.input.data.keys(): + c.input[split].shuffle_buffer_size = None + for ev in c.evals.values(): + ev.data.first_k_shards = 1 + + if c.mode == 'runlocal': + c.log_training_steps = 1 + c.input.batch_size = 2 + + c.seed = 0 + return c + + +def metrics(arg=None): # pylint: disable=unused-argument + m = ['training_loss'] + for split in ('minitrain', 'minival', 'val', 'eval'): + m.append(('epoch', f'{DATASET_NAME}/{split}/cider')) + for split in ('minitrain', 'minival'): + m.append(('epoch', f'{DATASET_NAME}/{split}/pplx/avg')) + return m + diff --git a/big_vision/configs/proj/paligemma/transfers/activitynet_qa.py b/big_vision/configs/proj/paligemma/transfers/activitynet_qa.py new file mode 100644 index 0000000000000000000000000000000000000000..556ce79d3dc6ee855e0e1c33d1c94605972f3a1e --- /dev/null +++ b/big_vision/configs/proj/paligemma/transfers/activitynet_qa.py @@ -0,0 +1,213 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""PaliGemma transfer to ActivityNet Video QA. + +IMPORTANT: This config is based on an unreleased version of DeepMind Video +Readers (DMVR). Users can either set up DMVR using the open source code from +GitHub (see below for details), or add their own data loader of choice. +""" + +import big_vision.configs.common as bvcc +from big_vision.configs.proj.paligemma.transfers.common import combine_and_keep_train, combine_and_keep_eval, TOKENIZER + +TEXT_LEN = 64 +DATASET_NAME = 'activitynet_qa' +# Numbers might need to be updated due to wipeout. Current from 2024-04-28 +SPLIT_SIZE = {'train': 27610, 'valid': 15760, 'test': 6900} + + +def training_data(res, *, final_split, num_frames, stride): + """Creates training data config. + + Args: + res: The requested image resolution (eg 224). + final_split: Train on all train+valid data. + num_frames: number of sampled frames per video. + stride: stride at which the frames are sampled. + + Returns: + The ConfigDict for the input section. + """ + pp = '|'.join([ + # prepare the frames by decoding, resizing, replicating, sampling: + f'video_decode({res})|video_replicate_img({num_frames},{num_frames})', + f'video_ensure_shape("image", {(num_frames, res, res, 3)})', + # only one question/answer per example. + 'reshape([], key="question")|reshape([], key="answer")', + 'strfmt("answer en {question}", outkey="prefix")', + 'copy("answer", "suffix")', + combine_and_keep_train(TEXT_LEN), + ]) + + c = bvcc.parse_arg('') + c.data = {} + splits = ['train', 'valid'] if final_split else ['train'] + raise NotImplementedError('Please implement a video reader of choice!') + # For example DMVR https://github.com/google-deepmind/dmvr + # The reader should support the following arguments: + # - name: Name of the reader. + # - dataset_name: Name of the data set. + # - split: Data set split. + # - num_frames: Number of frames sampled from the video. + # - stride: Stride at which the video frames are sampled. + # - deterministic_fs: Whether to sample the frames starting at the first + # frame or whether an offest should be chosen at random (if there are more + # frames than num_frames * stride) + # - first_k_shards: Whether to only use the first k shards of the data + # (optional but useful for speeding up intermediate evaluations). + for split in splits: + c.data[split] = SPLIT_SIZE[split] + c[split] = {'pp': pp} + c[split].data = dict( + # PLEASE ADD YOUR READER HERE: + name='', + dataset_name=DATASET_NAME, split=split, + num_frames=num_frames, stride=stride, + deterministic_fs=False) + return c + + +def add_eval(c, res, num_frames, stride): # pylint: disable=unused-argument + """QA evaluator.""" + c_train = training_data(res, final_split=True, num_frames=num_frames, stride=stride) + + pp = '|'.join([ + # prepare the frames by decoding, resizing, replicating, sampling: + f'video_decode({res})|video_replicate_img({num_frames},{num_frames})', + f'video_ensure_shape("image", {(num_frames, res, res, 3)})', + # only one question/answer per example. + 'reshape([], key="question")|reshape([], key="answer")', + 'strfmt("answer en {question}", outkey="prefix")', + 'strfmt("{id}#{example/video_id}: {question}", "question_id")', + combine_and_keep_eval(TEXT_LEN, keep=('question_id', 'answer')), + ]) + + for freq, name, split, first_k_shards, skip_first_eval in [ + (1/8, 'minitrain', 'train', 2, False), # To gauge memorization. + (1/4, 'minival', 'valid', 2, False), # To monitor val progress. + (1, 'val', 'valid', None, True), # To tune hparams. + (1, 'eval', 'test', None, True), # final metric + ]: + c.evals[f'activitynet_qa/{name}'] = dict( + type='proj.paligemma.transfers.vqa', + pred='decode', pred_kw={'max_decode_len': TEXT_LEN}, + data={**c_train.train.data, 'split': split, + 'first_k_shards': first_k_shards, + 'deterministic_fs': True}, + log_percent=freq, tokenizer=TOKENIZER, + pp_fn=pp, skip_first=skip_first_eval) + + +def add_eval_pplx(c, res, num_frames, stride): + """Perplexity evaluator to test runs before implementing the real deal.""" + c_train = training_data(res, final_split=True, num_frames=num_frames, stride=stride) + + for name, split, first_k_shards in [ + ('minitrain', 'train', 2), # To gauge memorization. + ('minival', 'valid', 2), + ]: + c.evals[f'activitynet_qa/{name}/pplx'] = dict( + type='proj.paligemma.perplexity', pred='logits', + key='text', shift_labels=True, + log_percent=1/8, # Not too cheap, do 10x per run. + data={**c_train.train.data, 'split': split, + 'first_k_shards': first_k_shards, + 'deterministic_fs': True}, + pp_fn=c_train.train.pp, + ) + + +def sweep_best(add, arg=None): + """Train with best hyper-params.""" + c = bvcc.parse_arg(arg, final_split=False) + add(lr=1e-5, wd=1e-6, total_epochs=1, **bvcc.arg(num_frames=16, stride=70, res=224, **c)) + + +sweep = sweep_best + + +def get_config(arg=None): + """Config for training.""" + c = bvcc.parse_arg(arg, mode='xm', num_frames=16, stride=70, res=224, + freeze_vit=False, freeze_llm=False, final_split=False) + + c.input = training_data( + c.res, final_split=c.final_split, + num_frames=c.num_frames, stride=c.stride) + + c.total_epochs = 3 + c.input.batch_size = 128 + c.optax_name = 'scale_by_adam' + c.optax = dict(b2=0.999) + c.lr = 1e-5 + c.wd = 1e-6 + c.grad_clip_norm = 1.0 + c.label_smoothing = 0.0 + + # Learning-rate schedule. + sched = dict(decay_type='cosine', warmup_percent=0.05) + c.schedule = [ + ('img/.*', None if c.freeze_vit else sched), + ('llm/.*', None if c.freeze_llm else sched), + ] + + # Add evaluators. + c.evals = {} + add_eval(c, c.res, c.num_frames, c.stride) + add_eval_pplx(c, c.res, c.num_frames, c.stride) + + # Model section. + c.model_name = 'proj.paligemma.paligemma' + c.model = {} + c.model.img = dict(variant='So400m/14', pool_type='none', scan=True) + c.model.llm = dict(vocab_size=256_000 + 1024 + 128, dropout=0.0) + c.model_init = f'pt_{c.res}' + + # FSDP strategy. + c.mesh = [('data', -1)] + c.sharding_strategy = [('.*', 'fsdp(axis="data")')] + c.sharding_rules = [('act_batch', ('data',))] + + for split in c.input.data.keys(): + c.input[split].shuffle_buffer_size = 10_000 + c.log_training_steps = 50 + c.ckpt_steps = 1_000 + c.pp_modules = ['ops_general', 'ops_image', 'ops_text', 'proj.paligemma.ops', + 'proj.paligemma.video'] + + # Update configs for quicker local runs and avoid swapping. + if c.mode in ('runlocal', 'mock'): + for split in c.input.data.keys(): + c.input[split].shuffle_buffer_size = None + for ev in c.evals.values(): + ev.data.first_k_shards = 1 + + if c.mode == 'runlocal': + c.log_training_steps = 1 + c.input.batch_size = 2 + + c.seed = 0 + return c + + +def metrics(arg=None): # pylint: disable=unused-argument + m = ['training_loss'] + for split in ('minitrain', 'minival', 'val', 'eval'): + m.append(('epoch', f'{DATASET_NAME}/{split}/acc')) + for split in ('minitrain', 'minival'): + m.append(('epoch', f'{DATASET_NAME}/{split}/pplx/avg')) + return m + diff --git a/big_vision/configs/proj/paligemma/transfers/ai2d.py b/big_vision/configs/proj/paligemma/transfers/ai2d.py new file mode 100644 index 0000000000000000000000000000000000000000..221091e795815493acf6ea58eadbbc2f5cda2cee --- /dev/null +++ b/big_vision/configs/proj/paligemma/transfers/ai2d.py @@ -0,0 +1,170 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""PaliGemma transfer to AI2D. +""" + +import big_vision.configs.common as bvcc +from big_vision.configs.proj.paligemma.transfers.common import combine_and_keep_train, combine_and_keep_eval, TOKENIZER + +PREFIX = 'answer en ' +PROMPT = 'choose from:' +PROMPT_SEP = ' \t ' + + +def training_data(res, final_split, text_len=128): + """Creates training data config. + + See (internal link) + You can add more arguments beside `res`, but give them good defaults. + + Args: + res: The requested image resolution (eg 224). + final_split: whether to use all train data. + text_len: sequence length + + Returns: + The ConfigDict for the input section. + """ + c = bvcc.parse_arg('') # Just make a configdict without extra import. + c.data = dict( + name='ai2d', + # 12k training examples. + split='train' if final_split else 'train[:-1024]', + ) + c.pp = '|'.join([ + f'decode|resize({res}, antialias=True)|value_range(-1, 1)', + f'strjoin("{PROMPT_SEP}", inkey="possible_answers", outkey="ansstr")', + f'strfmt("{PREFIX} {{question}} {PROMPT} {{ansstr}}", outkey="prefix")', + 'copy(inkey="answer", outkey="suffix")', + combine_and_keep_train(text_len), + ]) + return c + + +def add_eval(c, res, text_len=128, **kw): + """AI2D evaluators.""" + pp = '|'.join([ + f'decode|resize({res})|value_range(-1, 1)', + f'strjoin("{PROMPT_SEP}", inkey="possible_answers", outkey="ansstr")', + f'strfmt("{PREFIX} {{question}} {PROMPT} {{ansstr}}", outkey="prefix")', + 'copy(inkey="id",outkey="question_id")', + combine_and_keep_eval(text_len, keep=('answer', 'question_id')), + ]) + + for name, split in [ + ('minitrain', 'train[:1024]'), # To gauge memorization. + ('minival', 'train[-1024:]'), # To tune hparams. + ('eval', 'test'), # To compute final publishable scores. + ]: + c.evals[f'ai2d/{name}'] = dict( + type='proj.paligemma.transfers.vqa', + pred='decode', pred_kw={'max_decode_len': text_len}, + outfile=f'{{workdir}}/ai2d_{name}.json', + to_lower=False, # Model sees options in prompt and can match the case. + data={**training_data(res, True, text_len).data, 'split': split}, + log_percent=1/8, tokenizer=TOKENIZER, pp_fn=pp) + c.evals[f'ai2d/{name}'].update(kw) + + +def add_eval_pplx(c, res, text_len=128): + """Perplexity evaluator to test runs before implementing the real deal.""" + c_train = training_data(res, True, text_len) # Use mostly same settings as training. + + for name, split in [ + ('minitrain', 'train[:1024]'), # To gauge memorization. + ('minival', 'train[-1024:]'), # To tune hparams. + ('eval', 'test'), # To compute final publishable scores. + ]: + c.evals[f'ai2d/{name}/pplx'] = dict( + type='proj.paligemma.perplexity', pred='logits', + key='text', shift_labels=True, + log_percent=1/8, + data={**c_train.data, 'split': split}, + pp_fn=c_train.pp, + ) + + +def sweep_best(add, arg=None): + """Train with best hyper-params.""" + c = bvcc.parse_arg(arg, final_split=False) + add(lr=1e-5, wd=1e-6, total_epochs=10, **bvcc.arg(res=224, **c)) + add(lr=1e-5, wd=1e-6, total_epochs=10, **bvcc.arg(res=448, **c)) + # 896 was not better than 448 ((internal link)). + + +sweep = sweep_best # Choose which sweep to run. + + +def get_config(arg=None): + """Config for training.""" + c = bvcc.parse_arg(arg, mode='xm', res=224, final_split=False) + + c.input = training_data(c.res, final_split=c.final_split) + + # Instead of epochs, you can also use `total_examples` or `total_steps`. + c.total_epochs = 10 + c.input.batch_size = 256 + c.optax_name = 'scale_by_adam' + c.optax = dict(b2=0.999) + c.lr = 1e-5 + c.wd = 1e-5 * 0.1 + c.grad_clip_norm = 1.0 + c.label_smoothing = 0.0 + c.schedule = dict(decay_type='cosine', warmup_percent=0.05) + + # Add evaluators. + c.evals = {} + add_eval(c, c.res, batch_size=1024) + add_eval_pplx(c, c.res) + + # Model section. + c.model_name = 'proj.paligemma.paligemma' + c.model = {} + c.model.img = dict(variant='So400m/14', pool_type='none', scan=True) + c.model.llm = dict(vocab_size=256_000 + 1024 + 128, dropout=0.0) + c.model_init = f'pt_{c.res}' + + # FSDP strategy. + c.mesh = [('data', -1)] + c.sharding_strategy = [('.*', 'fsdp(axis="data")')] + c.sharding_rules = [('act_batch', ('data',))] + + # These probably do not need any change/tuning + c.input.shuffle_buffer_size = 50_000 + c.log_training_steps = 50 + c.ckpt_steps = 1_000 + c.pp_modules = ['ops_general', 'ops_image', 'ops_text', 'proj.paligemma.ops'] + + # Update configs for quicker local runs and avoid swapping. + if c.mode in ('runlocal', 'mock'): + c.input.shuffle_buffer_size = None + for ev in c.evals.values(): + ev.data.split = ev.data.split.split('[')[0] + '[:16]' + + if c.mode == 'runlocal': + c.log_training_steps = 1 + c.input.batch_size = 2 + + c.seed = 0 + return c + + +def metrics(arg=None): # pylint: disable=unused-argument + m = ['training_loss'] + for split in ('eval', 'minival', 'minitrain'): + m.append(f'ai2d/{split}/pplx/avg') + m.append(f'ai2d/{split}/acc') + return m diff --git a/big_vision/configs/proj/paligemma/transfers/aokvqa_da.py b/big_vision/configs/proj/paligemma/transfers/aokvqa_da.py new file mode 100644 index 0000000000000000000000000000000000000000..62c7664967de378a8edc3969e9e5b4135311a8d9 --- /dev/null +++ b/big_vision/configs/proj/paligemma/transfers/aokvqa_da.py @@ -0,0 +1,161 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""PaliGemma transfer to A-OK-VQA using Direct Answer mode. +""" + +import big_vision.configs.common as bvcc +from big_vision.configs.proj.paligemma.transfers.common import combine_and_keep_train, combine_and_keep_eval, TOKENIZER + + +def training_data(res, final_split, text_len=32): + """Creates training data config. + + See (internal link) + You can add more arguments beside `res`, but give them good defaults. + + Args: + res: The requested image resolution (eg 224). + final_split: Whether to use train and validation data. + text_len: sequence length + + Returns: + The ConfigDict for the input section. + """ + c = bvcc.parse_arg('') # Just make a configdict without extra import. + c.data = dict( + name='aokvqa', + split='train + val' if final_split else 'train', + ) + c.pp = '|'.join([ + f'decode|resize({res}, antialias=True)|value_range(-1, 1)', + 'strfmt("answer en {question}", outkey="prefix")', + 'choice_no_replacement(inkey="direct_answers", outkey="suffix")', + combine_and_keep_train(text_len), + ]) + return c + + +def add_eval(c, res, text_len=32, **kw): + """We can use the normal VQA evaluators.""" + pp = '|'.join([ + f'decode|resize({res}, antialias=True)|value_range(-1, 1)', + 'strfmt("answer en {question}", outkey="prefix")', + 'copy(inkey="direct_answers", outkey="answers")', + combine_and_keep_eval(text_len, keep=('answers', 'question_id')), + ]) + + for freq, name, split in [ + (1/4, 'minitrain', 'train[:5%]'), # To gauge memorization. + (1/4, 'eval', 'val'), # To tune hparams. + (1.0, 'test', 'test'), # To compute final predictions. + ]: + c.evals[f'aokvqa_da/{name}'] = dict( + type='proj.paligemma.transfers.vqa', + pred='decode', pred_kw={'max_decode_len': text_len}, + outfile=f'{{workdir}}/aokvqa_da_{name}.json', + data={**training_data(res, True, text_len).data, 'split': split}, + log_percent=freq, skip_first=freq == 1, tokenizer=TOKENIZER, pp_fn=pp) + c.evals[f'aokvqa/{name}'].update(kw) + + +def add_eval_pplx(c, res, text_len=32): + """Perplexity evaluator to test runs before implementing the real deal.""" + c_train = training_data(res, True, text_len) # Use mostly same settings as training. + + for name, split in [ + ('minitrain', 'train[:5%]'), # To gauge memorization. + ('eval', 'val'), # To tune hparams. + ]: + c.evals[f'aokvqa_da/{name}/pplx'] = dict( + type='proj.paligemma.perplexity', pred='logits', + key='text', shift_labels=True, + log_percent=0.05, # Eval ~20x per run; it's cheap. + data={**c_train.data, 'split': split}, + pp_fn=c_train.pp, + ) + + +def sweep_best(add, arg=None): + """Train with best hyper-params.""" + c = bvcc.parse_arg(arg, final_split=False) + add(lr=5e-6, wd=0.0, **bvcc.arg(res=224, **c)) + add(lr=5e-6, wd=0.0, **bvcc.arg(res=448, **c)) + # not better: add(lr=5e-6, wd=0.0, **bvcc.arg(res=896, **c)) + + +sweep = sweep_best # Choose which sweep to run. + + +def get_config(arg=None): + """Config for training.""" + c = bvcc.parse_arg(arg, mode='xm', res=224, final_split=False) + + c.input = training_data(c.res, c.final_split) + + # Instead of epochs, you can also use `total_examples` or `total_steps`. + c.total_epochs = 10 + c.input.batch_size = 128 + c.optax_name = 'scale_by_adam' + c.optax = dict(b2=0.999) + c.lr = 5e-6 + c.wd = 0.0 + c.grad_clip_norm = 1.0 + c.label_smoothing = 0.0 + c.schedule = dict(decay_type='cosine', warmup_percent=0.05) + + # Add evaluators. + c.evals = {} + add_eval(c, c.res, batch_size=256) + add_eval_pplx(c, c.res) + + # Model section. + c.model_name = 'proj.paligemma.paligemma' + c.model = {} + c.model.img = dict(variant='So400m/14', pool_type='none', scan=True) + c.model.llm = dict(vocab_size=256_000 + 1024 + 128, dropout=0.0) + c.model_init = f'pt_{c.res}' + + # FSDP strategy. + c.mesh = [('data', -1)] + c.sharding_strategy = [('.*', 'fsdp(axis="data")')] + c.sharding_rules = [('act_batch', ('data',))] + + # These probably do not need any change/tuning + c.input.shuffle_buffer_size = 50_000 + c.log_training_steps = 50 + c.ckpt_steps = 1_000 + c.pp_modules = ['ops_general', 'ops_image', 'ops_text', 'proj.paligemma.ops'] + + # Update configs for quicker local runs and avoid swapping. + if c.mode in ('runlocal', 'mock'): + c.input.shuffle_buffer_size = None + for ev in c.evals.values(): + ev.data.split = ev.data.split.split('[')[0] + '[:16]' + + if c.mode == 'runlocal': + c.log_training_steps = 1 + c.input.batch_size = 2 + + c.seed = 0 + return c + + +def metrics(arg=None): # pylint: disable=unused-argument + m = ['training_loss'] + for split in ('eval', 'minival', 'minitrain'): + m.append(f'aokvqa/{split}/pplx/avg') + m.append(f'aokvqa/{split}/acc') + return m diff --git a/big_vision/configs/proj/paligemma/transfers/aokvqa_mc.py b/big_vision/configs/proj/paligemma/transfers/aokvqa_mc.py new file mode 100644 index 0000000000000000000000000000000000000000..d5d9fadb4149b3c398aa28684ee4cbfb02489189 --- /dev/null +++ b/big_vision/configs/proj/paligemma/transfers/aokvqa_mc.py @@ -0,0 +1,169 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""PaliGemma transfer to A-OK-VQA using multiple choice answers. +""" + +import big_vision.configs.common as bvcc +from big_vision.configs.proj.paligemma.transfers.common import combine_and_keep_train, combine_and_keep_eval, TOKENIZER + + +PREFIX = 'answer en ' +PROMPT = 'choose from:' +PROMPT_SEP = ' \t ' + + +def training_data(res, final_split, text_len=128): + """Creates training data config. + + See (internal link) + You can add more arguments beside `res`, but give them good defaults. + + Args: + res: The requested image resolution (eg 224). + final_split: Whether to use train and validation data. + text_len: sequence length + + Returns: + The ConfigDict for the input section. + """ + c = bvcc.parse_arg('') # Just make a configdict without extra import. + c.data = dict( + name='aokvqa', + split='train + val' if final_split else 'train', + ) + c.pp = '|'.join([ + f'decode|resize({res}, antialias=True)|value_range(-1, 1)', + f'strjoin("{PROMPT_SEP}", inkey="multiple_choice_possible_answers", outkey="ansstr")', + f'strfmt("{PREFIX} {{question}} {PROMPT} {{ansstr}}", outkey="prefix")', + 'getidx(inkey="multiple_choice_possible_answers", index_key="multiple_choice_correct_idx", outkey="suffix")', + combine_and_keep_train(text_len), + ]) + return c + + +def add_eval(c, res, text_len=128, **kw): + """VQAv2 evaluators.""" + pp = '|'.join([ + f'decode|resize({res}, antialias=True)|value_range(-1, 1)', + f'strjoin("{PROMPT_SEP}", inkey="multiple_choice_possible_answers", outkey="ansstr")', + f'strfmt("{PREFIX} {{question}} {PROMPT} {{ansstr}}", outkey="prefix")', + 'getidx(inkey="multiple_choice_possible_answers", index_key="multiple_choice_correct_idx", outkey="answer")', + combine_and_keep_eval(text_len, keep=('answer', 'question_id')), + ]) + + for freq, name, split in [ + (1/4, 'minitrain', 'train[:5%]'), # To gauge memorization. + (1/4, 'eval', 'val'), # To tune hparams. + (1.0, 'test', 'test'), # To compute final predictions. + ]: + c.evals[f'aokvqa_mc/{name}'] = dict( + type='proj.paligemma.transfers.vqa', + pred='decode', pred_kw={'max_decode_len': text_len}, + outfile=f'{{workdir}}/aokvqa_mc_{name}.json', + data={**training_data(res, True, text_len).data, 'split': split}, + log_percent=freq, skip_first=freq == 1, tokenizer=TOKENIZER, pp_fn=pp) + c.evals[f'aokvqa/{name}'].update(kw) + + +def add_eval_pplx(c, res, text_len=128): + """Perplexity evaluator to test runs before implementing the real deal.""" + c_train = training_data(res, True, text_len) # Use mostly same settings as training. + + for name, split in [ + ('minitrain', 'train[:5%]'), # To gauge memorization. + ('eval', 'val'), # To tune hparams. + ('test', 'test'), # To compute final predictions. + ]: + c.evals[f'aokvqa_mc/{name}/pplx'] = dict( + type='proj.paligemma.perplexity', pred='logits', + key='text', shift_labels=True, + log_percent=0.05, # Eval ~20x per run; it's cheap. + data={**c_train.data, 'split': split}, + pp_fn=c_train.pp, + ) + + +def sweep_best(add, arg=None): + """Train with best hyper-params.""" + c = bvcc.parse_arg(arg, final_split=False) + add(lr=5e-6, wd=0.0, **bvcc.arg(res=224, **c)) + add(lr=5e-6, wd=0.0, **bvcc.arg(res=448, **c)) + # add(lr=5e-6, wd=0.0, **bvcc.arg(res=896, **c)) + + +sweep = sweep_best # Choose which sweep to run. + + +def get_config(arg=None): + """Config for training.""" + c = bvcc.parse_arg(arg, mode='xm', res=224, final_split=False) + + c.input = training_data(c.res, c.final_split) + + # Instead of epochs, you can also use `total_examples` or `total_steps`. + c.total_epochs = 15 + c.input.batch_size = 128 + c.optax_name = 'scale_by_adam' + c.optax = dict(b2=0.999) + c.lr = 5e-6 + c.wd = 0.0 + c.grad_clip_norm = 1.0 + c.label_smoothing = 0.0 + c.schedule = dict(decay_type='cosine', warmup_percent=0.05) + + # Add evaluators. + c.evals = {} + add_eval(c, c.res, batch_size=256) + add_eval_pplx(c, c.res) + + # Model section. + c.model_name = 'proj.paligemma.paligemma' + c.model = {} + c.model.img = dict(variant='So400m/14', pool_type='none', scan=True) + c.model.llm = dict(vocab_size=256_000 + 1024 + 128, dropout=0.0) + c.model_init = f'pt_{c.res}' + + # FSDP strategy. + c.mesh = [('data', -1)] + c.sharding_strategy = [('.*', 'fsdp(axis="data")')] + c.sharding_rules = [('act_batch', ('data',))] + + # These probably do not need any change/tuning + c.input.shuffle_buffer_size = 50_000 + c.log_training_steps = 50 + c.ckpt_steps = 1_000 + c.pp_modules = ['ops_general', 'ops_image', 'ops_text', 'proj.paligemma.ops'] + + # Update configs for quicker local runs and avoid swapping. + if c.mode in ('runlocal', 'mock'): + c.input.shuffle_buffer_size = None + for ev in c.evals.values(): + ev.data.split = ev.data.split.split('[')[0] + '[:16]' + + if c.mode == 'runlocal': + c.log_training_steps = 1 + c.input.batch_size = 2 + + c.seed = 0 + return c + + +def metrics(arg=None): # pylint: disable=unused-argument + m = ['training_loss'] + for split in ('eval', 'minival', 'minitrain'): + m.append(f'aokvqa/{split}/pplx/avg') + m.append(f'aokvqa/{split}/acc') + return m diff --git a/big_vision/configs/proj/paligemma/transfers/chartqa.py b/big_vision/configs/proj/paligemma/transfers/chartqa.py new file mode 100644 index 0000000000000000000000000000000000000000..df10fc301de8ae21e25fd1cc2ada520fc34dcbf2 --- /dev/null +++ b/big_vision/configs/proj/paligemma/transfers/chartqa.py @@ -0,0 +1,180 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""PaliGemma transfer to chartqa. +""" + +import big_vision.configs.common as bvcc +from big_vision.configs.proj.paligemma.transfers.common import combine_and_keep_train, combine_and_keep_eval, TOKENIZER + +_DATASETS = ('chartqa/human', 'chartqa/augmented') +# We use the true dataset sizes from https://arxiv.org/pdf/2203.10244.pdf. +_WEIGHTS = (7_398, 20_901) + + +def training_data(res, *, final_split=False, text_len=48): + """Creates training data config. + + See (internal link) + You can add more arguments beside `res`, but give them good defaults. + + Args: + res: The requested image resolution (eg 224). + final_split: Train on all train+val data. + text_len: sequence length. + + Returns: + The ConfigDict for the input section. + """ + c = bvcc.parse_arg('') # Just make a configdict without extra import. + pp = '|'.join([ + f'decode|resize({res}, antialias=True)|value_range(-1, 1)', + 'copy(inkey="question", outkey="prefix")', + 'copy(inkey="answer", outkey="suffix")', + combine_and_keep_train(text_len), + ]) + c.data = {ds: weight for ds, weight in zip(_DATASETS, _WEIGHTS)} + for ds in c.data: + c[ds] = dict( + shuffle_buffer_size=50_000, + pp=pp, + data=dict( + name=ds, + split='train+val' if final_split else 'train', + ), + ) + return c + + +def add_eval(c, res, text_len=48, **kw): + """Add eval configs.""" + c_train = training_data(res, final_split=True, text_len=text_len) + + pp_eval = '|'.join([ + f'decode|resize({res}, antialias=True)|value_range(-1, 1)', + 'copy(inkey="question", outkey="prefix")', + combine_and_keep_eval(text_len, keep=('answer', 'question_id')), + ]) + + for name, split in [ + ('minitrain', 'train[:5%]'), + ('minival', 'val'), + ('eval', 'test'), + ]: + for ds in _DATASETS: + c.evals[f'{ds}/{name}'] = dict( + type='proj.paligemma.transfers.chartqa', + pred='decode', pred_kw={'max_decode_len': text_len}, + to_lower=True, + outfile=f'{{workdir}}/{ds.replace("/", "_")}_{name}.json', + data={**c_train[ds].data, 'split': split}, + log_percent=0.1, tokenizer=TOKENIZER, pp_fn=pp_eval) + c.evals[f'{ds}/{name}'].update(kw) + + +def add_eval_pplx(c, res, text_len=48): + """Perplexity evaluator to test runs before implementing the real deal.""" + c_train = training_data(res, final_split=True, text_len=text_len) # Use mostly same settings as training. + for name, split in [ + ('minitrain', 'train[:5%]'), # To gauge memorization. + ('minival', 'val'), # To tune hparams. + ('eval', 'test'), # To compute final publishable scores. + ]: + for ds in _DATASETS: + c.evals[f'{ds}/{name}/pplx'] = dict( + type='proj.paligemma.perplexity', pred='logits', + key='text', shift_labels=True, + log_percent=0.05, # Eval ~20x per run; it's cheap. + data={**c_train[ds].data, 'split': split}, + pp_fn=c_train[ds].pp, + ) + + +def sweep_best(add, arg=None): + """Train with best hyper-params.""" + c = bvcc.parse_arg(arg, final_split=False) + # TODO: Update once latest numbers are in and have only 1 setup. + # Based on (internal link) (relaxed_accuracy). + add(lr=1e-5, wd=1e-6, total_epochs=30, **bvcc.arg(res=224, **c)) + # Based on sweep (internal link) and on (internal link) (relaxed_accuracy). + add(lr=1e-5, wd=1e-6, total_epochs=30, **bvcc.arg(res=448, **c)) + # Based on (internal link) (relaxed_accuracy). + # Not better: add(lr=1e-5, wd=1e-6, total_epochs=30, **bvcc.arg(res=896, **c)) + + +sweep = sweep_best # Choose which sweep to run. + + +def get_config(arg=None): + """Config for training.""" + c = bvcc.parse_arg(arg, mode='xm', res=896, final_split=False) + + c.input = training_data(c.res, final_split=c.final_split) + + # Instead of epochs, you can also use `total_examples` or `total_steps`. + c.total_epochs = 30 + c.input.batch_size = 256 + c.optax_name = 'scale_by_adam' + c.optax = dict(b2=0.999) + c.lr = 1e-5 + c.wd = 1e-6 + c.grad_clip_norm = 1.0 + c.label_smoothing = 0.2 + c.schedule = dict(decay_type='cosine', warmup_percent=0.05) + + # Add evaluators. + c.evals = {} + add_eval(c, c.res, batch_size=1024) + add_eval_pplx(c, c.res) + + # Model section. + c.model_name = 'proj.paligemma.paligemma' + c.model = {} + c.model.img = dict(variant='So400m/14', pool_type='none', scan=True) + c.model.llm = dict(vocab_size=256_000 + 1024 + 128, dropout=0.1) + c.model_init = f'pt_{c.res}' + + # FSDP strategy. + c.mesh = [('data', -1)] + c.sharding_strategy = [('.*', 'fsdp(axis="data")')] + c.sharding_rules = [('act_batch', ('data',))] + + # These probably do not need any change/tuning + c.log_training_steps = 50 + c.ckpt_steps = 1_000 + c.pp_modules = ['ops_general', 'ops_image', 'ops_text', 'proj.paligemma.ops'] + + # Update configs for quicker local runs and avoid swapping. + if c.mode in ('runlocal', 'mock'): + for ds in _DATASETS: + c.input[ds].shuffle_buffer_size = None + for ev in c.evals.values(): + ev.data.split = ev.data.split.split('[')[0] + '[:16]' + + if c.mode == 'runlocal': + c.log_training_steps = 1 + c.input.batch_size = 2 + + c.seed = 0 + return c + + +def metrics(arg=None): # pylint: disable=unused-argument + m = ['training_loss'] + for split in ('eval', 'minival', 'minitrain'): + for ds in _DATASETS: + m.append(f'{ds}/{split}/relaxed_acc') + m.append(f'{ds}/{split}/pplx/avg') + return m diff --git a/big_vision/configs/proj/paligemma/transfers/coco35l.py b/big_vision/configs/proj/paligemma/transfers/coco35l.py new file mode 100644 index 0000000000000000000000000000000000000000..ba8aedcf2e721f88c32a652764f49a20e3f1e9be --- /dev/null +++ b/big_vision/configs/proj/paligemma/transfers/coco35l.py @@ -0,0 +1,235 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""PaliGemma transfer to COCO-35L captions. +""" + +import big_vision.configs.common as bvcc +from big_vision.configs.proj.paligemma.transfers.common import combine_and_keep_train, combine_and_keep_eval, TOKENIZER + +LANGUAGES = ( + 'ar', 'bn', 'cs', 'da', 'de', 'el', 'en', 'es', 'fa', 'fi', 'fil', 'fr', + 'he', 'hi', 'hr', 'hu', 'id', 'it', 'ja', 'ko', 'mi', 'nl', 'no', 'pl', + 'pt', 'ro', 'ru', 'sv', 'sw', 'te', 'th', 'tr', 'uk', 'vi', 'zh', +) + +LANGUAGES_XM3600 = ( + 'ar', 'bn', 'cs', 'da', 'de', 'el', 'en', 'es', 'fa', 'fi', 'fil', 'fr', + 'he', 'hi', 'hr', 'hu', 'id', 'it', 'ja', 'ko', 'mi', 'nl', 'no', 'pl', + 'pt', 'quz', 'ro', 'ru', 'sv', 'sw', 'te', 'th', 'tr', 'uk', 'vi', 'zh' +) + +# A subset for more frequent evals. +LANGUAGES_SUBSET = ('ar', 'bn', 'en', 'id', 'sw', 'tr', 'zh') + + +def training_data(res, lang=None, text_len=32, crop='rs'): + """Creates training data config. + + See (internal link) + You can add more arguments beside `res`, but give them good defaults. + + Args: + res: The requested image resolution (eg 224) + lang: language code + text_len: sequence length + crop: one of {'ic', 'rc', 'rs'} + + Returns: + The ConfigDict for the input section. + """ + c = bvcc.parse_arg('') # Just make a configdict without extra import. + c.data = dict( + name='coco35l', + split=f'train_{lang}' if lang else '+'.join((f'train_{l}' for l in LANGUAGES)), + ) + + if crop == 'ic': + crop = f'inception_crop({res}, area_min=50)' + elif crop == 'rc': + crop = f'resize_small({res*8//7})|random_crop({res})' + elif crop == 'rs': + crop = f'resize({res})' + else: + raise ValueError(f'Unknown crop: {crop}') + + c.pp = '|'.join([ + 'flatten', + 'decode', crop, 'value_range(-1, 1)', + 'choice_no_replacement(inkey="captions", outkey="suffix")', + 'strfmt("caption {language}", outkey="prefix")', + combine_and_keep_train(text_len), + ]) + return c + + +def _get_eval_pp(res, lang, text_len=32): + return '|'.join([ + 'flatten', + 'decode', f'resize({res})', 'value_range(-1, 1)', + f'strfmt("caption {lang}", outkey="prefix")', + combine_and_keep_eval(text_len, keep=('image/id', 'captions')), + ]) + + +def add_eval(c, res, text_len=32, langs=None, **kw): + """Captioning evaluator with cider/bleu-4/meteor/rouge/spice metrics.""" + for lang in (langs or LANGUAGES): + # Frequent evals on a subset of representative languages, final eval on all. + freq = 0.25 if lang in LANGUAGES_SUBSET else 1.0 + + c.evals[f'coco35l/{lang}'] = dict( + type='proj.paligemma.transfers.coco_caption', + pred='decode', pred_kw={'max_decode_len': text_len}, + log_percent=freq, skip_first=(freq == 1.0), tokenizer=TOKENIZER, + data=dict( + name='coco35l', + split=f'dev_{lang}', + ), + cache='none', + pp_fn=_get_eval_pp(res, lang, text_len), + ) + c.evals[f'coco35l/{lang}'].update(kw) + + +def add_eval_xm(c, res, text_len=32, langs=None, **kw): + """Captioning evaluator with cider/bleu-4/meteor/rouge/spice metrics.""" + for lang in (langs or LANGUAGES_XM3600): + # Frequent evals on a subset of representative languages, final eval on all. + freq = 0.25 if lang in LANGUAGES_SUBSET else 1.0 + + c.evals[f'xm3600/{lang}'] = dict( + type='proj.paligemma.transfers.coco_caption', + pred='decode', pred_kw={'max_decode_len': text_len}, + log_percent=freq, skip_first=(freq == 1.0), tokenizer=TOKENIZER, + data=dict( + name='xm3600', + split=lang, + ), + pp_fn=_get_eval_pp(res, lang, text_len) + ) + c.evals[f'xm3600/{lang}'].update(kw) + + +def add_eval_pplx(c, res, text_len=32): + """Perplexity evaluator to test runs before implementing the real deal.""" + c_train = training_data(res, text_len=text_len) # Use mostly same settings as training. + for name, split in [ + ('minitrain', 'train_en[:2%]'), + ('minival', 'dev_en[:5%]'), + ('eval', 'dev_en'), + ]: + c.evals[f'coco35l/{name}/pplx'] = dict( + type='proj.paligemma.perplexity', pred='logits', + key='text', shift_labels=True, + log_percent=0.05, # Eval ~20x per run; it's cheap. + data={**c_train.data, 'split': split}, + pp_fn=c_train.pp, + ) + + +def get_config(arg=None): + """Config for training.""" + c = bvcc.parse_arg(arg, mode='xm', crop='rs', res=224, eval_xm3600=True, beam_size=0) + + c.input = { + lang: training_data(c.res, lang=lang, crop=c.crop) + for lang in LANGUAGES + } + c.input.data = {lang: 1 for lang in LANGUAGES} + for k in c.input.data: + c.input[k].shuffle_buffer_size = 10_000 + + c.total_examples = 566_435 # We need to go a looot longer here. + c.input.batch_size = 256 + c.optax_name = 'scale_by_adam' + c.optax = dict(b2=0.999) + c.lr = 1e-4 + c.wd = 0.0 + c.grad_clip_norm = 1.0 + c.label_smoothing = 0.0 + c.schedule = dict(decay_type='cosine', warmup_percent=0.05) + + # Add evaluators. + c.evals = {} + add_eval_pplx(c, c.res) + + if c.beam_size: + decode_kw = {'pred': 'beam_decode', 'pred_kw': {'beam_size': c.beam_size}} + else: + decode_kw = {} + + add_eval(c, c.res, batch_size=1024, **decode_kw) + if c.eval_xm3600: + add_eval_xm(c, c.res, batch_size=1024, **decode_kw) + + # Model section. + c.model_name = 'proj.paligemma.paligemma' + c.model = {} + c.model.img = dict(variant='So400m/14', pool_type='none', scan=True) + c.model.llm = dict(vocab_size=256_000 + 1024 + 128, dropout=0.0) + c.model_init = f'pt_{c.res}' + + # FSDP strategy. + c.mesh = [('data', -1)] + c.sharding_strategy = [('.*', 'fsdp(axis="data")')] + c.sharding_rules = [('act_batch', ('data',))] + + # These probably do not need any change/tuning + c.log_training_steps = 50 + c.ckpt_steps = 1_000 + c.pp_modules = ['ops_general', 'ops_image', 'ops_text', 'proj.paligemma.ops'] + + # Update configs for quicker local runs and avoid swapping. + if c.mode in ('runlocal', 'mock'): + # c.input.shuffle_buffer_size = None + for ev in c.evals.values(): + ev.data.split = ev.data.split.split('[')[0] + '[:16]' + + if c.mode == 'runlocal': + c.log_training_steps = 1 + c.input.batch_size = 2 + + c.seed = 0 + return c + + +def sweep_hyper(add): + """sweep over hyper-params.""" + for lr in (1e-5, 3e-6, 1e-6): + for wd in (0.0, 0.1*lr): + for ep in (1, 3, 5, 10, 20): + # One language COCO is 566_435 examples (5 captions, 100k examples). + add(lr=lr, wd=wd, total_examples=ep * 566_435, **bvcc.arg(res=224)) + + +def sweep_best(add, arg=None): + """Train with best hyper-params.""" + c = bvcc.parse_arg(arg, eval_xm3600=True) + ep = 566_435 + add(lr=1e-5, wd=1e-6, total_examples=5 * ep, **bvcc.arg(res=224, **c)) + add(lr=1e-5, wd=1e-6, total_examples=5 * ep, **bvcc.arg(res=448, **c)) + + +sweep = sweep_best # Choose which sweep to run. + + +def metrics(arg=None): # pylint: disable=unused-argument + c = bvcc.parse_arg(arg, eval_xm3600=True) + m = [('epoch', f'coco35l/{lang}/cider') for lang in LANGUAGES] + if c.eval_xm3600: + for lang in LANGUAGES: + m.append(('epoch', f'xm3600/{lang}/cider')) + return m diff --git a/big_vision/configs/proj/paligemma/transfers/cococap.py b/big_vision/configs/proj/paligemma/transfers/cococap.py new file mode 100644 index 0000000000000000000000000000000000000000..95202bb3105665932ff554a7d6300e7b71c626b8 --- /dev/null +++ b/big_vision/configs/proj/paligemma/transfers/cococap.py @@ -0,0 +1,194 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""PaliGemma transfer to COCO captions. +""" + +import big_vision.configs.common as bvcc +from big_vision.configs.proj.paligemma.transfers.common import combine_and_keep_train, combine_and_keep_eval, TOKENIZER + + +def training_data(res, *, final_split, text_len=32, crop='rs'): + """Creates training data config. + + See (internal link) + You can add more arguments beside `res`, but give them good defaults. + + Args: + res: The requested image resolution (eg 224). + final_split: Train on all train+restval data or train[:98%]+restval. + text_len: sequence length + crop: one of {'ic', 'rc', 'rs'} + + Returns: + The ConfigDict for the input section. + """ + c = bvcc.parse_arg('') # Just make a configdict without extra import. + c.data = dict( + name='coco_captions', + split='train+restval' if final_split else 'train[:98%]+restval', + ) + + if crop == 'ic': + crop = f'inception_crop({res}, area_min=50)' + elif crop == 'rc': + crop = f'resize_small({res*8//7})|random_crop({res})' + elif crop == 'rs': + crop = f'resize({res})' + else: + raise ValueError(f'Unknown crop: {crop}') + + c.pp = '|'.join([ + 'flatten', + 'decode', crop, 'value_range(-1, 1)', + 'choice_no_replacement(inkey="captions/text", outkey="suffix")', + 'strfmt("caption en", outkey="prefix")', + combine_and_keep_train(text_len), + ]) + return c + + +def add_eval(c, res, text_len=32, **kw): + """Captioning evaluator with cider/bleu-4/meteor/rouge/spice metrics.""" + # Input eval pp without ground truth text and random crop. + pp_eval = '|'.join([ + 'decode', f'resize({res})', 'value_range(-1, 1)', + 'strfmt("caption en", outkey="prefix")', + combine_and_keep_eval(text_len, keep=('image/id', 'captions')), + ]) + + for name, split in [ + ('minitrain', 'train[:2%]'), + ('minival', 'train[-2%:]'), + ('eval', 'val'), + ]: + c.evals[f'cococap/{name}'] = dict( + type='proj.paligemma.transfers.coco_caption', + pred='decode', pred_kw={'max_decode_len': text_len}, + log_percent=0.1, tokenizer=TOKENIZER, + data={'name': 'coco_captions', 'split': split}, + pp_fn='|'.join([ + 'flatten', 'copy("captions/text", "captions")', # GT for evaluator. + pp_eval, + ]), + ) + c.evals[f'cococap/{name}'].update(kw) + + c.evals['nocaps/eval'] = dict( + type='proj.paligemma.transfers.coco_caption', + pred='decode', pred_kw={'max_decode_len': text_len}, + log_percent=0.1, tokenizer=TOKENIZER, + data={'name': 'nocaps', 'split': 'val'}, + pp_fn='|'.join([ + 'copy("texts", "captions")', # GT for evaluator. + pp_eval, + ]), + ) + c.evals['nocaps/eval'].update(kw) + + +def add_eval_pplx(c, res, text_len=32): + """Perplexity evaluator to test runs before implementing the real deal.""" + c_train = training_data(res, final_split=True, text_len=text_len, crop='rs') # Use mostly same settings as training. + for name, split in [ + ('minitrain', 'train[:2%]'), + ('minival', 'train[-2%:]'), + ('eval', 'val'), + ]: + c.evals[f'cococap/{name}/pplx'] = dict( + type='proj.paligemma.perplexity', pred='logits', + key='text', shift_labels=True, + log_percent=0.05, # Eval ~20x per run; it's cheap. + data={**c_train.data, 'split': split}, + pp_fn=c_train.pp, + ) + + +def get_config(arg=None): + """Config for training.""" + c = bvcc.parse_arg(arg, mode='xm', crop='rs', res=224, beam_size=2, final_split=False) + + c.input = training_data(c.res, final_split=c.final_split, crop=c.crop) + + c.total_epochs = 5 # One epoch of captions (each image has 5 captions). + c.input.batch_size = 256 + c.optax_name = 'scale_by_adam' + c.optax = dict(b2=0.999) + c.lr = 1e-4 + c.wd = 0.0 + c.grad_clip_norm = 1.0 + c.label_smoothing = 0.0 + c.schedule = dict(decay_type='cosine', warmup_percent=0.05) + + # Add evaluators. + c.evals = {} + add_eval_pplx(c, c.res) + + if c.beam_size: + decode_kw = {'pred': 'beam_decode', 'pred_kw': {'beam_size': c.beam_size}} + else: + decode_kw = {} + + add_eval(c, c.res, batch_size=1024, **decode_kw) + + # Model section. + c.model_name = 'proj.paligemma.paligemma' + c.model = {} + c.model.img = dict(variant='So400m/14', pool_type='none', scan=True) + c.model.llm = dict(vocab_size=256_000 + 1024 + 128, dropout=0.0) + c.model_init = f'pt_{c.res}' + + # FSDP strategy. + c.mesh = [('data', -1)] + c.sharding_strategy = [('.*', 'fsdp(axis="data")')] + c.sharding_rules = [('act_batch', ('data',))] + + # These probably do not need any change/tuning + c.input.shuffle_buffer_size = 50_000 + c.log_training_steps = 50 + c.ckpt_steps = 1_000 + c.pp_modules = ['ops_general', 'ops_image', 'ops_text', 'proj.paligemma.ops'] + + # Update configs for quicker local runs and avoid swapping. + if c.mode in ('runlocal', 'mock'): + c.input.shuffle_buffer_size = None + for ev in c.evals.values(): + ev.data.split = ev.data.split.split('[')[0] + '[:16]' + + if c.mode == 'runlocal': + c.log_training_steps = 1 + c.input.batch_size = 2 + + c.seed = 0 + return c + + +def sweep_best(add, arg=None): + """Train with best hyper-params.""" + c = bvcc.parse_arg(arg, final_split=False) + # Note: wd=0.0 works as good. + add(lr=1e-5, wd=1e-6, total_epochs=5, **bvcc.arg(res=224, **c)) + add(lr=1e-5, wd=1e-6, total_epochs=5, **bvcc.arg(res=448, **c)) + + +sweep = sweep_best # Choose which sweep to run. + + +def metrics(arg=None): # pylint: disable=unused-argument + m = ['training_loss'] + for split in ('eval', 'minival', 'minitrain'): + m.append(('epoch', f'cococap/{split}/cider')) + m.append(('epoch', f'cococap/{split}/pplx/avg')) + return m diff --git a/big_vision/configs/proj/paligemma/transfers/common.py b/big_vision/configs/proj/paligemma/transfers/common.py new file mode 100644 index 0000000000000000000000000000000000000000..c584ff7f166d7eb9428dd0df039b65f4ecd7c2df --- /dev/null +++ b/big_vision/configs/proj/paligemma/transfers/common.py @@ -0,0 +1,65 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Common things across all transfer configs.""" + + +TOKENIZER = 'gemma(tokensets=("loc", "seg"))' + + +def tok(**kw): + """Creates the tokenization preprocessing string.""" + # Single entry point so that it's consistent everywhere and easier to switch. + kw.setdefault('model', TOKENIZER) + kw = ', '.join(f'{k}={repr(v)}' for k, v in kw.items()) + return f'tok({kw})' + + +def combine_and_keep_train(text_len, before=(), sep='\n'): + return '|'.join([ + *before, + tok(key='prefix', bos='yes'), + tok(key='suffix', eos='yes'), + tok(key='septok', text=sep), + # If masks confuse you, see (internal link) + 'masked_concat(["prefix", "septok", "suffix"], mask_ar=[0, 0, 1], mask_loss=[0, 0, 1])', # pylint: disable=line-too-long + # For training, we +1 since the trainer removes EOS. + f'tolen({text_len+1}, pad_value=0, key="text")', # Value doesn't matter. + f'tolen({text_len+1}, pad_value=1, key="mask_ar")', + f'tolen({text_len+1}, pad_value=0, key="mask_loss")', + 'keep("image", "text", "mask_ar", "mask_loss")', + ]) + + +def combine_and_keep_eval(text_len, keep=tuple(), before=(), sep='\n'): + return '|'.join([ + *before, + # Same as training, except that suffix is now the empty string. + # Meaning, we create text as [prefix separator pad], + # and the mask accordingly as [0 0 1] (with repeats of respective lengths) + tok(key='prefix', bos='yes'), + tok(key='septok', text=sep), + # At eval time, there can be also a suffix key in the data. If so it is + # tokenized without EOS and decoding will continue from it. + 'setdefault("suffix", "")', + tok(key='suffix', eos='no'), + # If masks confuse you, see (internal link) + 'masked_concat(["prefix", "septok", "suffix"], mask_ar=[0, 0, 1], mask_input=[1, 1, 1])', # pylint: disable=line-too-long + f'tolen({text_len}, pad_value=0, key="text")', # value doesn't matter. + f'tolen({text_len}, pad_value=1, key="mask_ar")', + f'tolen({text_len}, pad_value=0, key="mask_input")', + # And we need to keep everything that makes our evaluator happy. + 'keep(' + ', '.join(f'"{x}"' for x in ( + 'image', 'text', 'mask_ar', 'mask_input') + tuple(keep)) + ')', + ]) diff --git a/big_vision/configs/proj/paligemma/transfers/docvqa.py b/big_vision/configs/proj/paligemma/transfers/docvqa.py new file mode 100644 index 0000000000000000000000000000000000000000..2bb9ea9f660820d8b2c7255813560bc8dfab766a --- /dev/null +++ b/big_vision/configs/proj/paligemma/transfers/docvqa.py @@ -0,0 +1,163 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""PaliGemma transfer to docvqa. +""" + +import big_vision.configs.common as bvcc +from big_vision.configs.proj.paligemma.transfers.common import combine_and_keep_train, combine_and_keep_eval, TOKENIZER + + +def training_data(res, final_split, text_len=32): + """Creates training data config. + + See (internal link) + You can add more arguments beside `res`, but give them good defaults. + + Args: + res: The requested image resolution (eg 224) + final_split: Train on all train+val data. + text_len: sequence length + + Returns: + The ConfigDict for the input section. + """ + c = bvcc.parse_arg('') # Just make a configdict without extra import. + c.data = dict( + name='docvqa', + split='train+val' if final_split else 'train[:-5%]', + ) + c.pp = '|'.join([ + f'decode|resize({res}, antialias=True)|value_range(-1, 1)', + 'copy(inkey="question", outkey="prefix")', + 'choice(inkey="answers", outkey="suffix")', + combine_and_keep_train(text_len), + ]) + return c + + +def add_eval(c, res, text_len=32, **kw): + """Add eval configs.""" + pp_eval = '|'.join([ + f'decode|resize({res}, antialias=True)|value_range(-1, 1)', + 'copy(inkey="question", outkey="prefix")', + combine_and_keep_eval(text_len, keep=('answers', 'question_id')), + ]) + + for freq, name, split in [ + (1/8, 'minitrain', 'train[:5%]'), + (1/8, 'minival', 'train[-5%:]'), + (1/8, 'eval', 'val'), + (1.0, 'test', 'test'), + ]: + c.evals[f'docvqa/{name}'] = dict( + type='proj.paligemma.transfers.vqa', + pred='decode', pred_kw={'max_decode_len': text_len}, + to_lower=True, + outfile=f'{{workdir}}/docvqa_{name}.json', out_question_key='questionId', + data={**training_data(res, True, text_len).data, 'split': split}, + log_percent=freq, skip_first=freq == 1, tokenizer=TOKENIZER, pp_fn=pp_eval) + c.evals[f'docvqa/{name}'].update(kw) + + +def add_eval_pplx(c, res, text_len=32): + """Perplexity evaluator to test runs before implementing the real deal.""" + c_train = training_data(res, True, text_len) # Use mostly same settings as training. + for name, split in [ + ('minitrain', 'train[:5%]'), # To gauge memorization. + ('minival', 'train[-5%:]'), # To tune hparams. + ('eval', 'val'), # To compute final publishable scores. + ]: + c.evals[f'docvqa/{name}/pplx'] = dict( + type='proj.paligemma.perplexity', pred='logits', + key='text', shift_labels=True, + log_percent=0.05, # Eval ~20x per run; it's cheap. + data={**c_train.data, 'split': split}, + pp_fn=c_train.pp, + ) + + +def sweep_best(add, arg=None): + """Train with best hyper-params.""" + c = bvcc.parse_arg(arg, final_split=False) + # Based on http://(internal link)/ZCwPqz0b3tE and (internal link) + add(lr=1e-5, wd=1e-6, total_epochs=10, **bvcc.arg(res=224, **c)) + add(lr=1e-5, wd=1e-6, total_epochs=10, **bvcc.arg(res=448, **c)) + add(lr=1e-5, wd=1e-6, total_epochs=10, **bvcc.arg(res=896, **c)) + + +sweep = sweep_best # Choose which sweep to run. + + +def get_config(arg=None): + """Config for training.""" + c = bvcc.parse_arg(arg, mode='xm', res=896, final_split=False) + + c.input = training_data(c.res, c.final_split) + + # Instead of epochs, you can also use `total_examples` or `total_steps`. + c.total_epochs = 10 + c.input.batch_size = 256 + c.optax_name = 'scale_by_adam' + c.optax = dict(b2=0.999) + c.lr = 1e-5 + c.wd = 1e-6 + c.grad_clip_norm = 1.0 + c.label_smoothing = 0.0 + c.schedule = dict(decay_type='cosine', warmup_percent=0.05) + + # Add evaluators. + c.evals = {} + add_eval(c, c.res, batch_size=256) + add_eval_pplx(c, c.res) + + # Model section. + c.model_name = 'proj.paligemma.paligemma' + c.model = {} + c.model.img = dict(variant='So400m/14', pool_type='none', scan=True) + c.model.llm = dict(vocab_size=256_000 + 1024 + 128, dropout=0.0) + c.model_init = f'pt_{c.res}' + + # FSDP strategy. + c.mesh = [('data', -1)] + c.sharding_strategy = [('.*', 'fsdp(axis="data")')] + c.sharding_rules = [('act_batch', ('data',))] + + # These probably do not need any change/tuning + c.input.shuffle_buffer_size = 50_000 + c.log_training_steps = 50 + c.ckpt_steps = 1_000 + c.pp_modules = ['ops_general', 'ops_image', 'ops_text', 'proj.paligemma.ops'] + + # Update configs for quicker local runs and avoid swapping. + if c.mode in ('runlocal', 'mock'): + c.input.shuffle_buffer_size = None + for ev in c.evals.values(): + ev.data.split = ev.data.split.split('[')[0] + '[:16]' + + if c.mode == 'runlocal': + c.log_training_steps = 1 + c.input.batch_size = 2 + + c.seed = 0 + return c + + +def metrics(arg=None): # pylint: disable=unused-argument + m = ['training_loss'] + for split in ('eval', 'minival', 'minitrain'): + m.append(f'docvqa/{split}/anls') + m.append(f'docvqa/{split}/pplx/avg') + return m diff --git a/big_vision/configs/proj/paligemma/transfers/forkme.py b/big_vision/configs/proj/paligemma/transfers/forkme.py new file mode 100644 index 0000000000000000000000000000000000000000..3041fdaaa9c2a3218dafc42964a389a1a5196b47 --- /dev/null +++ b/big_vision/configs/proj/paligemma/transfers/forkme.py @@ -0,0 +1,151 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""Example config for finetuning PaliGemma to a task stored in the JSON-L file, designed to fit on four L4 GPU. + +Can be used as a starting point to finetune PaliGemma model. If you prefer to +use tfds-based data input, check out other transfer configs as examples. + +Command to run this config: + +``` +env BV_GEMMA_DIR=ckpts/ python -m big_vision.trainers.proj.paligemma.train \ + --config big_vision/configs/proj/paligemma/transfers/forkme.py \ + --workdir workdirs/`date '+%m-%d_%H%M'` +``` +""" + +import big_vision.configs.common as bvcc +from big_vision.configs.proj.paligemma.transfers.common import combine_and_keep_train, combine_and_keep_eval, TOKENIZER + + +def training_data(res, text_len): + """Creates training data config.""" + c = bvcc.parse_arg('') # Just make a configdict without extra import. + c.data = dict( + name='bv:jsonl', + fname='gs://longcap100/data_train90.jsonl', + fopen_keys={'image': 'gs://longcap100/'}, + # See docstring in datasets/jsonl.py for further details. + # download_keys=['image'], # If jsonl contains external paths. + ) + c.pp = '|'.join([ + # Read and prepare the image by just resizing it: + f'decode|resize({res}, antialias=True)|value_range(-1, 1)', + # The texts are already prepared in `prefix` and `suffix` keys. + 'strfmt("caption en", outkey="prefix")', + combine_and_keep_train(text_len), + ]) + # Keep the whole dataset in RAM after first pass. Useful optimization for + # small/mid-size datasets, but risks a host OOM for large datasets. + c.cache_raw = True + return c + + +def add_eval_pplx(c, res, text_len): + """Perplexity evaluator to test runs before implementing the real deal.""" + c_data = training_data(res, text_len) # Use mostly same settings as training. + c_data.pp = '|'.join([ + # Read and prepare the image by just resizing it: + f'decode|resize({res}, antialias=True)|value_range(-1, 1)', + # The texts are already prepared in `prefix` and `suffix` keys. + 'strfmt("caption en", outkey="prefix")', + combine_and_keep_eval(text_len), + ]) + + c.evals['val/pplx'] = dict( + type='proj.paligemma.perplexity', pred='logits', + key='text', shift_labels=True, + log_percent=1/10, + data={**c_data.data, + 'fname': 'gs://longcap100/data_val10.jsonl', + }, + pp_fn=c_data.pp + ) + + +def add_eval_store(c, res, text_len=32): + """Captioning evaluator with cider/bleu-4/meteor/rouge/spice metrics.""" + c_data = training_data(res, text_len) # Use mostly same settings as training. + c_data.pp = '|'.join([ + # Read and prepare the image by just resizing it: + f'decode|resize({res}, antialias=True)|value_range(-1, 1)', + # The texts are already prepared in `prefix` and `suffix` keys. + 'strfmt("caption en", outkey="prefix")', + combine_and_keep_eval(text_len, keep=('id',)), + ]) + + c.evals['val/store'] = dict( + type='proj.paligemma.transfers.storepreds', + pred='decode', pred_kw={'max_decode_len': text_len}, + log_percent=0.5, tokenizer=TOKENIZER, + data={**c_data.data, + 'fname': 'gs://longcap100/data_val10.jsonl', + }, + pp_fn=c_data.pp, + ) + + +def get_config(arg=None): + """Config for training.""" + # You probably do NOT want to add settings here. The `arg` way of settings is + # really only for things you'd want to sweep and which affect MULTIPLE config + # settings at once or go into the pp string. + c = bvcc.parse_arg(arg, res=224, text_len=128, batch_size=32, + freeze_vit=False, freeze_llm=False, + run_local=False) + + c.input = training_data(c.res, c.text_len) + + # Instead of epochs, you can also use `total_examples` or `total_steps`. + c.total_epochs = 15 + c.input.batch_size = c.batch_size + c.optax_name = 'scale_by_adam' + c.lr = 1e-5 + c.wd = 3e-7 + c.grad_clip_norm = 1.0 + c.label_smoothing = 0.0 + + # Learning-rate schedule. Probably is fine like this. + sched = dict(decay_type='cosine', warmup_percent=0.05) + c.schedule = [ + ('img/.*', None if c.freeze_vit else sched), + ('llm/.*', None if c.freeze_llm else sched), + ] + + c.evals = {} + add_eval_pplx(c, c.res, c.text_len) + add_eval_store(c, c.res, c.text_len) + + # Model section. + c.model_name = 'proj.paligemma.paligemma' + c.model = {} + c.model.img = dict(variant='So400m/14', pool_type='none', scan=True) + c.model.llm = dict(vocab_size=256_000 + 1024 + 128, dropout=0.0) + c.model_init = f'pt_{c.res}' + + # FSDP strategy. + c.mesh = [('data', -1)] + c.sharding_strategy = [('.*', 'fsdp(axis="data")')] + c.sharding_rules = [('act_batch', ('data',))] + + c.input.shuffle_buffer_size = 1000 + c.log_training_steps = 1 + c.ckpt_steps = 200 + c.pp_modules = ['ops_general', 'ops_image', 'ops_text', 'proj.paligemma.ops'] + + c.seed = 0 + + return c diff --git a/big_vision/configs/proj/paligemma/transfers/gqa.py b/big_vision/configs/proj/paligemma/transfers/gqa.py new file mode 100644 index 0000000000000000000000000000000000000000..531b0964d581c5dac68252cffae4c5eb44f3cffb --- /dev/null +++ b/big_vision/configs/proj/paligemma/transfers/gqa.py @@ -0,0 +1,197 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""PaliGemma transfer to GQA (https://arxiv.org/abs/1902.09506). +""" + +import big_vision.configs.common as bvcc +from big_vision.configs.proj.paligemma.transfers.common import combine_and_keep_train, combine_and_keep_eval, TOKENIZER + +XGQA_LANGUAGES = ('bn', 'de', 'en', 'id', 'ko', 'pt', 'ru', 'zh') + + +def training_data(res, *, final_split, prefix, text_len=32): + """Creates training data config. + + See (internal link) + You can add more arguments beside `res`, but give them good defaults. + + Args: + res: The requested image resolution (eg 224). + final_split: Whether to train on train+val. + prefix: The prefix to use for the input. E.g. "answer en {question}" + text_len: sequence length. + + Returns: + The ConfigDict for the input section. + """ + c = bvcc.parse_arg('') # Just make a configdict without extra import. + c.data = dict( + name='gqa', + split='train_balanced+val_balanced' if final_split else 'train_balanced', + ) + c.pp = '|'.join([ + f'decode|resize({res})|value_range(-1, 1)', + f'strfmt("{prefix}", outkey="prefix")', + 'copy(inkey="answer", outkey="suffix")', + combine_and_keep_train(text_len), + ]) + return c + + +def add_eval(c, res, *, text_len=32, prefix, **kw): + """GQA evaluators.""" + c_train = training_data(res, final_split=True, prefix=prefix, text_len=text_len) + + pp = '|'.join([ + f'decode|resize({res})|value_range(-1, 1)', + 'copy(inkey="example_id", outkey="question_id")', + # GQA: both questions and answers are always in english. + # xGQA: questions in different languages. Answers always in english. + f'strfmt("{prefix}", outkey="prefix")', + combine_and_keep_eval(text_len, keep=('answer', 'question_id')), + ]) + + for freq, name, split, skip_first in [ + # TODO: adjust the proportion of dataset seen in these minivals + # based speed on hardware. + (1/8, 'minitrain', 'train_balanced[:10000]', False), # To gauge memorization. + (1/8, 'val_balanced', 'val_balanced', True), # To tune hparams. + (1.0, 'testdev_balanced', 'testdev_balanced', True), # To compute final publishable scores. + ]: + c.evals[f'gqa/{name}/decode'] = dict( + type='proj.paligemma.transfers.vqa', + pred='decode', pred_kw={'max_decode_len': text_len}, + outfile=f'{{workdir}}/gqa_{name}.json', + out_question_key='question_id', out_answer_key='prediction', + data={**c_train.data, 'split': split}, + log_percent=freq, skip_first=skip_first, tokenizer=TOKENIZER, pp_fn=pp) + c.evals[f'gqa/{name}/decode'].update(kw) + + # Add XGQA evaluators. Zero shot since the model is trained only in GQA (en). + for lang in XGQA_LANGUAGES: + c.evals[f'xgqa/test_zs_{lang}/decode'] = dict( + type='proj.paligemma.transfers.vqa', + pred='decode', pred_kw={'max_decode_len': text_len}, + outfile=f'{{workdir}}/xgqa_test_{lang}.json', + data=dict( + name='xgqa', + split=f'test_zs_{lang}', # Zero-shot split + ), + log_percent=1/8, tokenizer=TOKENIZER, pp_fn=pp) + c.evals[f'xgqa/test_zs_{lang}/decode'].update(kw) + + +def add_eval_pplx(c, res, *, text_len=32, prefix): + """Perplexity evaluator to test runs before implementing the real deal.""" + c_train = training_data(res, final_split=True, text_len=text_len, prefix=prefix) + for name, split in [ + ('minitrain', 'train_balanced[:5%]'), # To gauge memorization. + ('minival', 'val_balanced[:5%]'), # To tune hparams. + ]: + c.evals[f'gqa/{name}/pplx'] = dict( + type='proj.paligemma.perplexity', pred='logits', + key='text', shift_labels=True, + log_percent=0.05, # Eval ~20x per run; it's cheap. + data={**c_train.data, 'split': split}, + pp_fn=c_train.pp, + ) + + +def sweep_best(add, arg=None): + """Train with best hyper-params.""" + c = bvcc.parse_arg(arg, final_split=False) + # Based on (internal link), (internal link), (internal link). + # TODO: Is there a more compreensive sweep and can we use + # freeze_vit=False for all resolutions (and more common in other configs)? + add(lr=1e-5, wd=0.0, **bvcc.arg(res=224, freeze_vit=False, **c)) + add(lr=1e-5, wd=0.0, **bvcc.arg(res=448, freeze_vit=True, **c)) + # Not better: add(lr=1e-5, wd=0.0, **bvcc.arg(res=896, freeze_vit=True, **c)) + + +sweep = sweep_best # Choose which sweep to run. + + +def get_config(arg=None): + """Config for training.""" + c = bvcc.parse_arg(arg, mode='xm', res=224, final_split=False, + freeze_vit=True, freeze_llm=False, + prefix='answer en {question}') + + c.name = '' + c.input = training_data(c.res, final_split=c.final_split, prefix=c.prefix) + + # Instead of epochs, you can also use `total_examples` or `total_steps`. + c.total_epochs = 1 + c.input.batch_size = 256 + c.optax_name = 'scale_by_adam' + c.optax = dict(b2=0.999) + c.lr = 1e-5 + c.wd = 0.0 + c.grad_clip_norm = 1.0 + c.label_smoothing = 0.0 + + # Learning-rate schedule. Probably is fine like this. + sched = dict(decay_type='cosine', warmup_percent=0.05) + c.schedule = [ + ('img/.*', None if c.freeze_vit else sched), + ('llm/.*', None if c.freeze_llm else sched), + ] + + # Add evaluators. + c.evals = {} + add_eval(c, c.res, batch_size=1024, prefix=c.prefix) + add_eval_pplx(c, c.res, prefix=c.prefix) + + # Model section. + c.model_name = 'proj.paligemma.paligemma' + c.model = {} + c.model.img = dict(variant='So400m/14', pool_type='none', scan=True) + c.model.llm = dict(vocab_size=256_000 + 1024 + 128, dropout=0.0) + c.model_init = f'pt_{c.res}' + + # FSDP strategy. + c.mesh = [('data', -1)] + c.sharding_strategy = [('.*', 'fsdp(axis="data")')] + c.sharding_rules = [('act_batch', ('data',))] + + # These probably do not need any change/tuning + c.input.shuffle_buffer_size = 50_000 + c.log_training_steps = 50 + c.ckpt_steps = 1_000 + c.pp_modules = ['ops_general', 'ops_image', 'ops_text', 'proj.paligemma.ops'] + + # Update configs for quicker local runs and avoid swapping. + if c.mode in ('runlocal', 'mock'): + c.input.shuffle_buffer_size = None + for ev in c.evals.values(): + ev.data.split = ev.data.split.split('[')[0] + '[:16]' + + if c.mode == 'runlocal': + c.log_training_steps = 1 + c.input.batch_size = 2 + + c.seed = 0 + return c + + +def metrics(): + m = ['training_loss'] + m.append('gqa/minitrain/pplx/avg') + m.append('gqa/minival/pplx/avg') + m.append('gqa/minitrain/decode/acc') + m.append('gqa/val_balanced/decode/acc') + m.append('gqa/testdev_balanced/decode/acc') + return m diff --git a/big_vision/configs/proj/paligemma/transfers/infovqa.py b/big_vision/configs/proj/paligemma/transfers/infovqa.py new file mode 100644 index 0000000000000000000000000000000000000000..a9ea80adfbcc1acb82c9307f107f110bd648f896 --- /dev/null +++ b/big_vision/configs/proj/paligemma/transfers/infovqa.py @@ -0,0 +1,172 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""PaliGemma transfer to InfoVQA. +""" + +import big_vision.configs.common as bvcc +from big_vision.configs.proj.paligemma.transfers.common import combine_and_keep_train, combine_and_keep_eval, TOKENIZER + + +def training_data(res, final_split, text_len=48): + """Creates training data config. + + See Colab: + http://(internal link)#scrollTo=0yIyusCLhdDy + + You can add more arguments beside `res`, but give them good defaults. + + Args: + res: The requested image resolution (eg 224). + final_split: Whether to use all train data. + text_len: The maximum text length (in tokens). + + Returns: + The ConfigDict for the input section. + """ + c = bvcc.parse_arg('') # Just make a configdict without extra import. + c.data = dict( + name='infovqa', + split='train+val' if final_split else 'train', + ) + c.pp = '|'.join([ + f'decode|resize({res}, antialias=True)|value_range(-1, 1)', + 'lower(key="answers")', + 'strfmt("answer en {question}", outkey="prefix")', + 'choice_no_replacement(inkey="answers", outkey="suffix")', + combine_and_keep_train(text_len), + ]) + return c + + +def add_eval(c, res, text_len=48, **kw): + """VQA evaluators.""" + pp = '|'.join([ + f'decode|resize({res}, antialias=True)|value_range(-1, 1)', + 'lower(key="answers")', + 'strfmt("answer en {question}", outkey="prefix")', + combine_and_keep_eval(text_len, keep=('answers', 'question_id')), + ]) + + for freq, name, split in [ + (0.1, 'minitrain', 'train[:5%]'), # To gauge memorization. ~1.2k samples. + (0.1, 'minival', 'val'), # To tune hparams. ~1.2k samples. + (1.0, 'test', 'test'), # Also stores predictions for the test set (3.3k). + ]: + c.evals[f'infovqa/{name}'] = dict( + type='proj.paligemma.transfers.vqa', + pred='decode', pred_kw={'max_decode_len': text_len}, + # Answers in reference evaluation are converted to lower case. + to_lower=True, + # Test server expects two fields: 'questionId' and 'answer'. + out_question_key='questionId', out_answer_key='answer', + outfile=f'{{workdir}}/infovqa_{name}.json', + data={**training_data(res, False, text_len).data, 'split': split}, + log_percent=freq, skip_first=freq == 1, tokenizer=TOKENIZER, pp_fn=pp) + c.evals[f'infovqa/{name}'].update(kw) + + +def add_eval_pplx(c, res, text_len=48): + """Perplexity evaluator to test runs before implementing the real deal.""" + c_train = training_data(res, False, text_len) # Use mostly same settings as training. + for name, split in [ + ('minitrain', 'train[:5%]'), # To gauge memorization. + ('minival', 'val'), # To tune hparams. + ]: + c.evals[f'infovqa/{name}/pplx'] = dict( + type='proj.paligemma.perplexity', pred='logits', + key='text', shift_labels=True, + log_percent=0.05, # Eval ~20x per run; it's cheap. + data={**c_train.data, 'split': split}, + pp_fn=c_train.pp, + ) + + +def sweep_best(add, arg=None): + """Train with best hyper-params.""" + # Based on http://(internal link)/ZCwPqz0b3tE and (internal link) + # In InfoVQA: max prefix: 51, suffix: 42, prefix+sep+suffix: 60. + c = bvcc.parse_arg(arg, mode='xm', final_split=False) + add(lr=1e-5, wd=1e-6, total_epochs=3, **bvcc.arg(res=224, **c)) + add(lr=1e-5, wd=1e-6, total_epochs=3, **bvcc.arg(res=448, **c)) + add(lr=3e-6, wd=3e-7, total_epochs=3, **bvcc.arg(res=896, **c)) + + +sweep = sweep_best # Choose which sweep to run. + + +def get_config(arg=None): + """Config for training.""" + c = bvcc.parse_arg(arg, mode='xm', res=896, text_len=48, final_split=False) + + c.input = training_data(c.res, c.final_split, c.text_len) + + # Instead of epochs, you can also use `total_examples` or `total_steps`. + c.total_epochs = 3 + c.input.batch_size = {224: 256, 448: 128, 896: 32}[c.res] + c.optax_name = 'scale_by_adam' + c.optax = dict(b2=0.999) + c.lr = 3e-6 + c.wd = 3e-7 + c.grad_clip_norm = 1.0 + c.label_smoothing = 0.4 + c.schedule = dict(decay_type='cosine', warmup_percent=0.05) + + # Add evaluators. + c.evals = {} + add_eval(c, c.res, c.text_len, batch_size=256) + add_eval_pplx(c, c.res, c.text_len) + + # Model section. + c.model_name = 'proj.paligemma.paligemma' + c.model = {} + c.model.img = dict(variant='So400m/14', pool_type='none', scan=True) + c.model.llm = dict(vocab_size=256_000 + 1024 + 128, dropout=0.0) + c.model_init = f'pt_{c.res}' + + # FSDP strategy. + c.mesh = [('data', -1)] + c.sharding_strategy = [('.*', 'fsdp(axis="data")')] + c.sharding_rules = [('act_batch', ('data',))] + + # These probably do not need any change/tuning + c.input.shuffle_buffer_size = 50_000 + c.log_training_steps = 50 + c.ckpt_steps = 1_000 + c.pp_modules = ['ops_general', 'ops_image', 'ops_text', 'proj.paligemma.ops'] + + # Update configs for quicker local runs and avoid swapping. + if c.mode in ('runlocal', 'mock'): + c.input.shuffle_buffer_size = None + for ev in c.evals.values(): + ev.data.split = ev.data.split.split('[')[0] + '[:16]' + + if c.mode == 'runlocal': + c.log_training_steps = 1 + c.input.batch_size = 2 + del c.total_epochs + c.total_steps = 10 + + c.seed = 0 + return c + + +def metrics(arg=None): # pylint: disable=unused-argument + m = ['training_loss'] + for split in ('minival', 'minitrain'): + m.append(f'infovqa/{split}/anls') + m.append(f'infovqa/{split}/acc') + m.append(f'infovqa/{split}/pplx/avg') + return m diff --git a/big_vision/configs/proj/paligemma/transfers/msrvtt_cap.py b/big_vision/configs/proj/paligemma/transfers/msrvtt_cap.py new file mode 100644 index 0000000000000000000000000000000000000000..6b81ae293deb99c85aac40366f41f83a555f2fed --- /dev/null +++ b/big_vision/configs/proj/paligemma/transfers/msrvtt_cap.py @@ -0,0 +1,210 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""PaliGemma transfer to MSRVTT Video captioning. + +IMPORTANT: This config is based on an unreleased version of DeepMind Video +Readers (DMVR). Users can either set up DMVR using the open source code from +GitHub (see below for details), or add their own data loader of choice. +""" + +import big_vision.configs.common as bvcc +from big_vision.configs.proj.paligemma.transfers.common import combine_and_keep_train, combine_and_keep_eval, TOKENIZER + +TEXT_LEN = 32 +DATASET_NAME = 'msrvtt' +# Numbers might need to be updated due to wipeout. Current from 2024-04-28 +SPLIT_SIZE = {'train': 4663, 'valid': 316, 'test': 2094, 'test_v2': 2271} + + +def training_data(res, *, final_split, num_frames, stride): + """Creates training data config. + + Args: + res: The requested image resolution (eg 224). + final_split: Train on all train+valid data. + num_frames: number of sampled frames per video. + stride: stride at which the frames are sampled. + + Returns: + The ConfigDict for the input section. + """ + pp = '|'.join([ + # prepare the frames by decoding, resizing, replicating, sampling: + f'video_decode({res})|video_replicate_img({num_frames},{num_frames})', + f'video_ensure_shape("image", {(num_frames, res, res, 3)})', + # pick one caption at random during training + 'strfmt("caption en", outkey="prefix")', + 'video_choice(inkey="caption/string", outkey="suffix")', + combine_and_keep_train(TEXT_LEN), + ]) + + c = bvcc.parse_arg('') + c.data = {} + splits = ['train', 'valid'] if final_split else ['train'] + raise NotImplementedError('Please implement a video reader of choice!') + # For example DMVR https://github.com/google-deepmind/dmvr + # The reader should support the following arguments: + # - name: Name of the reader. + # - dataset_name: Name of the data set. + # - split: Data set split. + # - num_frames: Number of frames sampled from the video. + # - stride: Stride at which the video frames are sampled. + # - deterministic_fs: Whether to sample the frames starting at the first + # frame or whether an offest should be chosen at random (if there are more + # frames than num_frames * stride) + # - first_k_shards: Whether to only use the first k shards of the data + # (optional but useful for speeding up intermediate evaluations). + for split in splits: + c.data[split] = SPLIT_SIZE[split] + c[split] = {'pp': pp} + c[split].data = dict( + # PLEASE ADD YOUR READER HERE: + name='', + dataset_name=DATASET_NAME, split=split, + num_frames=num_frames, stride=stride, + deterministic_fs=False) + return c + + +def add_eval(c, res, num_frames, stride): # pylint: disable=unused-argument + """Captioning evaluator.""" + c_train = training_data(res, final_split=True, num_frames=num_frames, stride=stride) + + pp = '|'.join([ + f'video_decode({res})|video_replicate_img({num_frames},{num_frames})', + f'video_ensure_shape("image", {(num_frames, res, res, 3)})', + 'strfmt("caption en", outkey="prefix")', + 'strfmt("{example/video_id}[{clip/start/timestamp}-{clip/end/timestamp}]", outkey="image/id")', + 'copy("caption/string", "captions")', + combine_and_keep_eval(TEXT_LEN, keep=('image/id', 'captions')), + ]) + + for freq, name, split, first_k_shards, skip_first_eval in [ + (1/8, 'minitrain', 'train', 2, False), # To gauge memorization. + (1/4, 'minival', 'valid', 2, False), # To monitor val progress. + (1, 'val', 'valid', None, False), # To tune hparams. + (1, 'eval', 'test', None, False), # final metric + ]: + c.evals[f'msrvtt/{name}'] = dict( + type='proj.paligemma.transfers.coco_caption', + pred='decode', pred_kw={'max_decode_len': TEXT_LEN}, + data={**c_train.train.data, 'split': split, + 'first_k_shards': first_k_shards, + 'deterministic_fs': True}, + log_percent=freq, tokenizer=TOKENIZER, + pp_fn=pp, skip_first=skip_first_eval) + + +def add_eval_pplx(c, res, num_frames, stride): + """Perplexity evaluator to test runs before implementing the real deal.""" + c_train = training_data(res, final_split=True, num_frames=num_frames, stride=stride) + + for name, split, first_k_shards in [ + ('minitrain', 'train', 2), # To gauge memorization. + ('minival', 'valid', 2), + ]: + c.evals[f'msrvtt/{name}/pplx'] = dict( + type='proj.paligemma.perplexity', pred='logits', + key='text', shift_labels=True, + log_percent=1/8, # Not too cheap, do 10x per run. + data={**c_train.train.data, 'split': split, + 'first_k_shards': first_k_shards, + 'deterministic_fs': True}, + pp_fn=c_train.train.pp, + ) + + +def sweep_best(add, arg=None): + """Train with best hyper-params.""" + c = bvcc.parse_arg(arg, final_split=False) + add(lr=1e-5, wd=0.0, total_epochs=20, **bvcc.arg(res=224, freeze_vit=True, **c)) + + +sweep = sweep_best + + +def get_config(arg=None): + """Config for training.""" + c = bvcc.parse_arg(arg, mode='xm', num_frames=16, stride=9, res=224, + freeze_vit=False, freeze_llm=False, final_split=False) + + c.input = training_data( + c.res, final_split=c.final_split, + num_frames=c.num_frames, stride=c.stride) + + c.total_epochs = 10 + c.input.batch_size = 128 + c.optax_name = 'scale_by_adam' + c.optax = dict(b2=0.999) + c.lr = 1e-5 + c.wd = 1e-6 + c.grad_clip_norm = 1.0 + c.label_smoothing = 0.0 + + # Learning-rate schedule. + sched = dict(decay_type='cosine', warmup_percent=0.05) + c.schedule = [ + ('img/.*', None if c.freeze_vit else sched), + ('llm/.*', None if c.freeze_llm else sched), + ] + + # Add evaluators. + c.evals = {} + add_eval(c, c.res, c.num_frames, c.stride) + add_eval_pplx(c, c.res, c.num_frames, c.stride) + + # Model section. + c.model_name = 'proj.paligemma.paligemma' + c.model = {} + c.model.img = dict(variant='So400m/14', pool_type='none', scan=True) + c.model.llm = dict(vocab_size=256_000 + 1024 + 128, dropout=0.0) + c.model_init = f'pt_{c.res}' + + # FSDP strategy. + c.mesh = [('data', -1)] + c.sharding_strategy = [('.*', 'fsdp(axis="data")')] + c.sharding_rules = [('act_batch', ('data',))] + + for split in c.input.data.keys(): + c.input[split].shuffle_buffer_size = 10_000 + c.log_training_steps = 50 + c.ckpt_steps = 1_000 + c.pp_modules = ['ops_general', 'ops_image', 'ops_text', 'proj.paligemma.ops', + 'proj.paligemma.video'] + + # Update configs for quicker local runs and avoid swapping. + if c.mode in ('runlocal', 'mock'): + for split in c.input.data.keys(): + c.input[split].shuffle_buffer_size = None + for ev in c.evals.values(): + ev.data.first_k_shards = 1 + + if c.mode == 'runlocal': + c.log_training_steps = 1 + c.input.batch_size = 2 + + c.seed = 0 + return c + + +def metrics(arg=None): # pylint: disable=unused-argument + m = ['training_loss'] + for split in ('minitrain', 'minival', 'val', 'eval'): + m.append(('epoch', f'{DATASET_NAME}/{split}/cider')) + for split in ('minitrain', 'minival'): + m.append(('epoch', f'{DATASET_NAME}/{split}/pplx/avg')) + return m + diff --git a/big_vision/configs/proj/paligemma/transfers/msrvtt_qa.py b/big_vision/configs/proj/paligemma/transfers/msrvtt_qa.py new file mode 100644 index 0000000000000000000000000000000000000000..08df9c56245b27db1d31a8f07de1b5967381e78c --- /dev/null +++ b/big_vision/configs/proj/paligemma/transfers/msrvtt_qa.py @@ -0,0 +1,213 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""PaliGemma transfer to MSRVTT Video QA. + +IMPORTANT: This config is based on an unreleased version of DeepMind Video +Readers (DMVR). Users can either set up DMVR using the open source code from +GitHub (see below for details), or add their own data loader of choice. +""" + +import big_vision.configs.common as bvcc +from big_vision.configs.proj.paligemma.transfers.common import combine_and_keep_train, combine_and_keep_eval, TOKENIZER + +TEXT_LEN = 64 +DATASET_NAME = 'msrvtt_qa' +# Numbers might need to be updated due to wipeout. Current from 2024-04-28 +SPLIT_SIZE = {'train': 114080, 'valid': 7936, 'test': 51680} + + +def training_data(res, *, final_split, num_frames, stride): + """Creates training data config. + + Args: + res: The requested image resolution (eg 224). + final_split: Train on all train+valid data. + num_frames: number of sampled frames per video. + stride: stride at which the frames are sampled. + + Returns: + The ConfigDict for the input section. + """ + pp = '|'.join([ + # prepare the frames by decoding, resizing, replicating, sampling: + f'video_decode({res})|video_replicate_img({num_frames},{num_frames})', + f'video_ensure_shape("image", {(num_frames, res, res, 3)})', + # only one question/answer per example. + 'reshape([], key="question")|reshape([], key="answer")', + 'strfmt("answer en {question}", outkey="prefix")', + 'copy("answer", "suffix")', + combine_and_keep_train(TEXT_LEN), + ]) + + c = bvcc.parse_arg('') + c.data = {} + splits = ['train', 'valid'] if final_split else ['train'] + raise NotImplementedError('Please implement a video reader of choice!') + # For example DMVR https://github.com/google-deepmind/dmvr + # The reader should support the following arguments: + # - name: Name of the reader. + # - dataset_name: Name of the data set. + # - split: Data set split. + # - num_frames: Number of frames sampled from the video. + # - stride: Stride at which the video frames are sampled. + # - deterministic_fs: Whether to sample the frames starting at the first + # frame or whether an offest should be chosen at random (if there are more + # frames than num_frames * stride) + # - first_k_shards: Whether to only use the first k shards of the data + # (optional but useful for speeding up intermediate evaluations). + for split in splits: + c.data[split] = SPLIT_SIZE[split] + c[split] = {'pp': pp} + c[split].data = dict( + # PLEASE ADD YOUR READER HERE: + name='', + dataset_name=DATASET_NAME, split=split, + num_frames=num_frames, stride=stride, + deterministic_fs=False) + return c + + +def add_eval(c, res, num_frames, stride): # pylint: disable=unused-argument + """QA evaluator.""" + c_train = training_data(res, final_split=True, num_frames=num_frames, stride=stride) + + pp = '|'.join([ + # prepare the frames by decoding, resizing, replicating, sampling: + f'video_decode({res})|video_replicate_img({num_frames},{num_frames})', + f'video_ensure_shape("image", {(num_frames, res, res, 3)})', + # only one question/answer per example. + 'reshape([], key="question")|reshape([], key="answer")', + 'strfmt("answer en {question}", outkey="prefix")', + 'strfmt("{id}#{example/video_id}: {question}", "question_id")', + combine_and_keep_eval(TEXT_LEN, keep=('question_id', 'answer')), + ]) + + for freq, name, split, first_k_shards, skip_first_eval in [ + (1/8, 'minitrain', 'train', 2, False), # To gauge memorization. + (1/4, 'minival', 'valid', 2, False), # To monitor val progress. + (1, 'val', 'valid', None, False), # To tune hparams. + (1, 'eval', 'test', None, False), # final metric + ]: + c.evals[f'msrvtt_qa/{name}'] = dict( + type='proj.paligemma.transfers.vqa', + pred='decode', pred_kw={'max_decode_len': TEXT_LEN}, + data={**c_train.train.data, 'split': split, + 'first_k_shards': first_k_shards, + 'deterministic_fs': True}, + log_percent=freq, tokenizer=TOKENIZER, + pp_fn=pp, skip_first=skip_first_eval) + + +def add_eval_pplx(c, res, num_frames, stride): + """Perplexity evaluator to test runs before implementing the real deal.""" + c_train = training_data(res, final_split=True, num_frames=num_frames, stride=stride) + + for name, split, first_k_shards in [ + ('minitrain', 'train', 2), # To gauge memorization. + ('minival', 'valid', 2), # To tune hparams. + ]: + c.evals[f'msrvtt_qa/{name}/pplx'] = dict( + type='proj.paligemma.perplexity', pred='logits', + key='text', shift_labels=True, + log_percent=1/8, # Not too cheap, do 10x per run. + data={**c_train.train.data, 'split': split, + 'first_k_shards': first_k_shards, + 'deterministic_fs': True}, + pp_fn=c_train.train.pp, + ) + + +def sweep_best(add, arg=None): + """Train with best hyper-params.""" + c = bvcc.parse_arg(arg, final_split=False) + add(lr=1e-5, wd=0.0, total_epochs=1, **bvcc.arg(freeze_vit=True, res=224, **c)) + + +sweep = sweep_best + + +def get_config(arg=None): + """Config for training.""" + c = bvcc.parse_arg(arg, mode='xm', num_frames=16, stride=9, res=224, + freeze_vit=True, freeze_llm=False, final_split=False) + + c.input = training_data( + c.res, final_split=c.final_split, + num_frames=c.num_frames, stride=c.stride) + + c.total_epochs = 1 + c.input.batch_size = 128 + c.optax_name = 'scale_by_adam' + c.optax = dict(b2=0.999) + c.lr = 0.00001 + c.wd = 0.0 + c.grad_clip_norm = 1.0 + c.label_smoothing = 0.0 + + # Learning-rate schedule. + sched = dict(decay_type='cosine', warmup_percent=0.05) + c.schedule = [ + ('img/.*', None if c.freeze_vit else sched), + ('llm/.*', None if c.freeze_llm else sched), + ] + + # Add evaluators. + c.evals = {} + add_eval(c, c.res, c.num_frames, c.stride) + add_eval_pplx(c, c.res, c.num_frames, c.stride) + + # Model section. + c.model_name = 'proj.paligemma.paligemma' + c.model = {} + c.model.img = dict(variant='So400m/14', pool_type='none', scan=True) + c.model.llm = dict(vocab_size=256_000 + 1024 + 128, dropout=0.0) + c.model_init = f'pt_{c.res}' + + # FSDP strategy. + c.mesh = [('data', -1)] + c.sharding_strategy = [('.*', 'fsdp(axis="data")')] + c.sharding_rules = [('act_batch', ('data',))] + + for split in c.input.data.keys(): + c.input[split].shuffle_buffer_size = 10_000 + c.log_training_steps = 50 + c.ckpt_steps = 1_000 + c.pp_modules = ['ops_general', 'ops_image', 'ops_text', 'proj.paligemma.ops', + 'proj.paligemma.video'] + + # Update configs for quicker local runs and avoid swapping. + if c.mode in ('runlocal', 'mock'): + for split in c.input.data.keys(): + c.input[split].shuffle_buffer_size = None + for ev in c.evals.values(): + ev.data.first_k_shards = 1 + + if c.mode == 'runlocal': + c.log_training_steps = 1 + c.input.batch_size = 2 + + c.seed = 0 + return c + + +def metrics(arg=None): # pylint: disable=unused-argument + m = ['training_loss'] + for split in ('minitrain', 'minival', 'val', 'eval'): + m.append(('epoch', f'{DATASET_NAME}/{split}/acc')) + for split in ('minitrain', 'minival'): + m.append(('epoch', f'{DATASET_NAME}/{split}/pplx/avg')) + return m + diff --git a/big_vision/configs/proj/paligemma/transfers/msvd_qa.py b/big_vision/configs/proj/paligemma/transfers/msvd_qa.py new file mode 100644 index 0000000000000000000000000000000000000000..0832ad7185c714056bb39c7390c711a5e03d3b48 --- /dev/null +++ b/big_vision/configs/proj/paligemma/transfers/msvd_qa.py @@ -0,0 +1,214 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""PaliGemma transfer to MSVD Video QA. + +IMPORTANT: This config is based on an unreleased version of DeepMind Video +Readers (DMVR). Users can either set up DMVR using the open source code from +GitHub (see below for details), or add their own data loader of choice. +""" + +import big_vision.configs.common as bvcc +from big_vision.configs.proj.paligemma.transfers.common import combine_and_keep_train, combine_and_keep_eval, TOKENIZER + +TEXT_LEN = 64 +DATASET_NAME = 'msvd_qa' +# Numbers might need to be updated due to wipeout. Current from 2024-04-28 +SPLIT_SIZE = {'train': 24670, 'valid': 5107, 'test': 10136} # 2024-04-28 + + +def training_data(res, *, final_split, num_frames=8, stride=None): + """Creates training data config. + + Args: + res: The requested image resolution (eg 224). + final_split: Train on all train+valid data. + num_frames: number of sampled frames per video. + stride: stride at which the frames are sampled. + + Returns: + The ConfigDict for the input section. + """ + pp = '|'.join([ + # prepare the frames by decoding, resizing, replicating, sampling: + f'video_decode({res})|video_replicate_img({num_frames},{num_frames})', + # ensure_shape is needed for tf to figure out the tensor shape + f'video_ensure_shape("image", {(num_frames, res, res, 3)})', + # only one question/answer per example. + 'reshape([], key="question")|reshape([], key="answer")', + 'strfmt("answer en {question}", outkey="prefix")', + 'copy("answer", "suffix")', + combine_and_keep_train(TEXT_LEN), + ]) + + c = bvcc.parse_arg('') + c.data = {} + splits = ['train', 'valid'] if final_split else ['train'] + raise NotImplementedError('Please implement a video reader of choice!') + # For example DMVR https://github.com/google-deepmind/dmvr + # The reader should support the following arguments: + # - name: Name of the reader. + # - dataset_name: Name of the data set. + # - split: Data set split. + # - num_frames: Number of frames sampled from the video. + # - stride: Stride at which the video frames are sampled. + # - deterministic_fs: Whether to sample the frames starting at the first + # frame or whether an offest should be chosen at random (if there are more + # frames than num_frames * stride) + # - first_k_shards: Whether to only use the first k shards of the data + # (optional but useful for speeding up intermediate evaluations). + for split in splits: + c.data[split] = SPLIT_SIZE[split] + c[split] = {'pp': pp} + c[split].data = dict( + # PLEASE ADD YOUR READER HERE: + name='', + dataset_name=DATASET_NAME, split=split, + num_frames=num_frames, stride=stride, + deterministic_fs=False) + return c + + +def add_eval(c, res, num_frames=8, stride=None): + """QA evaluator.""" + c_train = training_data(res, final_split=True, num_frames=num_frames, stride=stride) + + pp = '|'.join([ + # prepare the frames by decoding, resizing, replicating, sampling: + f'video_decode({res})|video_replicate_img({num_frames},{num_frames})', + f'video_ensure_shape("image", {(num_frames, res, res, 3)})', + # only one question/answer per example. + 'reshape([], key="question")|reshape([], key="answer")', + 'strfmt("answer en {question}", outkey="prefix")', + 'strfmt("{id}#{example/video_id}[{clip/start/timestamp}-{clip/end/timestamp}]: {question}", outkey="question_id")', + combine_and_keep_eval(TEXT_LEN, keep=('question_id', 'answer')), + ]) + + for freq, name, split, first_k_shards, skip_first_eval in [ + (1/8, 'minitrain', 'train', 1, False), # To gauge memorization. + (1/4, 'minival', 'valid', 1, False), # To monitor val progress. + (1, 'val', 'valid', None, False), # To tune hparams. + (1, 'eval', 'test', None, False), # final metric + ]: + c.evals[f'{DATASET_NAME}/{name}'] = dict( + type='proj.paligemma.transfers.vqa', + pred='decode', pred_kw={'max_decode_len': TEXT_LEN}, + data={**c_train.train.data, 'split': split, + 'first_k_shards': first_k_shards, + 'deterministic_fs': True}, + log_percent=freq, tokenizer=TOKENIZER, + pp_fn=pp, skip_first=skip_first_eval) + + +def add_eval_pplx(c, res, num_frames=8, stride=None): + """Perplexity evaluator to test runs before implementing the real deal.""" + c_train = training_data(res, final_split=True, num_frames=num_frames, stride=stride) + + for name, split, first_k_shards in [ + ('minitrain', 'train', 1), # To gauge memorization. + ('minival', 'valid', 1), # To tune hparams. + ]: + c.evals[f'{DATASET_NAME}/{name}/pplx'] = dict( + type='proj.paligemma.perplexity', pred='logits', + key='text', shift_labels=True, + log_percent=1/8, # Not too cheap, do 10x per run. + data={**c_train.train.data, 'split': split, + 'first_k_shards': first_k_shards, + 'deterministic_fs': True}, + pp_fn=c_train.train.pp, + ) + + +def sweep_best(add, arg=None): + """Train with best hyper-params.""" + c = bvcc.parse_arg(arg, final_split=False) + add(lr=3e-6, wd=3e-7, total_epochs=1, **bvcc.arg(res=224, **c)) + + +sweep = sweep_best + + +def get_config(arg=None): + """Config for training.""" + c = bvcc.parse_arg(arg, mode='xm', num_frames=8, stride=26, res=224, + freeze_vit=False, freeze_llm=False, final_split=False) + + c.input = training_data( + c.res, final_split=c.final_split, + num_frames=c.num_frames, stride=c.stride) + + c.total_epochs = 3 + c.input.batch_size = 128 + c.optax_name = 'scale_by_adam' + c.optax = dict(b2=0.999) + c.lr = 3e-6 + c.wd = 3e-7 + c.grad_clip_norm = 1.0 + c.label_smoothing = 0.0 + + # Learning-rate schedule. + sched = dict(decay_type='cosine', warmup_percent=0.05) + c.schedule = [ + ('img/.*', None if c.freeze_vit else sched), + ('llm/.*', None if c.freeze_llm else sched), + ] + + # Add evaluators. + c.evals = {} + add_eval(c, c.res, c.num_frames, c.stride) + add_eval_pplx(c, c.res, c.num_frames, c.stride) + + # Model section. + c.model_name = 'proj.paligemma.paligemma' + c.model = {} + c.model.img = dict(variant='So400m/14', pool_type='none', scan=True) + c.model.llm = dict(vocab_size=256_000 + 1024 + 128, dropout=0.0) + c.model_init = f'pt_{c.res}' + + # FSDP strategy. + c.mesh = [('data', -1)] + c.sharding_strategy = [('.*', 'fsdp(axis="data")')] + c.sharding_rules = [('act_batch', ('data',))] + + for split in c.input.data.keys(): + c.input[split].shuffle_buffer_size = 10_000 + c.log_training_steps = 50 + c.ckpt_steps = 1_000 + c.pp_modules = ['ops_general', 'ops_image', 'ops_text', 'proj.paligemma.ops', + 'proj.paligemma.video'] + + # Update configs for quicker local runs and avoid swapping. + if c.mode in ('runlocal', 'mock'): + for split in c.input.data.keys(): + c.input[split].shuffle_buffer_size = None + for ev in c.evals.values(): + ev.data.first_k_shards = 1 + + if c.mode == 'runlocal': + c.log_training_steps = 1 + c.input.batch_size = 2 + + c.seed = 0 + return c + + +def metrics(arg=None): # pylint: disable=unused-argument + m = ['training_loss'] + for split in ('minitrain', 'minival', 'val', 'eval'): + m.append(('epoch', f'{DATASET_NAME}/{split}/acc')) + for split in ('minitrain', 'minival'): + m.append(('epoch', f'{DATASET_NAME}/{split}/pplx/avg')) + return m + diff --git a/big_vision/configs/proj/paligemma/transfers/nlvr2.py b/big_vision/configs/proj/paligemma/transfers/nlvr2.py new file mode 100644 index 0000000000000000000000000000000000000000..3e76a158ba26986e890e9ab6b2e38d0499ebc34b --- /dev/null +++ b/big_vision/configs/proj/paligemma/transfers/nlvr2.py @@ -0,0 +1,169 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""PaliGemma transfer to NLVR2 captions, including evaluation on MaRVL. +""" + +import big_vision.configs.common as bvcc +from big_vision.configs.proj.paligemma.transfers.common import combine_and_keep_train, combine_and_keep_eval, TOKENIZER + +LANGS = ('id', 'sw', 'ta', 'tr', 'zh') + + +def training_data(res, *, final_split, text_len=64): + """Creates training data config. + + See (internal link) + You can add more arguments beside `res`, but give them good defaults. + + Args: + res: The requested image resolution (eg 224). + final_split: Train on all train+dev data. + text_len: sequence length. + + Returns: + The ConfigDict for the input section. + """ + c = bvcc.parse_arg('') # Just make a configdict without extra import. + c.data = dict( + name='nlvr2', + split='train+dev' if final_split else 'train', + ) + + num_frames = 2 + c.pp = '|'.join([ + f'resize({res}, key="image_left")|resize({res}, key="image_right")', + 'stack_images(inkeys=["image_left", "image_right"], outkey="image")', + 'value_range(-1, 1)', + f'video_ensure_shape("image", {(num_frames, res, res, 3)})', + 'strfmt("answer en {sentence}", outkey="prefix")', + 'copy(inkey="label", outkey="suffix")', + combine_and_keep_train(text_len), + ]) + return c + + +def add_eval(c, res, text_len=64, **kw): + """QA evaluator.""" + # Input eval pp without ground truth text and random crop. + num_frames = 2 + pp_eval = '|'.join([ + f'resize({res}, key="image_left")|resize({res}, key="image_right")', + 'stack_images(inkeys=["image_left", "image_right"], outkey="image")', + 'value_range(-1, 1)', + f'video_ensure_shape("image", {(num_frames, res, res, 3)})', + 'strfmt("answer en {sentence}", outkey="prefix")', + 'copy(inkey="label", outkey="answer")', + 'copy(inkey="example_id", outkey="question_id")', + combine_and_keep_eval(text_len, keep=('answer', 'question_id')), + ]) + + for name, split in [ + ('minitrain', 'train[:10000]'), + ('dev', 'dev'), + ('test', 'test'), + ]: + c.evals[f'nlvr2/{name}'] = dict( + type='proj.paligemma.transfers.vqa', + pred='decode', pred_kw={'max_decode_len': text_len}, + outfile=f'{{workdir}}/nlvr2_{name}.json', + data={**training_data(res, final_split=True, text_len=text_len).data, 'split': split}, + log_percent=0.1, tokenizer=TOKENIZER, pp_fn=pp_eval) + c.evals[f'nlvr2/{name}'].update(kw) + + for lang in LANGS: + c.evals[f'marvl/test_{lang}'] = dict( + type='proj.paligemma.transfers.vqa', + pred='decode', pred_kw={'max_decode_len': text_len}, + outfile=f'{{workdir}}/marvl_test_{lang}.json', + data=dict( + name='marvl', + split=f'test_{lang}', + ), + log_percent=0.1, tokenizer=TOKENIZER, pp_fn=pp_eval) + c.evals[f'marvl/test_{lang}'].update(kw) + + +def get_config(arg=None): + """Config for training.""" + c = bvcc.parse_arg(arg, mode='xm', res=224, final_split=False) + + c.input = training_data(c.res, final_split=c.final_split) + + # Instead of epochs, you can also use `total_examples` or `total_steps`. + c.total_epochs = 3 + c.input.batch_size = 256 + c.optax_name = 'scale_by_adam' + c.optax = dict(b2=0.999) + c.lr = 1e-4 + c.wd = 0.0 + c.grad_clip_norm = 1.0 + c.label_smoothing = 0.0 + c.schedule = dict(decay_type='cosine', warmup_percent=0.05) + + # Add evaluators. + c.evals = {} + add_eval(c, c.res, batch_size=1024) + + # Model section. + c.model_name = 'proj.paligemma.paligemma' + c.model = {} + c.model.img = dict(variant='So400m/14', pool_type='none', scan=True) + c.model.llm = dict(vocab_size=256_000 + 1024 + 128, dropout=0.0) + c.model_init = f'pt_{c.res}' + + # FSDP strategy. + c.mesh = [('data', -1)] + c.sharding_strategy = [('.*', 'fsdp(axis="data")')] + c.sharding_rules = [('act_batch', ('data',))] + + # These probably do not need any change/tuning + c.input.shuffle_buffer_size = 50_000 + c.log_training_steps = 50 + c.ckpt_steps = 1_000 + c.pp_modules = ['ops_general', 'ops_image', 'ops_text', 'proj.paligemma.ops', + 'proj.paligemma.video'] + + # Update configs for quicker local runs and avoid swapping. + if c.mode in ('runlocal', 'mock'): + c.input.shuffle_buffer_size = None + for ev in c.evals.values(): + ev.data.split = ev.data.split.split('[')[0] + '[:16]' + + if c.mode == 'runlocal': + c.log_training_steps = 1 + c.input.batch_size = 2 + + c.seed = 0 + return c + + +def sweep_best(add, arg=None): + """Train with best hyper-params.""" + c = bvcc.parse_arg(arg, final_split=False) + add(lr=1e-5, wd=1e-6, total_epochs=3, **bvcc.arg(res=224, **c)) + add(lr=3e-6, wd=3e-7, total_epochs=10, **bvcc.arg(res=448, **c)) + + +sweep = sweep_best + + +def metrics(arg=None): # pylint: disable=unused-argument + m = ['training_loss'] + for split in ('minitrain', 'minival', 'dev', 'test'): + m.append(f'nlvr2/{split}/acc') + for lang in LANGS: + m.append(f'marvl/test_{lang}/acc') + return m diff --git a/big_vision/configs/proj/paligemma/transfers/ocrvqa.py b/big_vision/configs/proj/paligemma/transfers/ocrvqa.py new file mode 100644 index 0000000000000000000000000000000000000000..3ce0efa8bcdf8aad9dd544301bab5586d1898552 --- /dev/null +++ b/big_vision/configs/proj/paligemma/transfers/ocrvqa.py @@ -0,0 +1,171 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""PaliGemma transfer to OCR-VQA, see (internal link) for details and notes. +""" + +import big_vision.configs.common as bvcc +from big_vision.configs.proj.paligemma.transfers.common import combine_and_keep_train, combine_and_keep_eval, TOKENIZER + + +def training_data(res, *, final_split, text_len=32): + """Creates training data config. + + See (internal link) + You can add more arguments beside `res`, but give them good defaults. + + Args: + res: The requested image resolution (eg 224). + final_split: Train on all train+val data or train+val[20000:]. + text_len: sequence length. + + Returns: + The ConfigDict for the input section. + """ + # Note that the dataset is "unbatched" here, meaning each image_question pair + # is one example. So while there are ~800k training examples, there's only + # ~200k unique images, each one having on average 4 questions, and the + # questions are highly regular: + # - What is the title of this book? + # - What type of book is this? OR What is the genre of this book? + # - Who wrote this book? OR Who is the author of this book? + # - Is this book related to [GENRE]? OR Is this a [GENRE] book? "yes" + # - Same but with answer "no" + # So one obvious thing we could do in training is randomize the negative genre + # question more using a custom pp op. + c = bvcc.parse_arg('') # Just make a configdict without extra import. + c.data = dict( + name='ocrvqa_id', + split='train+val' if final_split else 'train + val[20_000:]', # Val is 100k, we don't need that much! + ) + c.pp = '|'.join([ + f'decode|resize({res}, antialias=True)|value_range(-1, 1)', + 'strfmt("answer en {question}", outkey="prefix")', + 'copy(inkey="answer", outkey="suffix")', + combine_and_keep_train(text_len), + ]) + return c + + +def add_eval(c, res, text_len=32, **kw): + """OCR-VQA evaluators.""" + pp = '|'.join([ + f'decode|resize({res}, antialias=True)|value_range(-1, 1)', + 'strfmt("answer en {question}", outkey="prefix")', + 'copy(inkey="int_id", outkey="question_id")', + combine_and_keep_eval(text_len, keep=('answer', 'question_id')), + ]) + + for freq, name, split in [ + (1/8, 'minitrain', 'train[:5120]'), # To gauge memorization. + (1/4, 'minival', 'val[:20_000]'), # To tune hparams. SLOW! + (1.0, 'eval', 'test'), # Final number to report. Big => rare. + ]: + c.evals[f'ocrvqa/{name}'] = dict( + type='proj.paligemma.transfers.vqa', + to_lower=True, + pred='decode', pred_kw={'max_decode_len': text_len}, + data={**training_data(res, final_split=True, text_len=text_len).data, 'split': split}, + log_percent=freq, skip_first=freq == 1, tokenizer=TOKENIZER, pp_fn=pp) + c.evals[f'ocrvqa/{name}'].update(kw) + + +def add_eval_pplx(c, res, text_len=32): + """Perplexity evaluator to test runs before implementing the real deal.""" + c_train = training_data(res, final_split=True, text_len=text_len) # Use mostly same settings as training. + for name, split in [ + ('minitrain', 'train[:5120]'), # To gauge memorization. + ('minival', 'val[:20_000]'), # To tune hparams. + ('eval', 'test'), # To compute final publishable scores. + ]: + c.evals[f'ocrvqa/{name}/pplx'] = dict( + type='proj.paligemma.perplexity', pred='logits', + key='text', shift_labels=True, + log_percent=0.1, # Eval ~10x per run; + data={**c_train.data, 'split': split}, + pp_fn=c_train.pp, + ) + + +def sweep_best(add, arg=None): + """Train with best hyper-params.""" + c = bvcc.parse_arg(arg, final_split=False) + add(**bvcc.arg(res=224, **c), lr=3e-6) + add(**bvcc.arg(res=448, **c), lr=3e-6) + add(**bvcc.arg(res=896, **c), lr=1e-5) + + +sweep = sweep_best # Choose which sweep to run. + + +def get_config(arg=None): + """Config for training.""" + c = bvcc.parse_arg(arg, mode='xm', res=224, final_split=False) + + c.input = training_data(c.res, final_split=c.final_split) + + # Instead of epochs, you can also use `total_examples` or `total_steps`. + c.total_epochs = 3 + c.input.batch_size = 128 + c.optax_name = 'scale_by_adam' + c.optax = dict(b2=0.999) + c.lr = 3e-6 + c.wd = 0.0 + c.grad_clip_norm = 1.0 + c.label_smoothing = 0.0 + c.schedule = dict(decay_type='cosine', warmup_percent=0.05) + + # Add evaluators. + c.evals = {} + add_eval(c, c.res, batch_size=256) + add_eval_pplx(c, c.res) + + # Model section. + c.model_name = 'proj.paligemma.paligemma' + c.model = {} + c.model.img = dict(variant='So400m/14', pool_type='none', scan=True) + c.model.llm = dict(vocab_size=256_000 + 1024 + 128, dropout=0.0) + c.model_init = f'pt_{c.res}' + + # FSDP strategy. + c.mesh = [('data', -1)] + c.sharding_strategy = [('.*', 'fsdp(axis="data")')] + c.sharding_rules = [('act_batch', ('data',))] + + # These probably do not need any change/tuning + c.input.shuffle_buffer_size = 50_000 + c.log_training_steps = 50 + c.ckpt_steps = 1_000 + c.pp_modules = ['ops_general', 'ops_image', 'ops_text', 'proj.paligemma.ops'] + + # Update configs for quicker local runs and avoid swapping. + if c.mode in ('runlocal', 'mock'): + c.input.shuffle_buffer_size = None + for ev in c.evals.values(): + ev.data.split = ev.data.split.split('[')[0] + '[:16]' + + if c.mode == 'runlocal': + c.log_training_steps = 1 + c.input.batch_size = 2 + + c.seed = 0 + return c + + +def metrics(arg=None): # pylint: disable=unused-argument + m = ['training_loss'] + for split in ('eval', 'minival', 'minitrain'): + m.append(f'ocrvqa/{split}/acc') + return m diff --git a/big_vision/configs/proj/paligemma/transfers/okvqa.py b/big_vision/configs/proj/paligemma/transfers/okvqa.py new file mode 100644 index 0000000000000000000000000000000000000000..cc5090362b3bf9ac8acfddbe8f02f279f80b4d2b --- /dev/null +++ b/big_vision/configs/proj/paligemma/transfers/okvqa.py @@ -0,0 +1,161 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""PaliGemma transfer to OK-VQA. +""" + +import big_vision.configs.common as bvcc +from big_vision.configs.proj.paligemma.transfers.common import combine_and_keep_train, combine_and_keep_eval, TOKENIZER + + +def training_data(res, final_split, text_len=32): + """Creates training data config. + + See (internal link) + You can add more arguments beside `res`, but give them good defaults. + + Args: + res: The requested image resolution (eg 224). + final_split: Whether to use all train data. + text_len: sequence length + + Returns: + The ConfigDict for the input section. + """ + c = bvcc.parse_arg('') # Just make a configdict without extra import. + c.data = dict( + name='okvqa', + split='train' if final_split else 'train[:-10%]', + ) + c.pp = '|'.join([ + f'decode|resize({res}, antialias=True)|value_range(-1, 1)', + 'strfmt("answer en {question}", outkey="prefix")', + 'choice_no_replacement(inkey="answers", outkey="suffix")', + combine_and_keep_train(text_len), + ]) + return c + + +def add_eval(c, res, text_len=32, **kw): + """We can use the normal VQA evaluators.""" + pp = '|'.join([ + f'decode|resize({res}, antialias=True)|value_range(-1, 1)', + 'strfmt("answer en {question}", outkey="prefix")', + combine_and_keep_eval(text_len, keep=('answers', 'question_id')), + ]) + + for freq, name, split in [ + (1/4, 'minitrain', 'train[:5%]'), # To gauge memorization. + (1/4, 'minival', 'train[-10%:]'), # To tune hparams. + (1/4, 'eval', 'val'), # To compute final publishable scores. + ]: + c.evals[f'okvqa/{name}'] = dict( + type='proj.paligemma.transfers.vqa', + pred='decode', pred_kw={'max_decode_len': text_len}, + outfile=f'{{workdir}}/okvqa_{name}.json', + data={**training_data(res, True, text_len).data, 'split': split}, + log_percent=freq, tokenizer=TOKENIZER, pp_fn=pp) + c.evals[f'okvqa/{name}'].update(kw) + + +def add_eval_pplx(c, res, text_len=32): + """Perplexity evaluator to test runs before implementing the real deal.""" + c_train = training_data(res, True, text_len) # Use mostly same settings as training. + + for name, split in [ + ('minitrain', 'train[:5%]'), # To gauge memorization. + ('minival', 'train[-10%:]'), # To tune hparams. + ('eval', 'val'), # To compute final publishable scores. + ]: + c.evals[f'okvqa/{name}/pplx'] = dict( + type='proj.paligemma.perplexity', pred='logits', + key='text', shift_labels=True, + log_percent=0.05, # Eval ~20x per run; it's cheap. + data={**c_train.data, 'split': split}, + pp_fn=c_train.pp, + ) + + +def sweep_best(add, arg=None): + """Train with best hyper-params.""" + c = bvcc.parse_arg(arg, final_split=False) + add(**bvcc.arg(res=224, **c)) + add(**bvcc.arg(res=448, **c)) + # Not better: add(**bvcc.arg(res=896, **c)) + + +sweep = sweep_best # Choose which sweep to run. + + +def get_config(arg=None): + """Config for training.""" + c = bvcc.parse_arg(arg, mode='xm', res=224, final_split=False) + + c.input = training_data(c.res, c.final_split) + + # Instead of epochs, you can also use `total_examples` or `total_steps`. + c.total_epochs = 10 + c.input.batch_size = 128 + c.optax_name = 'scale_by_adam' + c.optax = dict(b2=0.999) + c.lr = 5e-6 + c.wd = 0.0 + c.grad_clip_norm = 1.0 + c.label_smoothing = 0.0 + c.schedule = dict(decay_type='cosine', warmup_percent=0.05) + + # Add evaluators. + c.evals = {} + add_eval(c, c.res, batch_size=1024) + add_eval_pplx(c, c.res) + + # Model section. + c.model_name = 'proj.paligemma.paligemma' + c.model = {} + c.model.img = dict(variant='So400m/14', pool_type='none', scan=True) + c.model.llm = dict(vocab_size=256_000 + 1024 + 128, dropout=0.0) + c.model_init = f'pt_{c.res}' + + # FSDP strategy. + c.mesh = [('data', -1)] + c.sharding_strategy = [('.*', 'fsdp(axis="data")')] + c.sharding_rules = [('act_batch', ('data',))] + + # These probably do not need any change/tuning + c.input.shuffle_buffer_size = 50_000 + c.log_training_steps = 50 + c.ckpt_steps = 1_000 + c.pp_modules = ['ops_general', 'ops_image', 'ops_text', 'proj.paligemma.ops'] + + # Update configs for quicker local runs and avoid swapping. + if c.mode in ('runlocal', 'mock'): + c.input.shuffle_buffer_size = None + for ev in c.evals.values(): + ev.data.split = ev.data.split.split('[')[0] + '[:16]' + + if c.mode == 'runlocal': + c.log_training_steps = 1 + c.input.batch_size = 2 + + c.seed = 0 + return c + + +def metrics(arg=None): # pylint: disable=unused-argument + m = ['training_loss'] + for split in ('eval', 'minival', 'minitrain'): + m.append(f'okvqa/{split}/pplx/avg') + m.append(f'okvqa/{split}/acc') + return m diff --git a/big_vision/configs/proj/paligemma/transfers/pope.py b/big_vision/configs/proj/paligemma/transfers/pope.py new file mode 100644 index 0000000000000000000000000000000000000000..26610551504b5ecba78cecb0fd6903b1d9200f52 --- /dev/null +++ b/big_vision/configs/proj/paligemma/transfers/pope.py @@ -0,0 +1,118 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""PaliGemma (0-shot) evaluation in POPE. +""" + +import big_vision.configs.common as bvcc +from big_vision.configs.proj.paligemma.transfers.common import combine_and_keep_train, combine_and_keep_eval, TOKENIZER + +_DATASETS = ('pope_random', 'pope_popular', 'pope_adversarial') + + +# Note that POPE does not have training data, only test data. +# We're defining training_data() here anyway for symmetry with the other +# transfers. We will train for 0 steps on this data, i.e. not at all. +def training_data(res, *, text_len, prefix): + """Creates training data config. + + See (internal link) + + Args: + res: The requested image resolution (eg 224). + text_len: sequence length. + prefix: prefix to use in the prompt: (e.g. 'answer en {question}') + + Returns: + The ConfigDict for the input section. + """ + c = bvcc.parse_arg('') # Just make a configdict without extra import. + pp = '|'.join([ + f'decode|resize({res})|value_range(-1, 1)', + f'strfmt("{prefix}", outkey="prefix")', + 'copy(inkey="answer", outkey="suffix")', + combine_and_keep_train(text_len=text_len), + ]) + c.data = {} + for dataset in _DATASETS: + c.data[dataset] = 1 # Set weight to 1. + c[dataset] = dict( + pp=pp, + data=dict( + name=dataset, + split='test', + ), + ) + return c + + +def add_eval(c, res, *, text_len, prefix): + """Add eval configs.""" + pp_eval = '|'.join([ + f'decode|resize({res})|value_range(-1, 1)', + f'strfmt("{prefix}", outkey="prefix")', + combine_and_keep_eval(text_len=text_len, keep=('question_id', 'answer')), + ]) + for dataset in _DATASETS: + c.evals[f'pope/{dataset}'] = dict( + type='proj.paligemma.transfers.pope', + pred='decode', pred_kw={'max_decode_len': text_len}, + log_percent=1, tokenizer=TOKENIZER, + data=dict( + name=dataset, + split='test', + ), + pp_fn=pp_eval, + ) + + +def get_config(arg=None): + """Config for training.""" + c = bvcc.parse_arg(arg, mode='xm', res=224, text_len=48, prefix='{question}') + + c.name = '' # Help to track experiments. + c.input = training_data(c.res, text_len=c.text_len, prefix=c.prefix) + + # Make the config eval-only by setting some dummies. + c.total_steps = 0 + c.input.batch_size = 256 + c.optax_name = 'identity' + c.lr = 0.0 + + # Add evaluators. + c.evals = {} + add_eval(c, c.res, text_len=c.text_len, prefix=c.prefix) + + # Model section. + c.model_name = 'proj.paligemma.paligemma' + c.model = {} + c.model.img = dict(variant='So400m/14', pool_type='none', scan=True) + c.model.llm = dict(vocab_size=256_000 + 1024 + 128, dropout=0.0) + c.model_init = f'pt_{c.res}' + + # FSDP strategy. + c.mesh = [('data', -1)] + c.sharding_strategy = [('.*', 'fsdp(axis="data")')] + c.sharding_rules = [('act_batch', ('data',))] + + # These probably do not need any change/tuning + c.pp_modules = ['ops_general', 'ops_image', 'ops_text', 'proj.paligemma.ops'] + + # Update configs for quicker local runs and avoid swapping. + if c.mode in ('runlocal', 'mock'): + for ev in c.evals.values(): + ev.data.split = ev.data.split.split('[')[0] + '[:16]' + + return c \ No newline at end of file diff --git a/big_vision/configs/proj/paligemma/transfers/refcoco_seg.py b/big_vision/configs/proj/paligemma/transfers/refcoco_seg.py new file mode 100644 index 0000000000000000000000000000000000000000..4e2725f2c4eee478a5799b607378e5615598b937 --- /dev/null +++ b/big_vision/configs/proj/paligemma/transfers/refcoco_seg.py @@ -0,0 +1,198 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""PaliGemma transfer to RefCOCO (with segmentation). +""" + +import big_vision.configs.common as bvcc +from big_vision.configs.proj.paligemma.transfers.common import combine_and_keep_train, combine_and_keep_eval, TOKENIZER + + +def training_data(res, text_len=48, crop='rs'): + """Creates training data config. + + See (internal link) + You can add more arguments beside `res`, but give them good defaults. + + Args: + res: The requested image resolution (eg 224) + text_len: sequence length + crop: What way to do random cropping to get to res. + + Returns: + The ConfigDict for the input section. + """ + c = bvcc.parse_arg('') # Just make a configdict without extra import. + c.data = dict( + name='ref_coco_bv/refcocox_combined:1.4.0', + split='train', + ) + + if crop == 'rs': + crop_ops = f'resize({res})' + elif crop == 'zic_mild': + crop_ops = '|'.join([ + 'zoomout(max_f=1.5, key="image", bboxkey="objects/bbox", auxkeys=["objects/mask"])', + f'inception_box(area=(0.1,1.0), aspect=({3/4},{4/3}), min_obj_cover=1.0, bboxkey="objects/bbox")', + 'box_crop_bbox', + 'box_crop_img(key="objects/mask")', + 'box_crop_img(key="image")', + f'resize({res})', + ]) + else: + raise ValueError(crop) + + c.pp = '|'.join([ + 'flatten', + 'choice_no_replacement(key=["objects/mask", "objects/bbox", "objects/refs/sentence"])', + 'choice(key=["objects/refs/sentence"])', + 'decode', + crop_ops, + 'value_range(-1, 1)', + 'refcoco_mask2str', + combine_and_keep_train(text_len), + ]) + return c + + +def add_eval(c, res, text_len=48, **kw): + """Segmentation evaluator computing mIoU.""" + # NOTE: we verified that squeezing to square at eval time is no worse than + # padding with black borders. The actual evaluation is still done in original + # full resolution of the mask, of course. + pp_eval_squeeze = '|'.join([ + 'flatten', + # (choice simply removes a dimension since it's already flattened) + 'choice(key=["objects/mask", "objects/bbox", "objects/refs/sentence"])', + 'choice(key=["objects/refs/sentence"], outkey="prefix")', + # 'refcoco_mask2str', # TODO: b/lbeyer - also eval decoded GT mask? + f'decode|resize({res})|value_range(-1, 1)', + combine_and_keep_eval(text_len, keep=('objects/mask', 'objects/bbox', 'width', 'height')), + ]) + + for freq, name, ds_name, split in [ + (0.2, 'refcoco/val', 'ref_coco_bv/refcoco_unc:1.4.0', 'validation_flat'), + (1.0, 'refcoco/testA', 'ref_coco_bv/refcoco_unc:1.4.0', 'testA_flat'), + (1.0, 'refcoco/testB', 'ref_coco_bv/refcoco_unc:1.4.0', 'testB_flat'), + (1.0, 'refcocop/val', 'ref_coco_bv/refcocoplus_unc:1.4.0', 'validation_flat'), + (1.0, 'refcocop/testA', 'ref_coco_bv/refcocoplus_unc:1.4.0', 'testA_flat'), + (1.0, 'refcocop/testB', 'ref_coco_bv/refcocoplus_unc:1.4.0', 'testB_flat'), + (1.0, 'refcocog/val', 'ref_coco_bv/refcocog_umd:1.4.0', 'validation_flat'), + (1.0, 'refcocog/test', 'ref_coco_bv/refcocog_umd:1.4.0', 'test_flat'), + ]: + c.evals[f'seg/{name}'] = dict( + type='proj.paligemma.transfers.segmentation', + pred='decode', pred_kw={'max_decode_len': text_len}, + data={'name': ds_name, 'split': split}, + log_percent=freq, skip_first=freq == 1, + tokenizer=TOKENIZER, pp_fn=pp_eval_squeeze) + c.evals[f'seg/{name}'].update(kw) + + +def add_eval_pplx(c, res, text_len=48): + """Perplexity evaluator to test runs before implementing the real deal.""" + c_train = training_data(res, text_len) # Use mostly same settings as training. + for name, split in [ + ('minitrain', 'train[:5%]'), # 1_220 + ('val', 'validation'), # 2_738 + ]: + c.evals[f'refcoco_seg/{name}/pplx'] = dict( + type='proj.paligemma.perplexity', pred='logits', + key='text', shift_labels=True, + log_percent=0.05, # Eval ~20x per run; it's cheap. + data={**c_train.data, 'split': split}, + pp_fn=c_train.pp, + ) + + +def sweep_best(add, arg=None): # pylint: disable=unused-argument + """Train with best hyper-params per resolution.""" + # Based on (internal link) + add(**bvcc.arg(res=224), lr=3e-5, total_epochs=100, label_smoothing=0.3, + **{'model.llm.dropout': 0.1, 'input.batch_size': 256}) + add(**bvcc.arg(res=448), lr=1e-5, total_epochs=100, label_smoothing=0.3, + **{'model.llm.dropout': 0.0, 'input.batch_size': 256}) + # Takes 2d on 16 TPUv5e, gives overall +0.5-+1 over 448. + add(**bvcc.arg(res=896), lr=1e-5, total_epochs=100, label_smoothing=0.3, + **{'model.llm.dropout': 0.0, 'input.batch_size': 64}) + + +sweep = sweep_best # Choose which sweep to run. + + +def get_config(arg=None): + """Default "reasonably good" training config, gets about 75/67/70.""" + c = bvcc.parse_arg(arg, mode='xm', res=448, crop='rs') + + c.input = training_data(c.res, crop=c.crop) + + # Instead of epochs, you can also use `total_examples` or `total_steps`. + c.total_epochs = 5 + c.input.batch_size = 64 + c.optax_name = 'scale_by_adam' + c.optax = dict(b2=0.999) + c.lr = 1e-5 + c.wd = 0.0 + c.grad_clip_norm = 1.0 + c.label_smoothing = 0.3 + c.schedule = dict(decay_type='cosine', warmup_percent=0.05) + + # Add evaluators. + c.evals = {} + add_eval(c, c.res, batch_size=256) + add_eval_pplx(c, c.res) + + # Model section. + c.model_name = 'proj.paligemma.paligemma' + c.model = {} + c.model.img = dict(variant='So400m/14', pool_type='none', scan=True, dropout=0.0) + c.model.llm = dict(vocab_size=256_000 + 1024 + 128, dropout=0.0) + c.model_init = f'pt_{c.res}' + + # FSDP strategy. + c.mesh = [('data', -1)] + c.sharding_strategy = [('.*', 'fsdp(axis="data")')] + c.sharding_rules = [('act_batch', ('data',))] + + # These probably do not need any change/tuning + c.input.shuffle_buffer_size = 50_000 + c.log_training_steps = 50 + c.ckpt_steps = 1_000 + c.pp_modules = [ + 'ops_general', 'ops_image', 'ops_text', 'proj.paligemma.ops', + 'proj.paligemma.segmentation', + ] + + # Update configs for quicker local runs and avoid swapping. + if c.mode in ('runlocal', 'mock'): + c.input.shuffle_buffer_size = None + for ev in c.evals.values(): + ev.data.split = ev.data.split.split('[')[0] + '[:16]' + + if c.mode == 'runlocal': + c.log_training_steps = 1 + c.input.batch_size = 2 + + c.seed = 0 + return c + + +def metrics(arg=None): # pylint: disable=unused-argument + m = ['training_loss'] + for thing in ('miou', 'boxacc/0.5', 'invalid'): + m.append(f'seg/refcoco/val/{thing}') + for split in ('/testA', '/testB', 'p/val', 'p/testA', 'p/testB', 'g/val', 'g/test'): + m.append(f'seg/refcoco{split}/miou') + return m diff --git a/big_vision/configs/proj/paligemma/transfers/rsvqa_hr.py b/big_vision/configs/proj/paligemma/transfers/rsvqa_hr.py new file mode 100644 index 0000000000000000000000000000000000000000..bfe5b1058db8f2d6962fb0d3442e33c920942219 --- /dev/null +++ b/big_vision/configs/proj/paligemma/transfers/rsvqa_hr.py @@ -0,0 +1,182 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""PaliGemma transfer to RSVQA-HR. +""" + +import big_vision.configs.common as bvcc +from big_vision.configs.proj.paligemma.transfers.common import combine_and_keep_train, combine_and_keep_eval, TOKENIZER + +ONLY_NON_NUMERIC_ANSWERS = True +# text_len: max prefix: 31, suffix: 5, prefix+sep+suffix: 32. + + +def training_data(res, *, final_split, text_len=32): + """Creates training data config. + + See Colab: + http://(internal link)#scrollTo=1jZ-9FMPVD-q + + You can add more arguments beside `res`, but give them good defaults. + + Args: + res: The requested image resolution (eg 224). + final_split: Whether to use all of the validation data. + text_len: The maximum text length (in tokens). + + Returns: + The ConfigDict for the input section. + """ + c = bvcc.parse_arg('') # Just make a configdict without extra import. + c.data = dict( + name='rsvqa_hr/nonum' if ONLY_NON_NUMERIC_ANSWERS else 'rsvqa_hr/all', + split='train + val' if final_split else 'train', + ) + c.pp = '|'.join([ + f'decode|resize({res}, antialias=True)|value_range(-1, 1)', + # Answers in reference evaluation are converted to lower case. + # See https://rrc.cvc.uab.es/?ch=17&com=tasks or paper. + 'lower(key="answers")', + 'strfmt("answer en {question}", outkey="prefix")', + 'choice_no_replacement(inkey="answers", outkey="suffix")', + combine_and_keep_train(text_len), + ]) + return c + + +def add_eval(c, res, text_len=32, **kw): + """Add evaluators.""" + c_train = training_data(res, final_split=True, text_len=text_len) + + pp = '|'.join([ + f'decode|resize({res})|value_range(-1, 1)', + # Answers in reference evaluation are converted to lower case. + # See https://rrc.cvc.uab.es/?ch=17&com=tasks or paper. + 'lower(key="answers")', + 'strfmt("answer en {question}", outkey="prefix")', + combine_and_keep_eval(text_len, keep=('answers', 'question_type', 'question_id')), + ]) + + for freq, name, split in [ + (0.1, 'minitrain', 'train[:1280]'), + (1.0, 'minival', 'val'), # Very slow (large) + (1.0, 'test', 'test'), # Very slow (large) + (1.0, 'test2', 'test_2'), # Very slow (large) + ]: + c.evals[f'rsvqa_hr/{name}'] = dict( + type='proj.paligemma.transfers.rsvqa', + pred='decode', pred_kw={'max_decode_len': text_len}, + # Answers in reference evaluation are converted to lower case. + # See https://rrc.cvc.uab.es/?ch=17&com=tasks or paper. + to_lower=True, + outfile=f'{{workdir}}/rsvqa_hr_{name}.json', + data={**c_train.data, 'split': split}, + log_percent=freq, skip_first=freq == 1, tokenizer=TOKENIZER, pp_fn=pp) + c.evals[f'rsvqa_hr/{name}'].update(kw) + + +def add_eval_pplx(c, res, text_len=32): + """Perplexity evaluator to test runs before implementing the real deal.""" + c_train = training_data(res, final_split=True, text_len=text_len) + for name, split in [ + ('minitrain', 'train[:1280]'), + ('minival', 'val'), + ]: + c.evals[f'rsvqa_hr/{name}/pplx'] = dict( + type='proj.paligemma.perplexity', pred='logits', + key='text', shift_labels=True, + log_percent=0.1, + data={**c_train.data, 'split': split}, + pp_fn=c_train.pp, + ) + + +def sweep_best(add, arg=None): + """Train with best hyper-params.""" + c = bvcc.parse_arg(arg, final_split=False) + add(lr=1e-5, wd=0.0, total_epochs=1, **bvcc.arg(res=224, **c)) + add(lr=1e-5, wd=0.0, total_epochs=1, **bvcc.arg(res=448, **c)) + # 896 not better: (internal link) + # add(lr=1e-5, wd=0.0, total_epochs=1, **bvcc.arg(res=896, **c)) + + +sweep = sweep_best # Choose which sweep to run. + + +def get_config(arg=None): + """Config for training.""" + c = bvcc.parse_arg(arg, mode='xm', res=224, final_split=False) + + c.input = training_data(c.res, final_split=c.final_split) + + # Instead of epochs, you can also use `total_examples` or `total_steps`. + c.total_epochs = 1 + c.input.batch_size = 256 + c.optax_name = 'scale_by_adam' + c.optax = dict(b2=0.999) + c.lr = 1e-5 + c.wd = 0.0 + c.grad_clip_norm = 1.0 + c.label_smoothing = 0.0 + c.schedule = dict(decay_type='cosine', warmup_percent=0.05) + + # Add evaluators. + c.evals = {} + add_eval(c, c.res, batch_size=1024) + add_eval_pplx(c, c.res) + + # Model section. + c.model_name = 'proj.paligemma.paligemma' + c.model = {} + c.model.img = dict(variant='So400m/14', pool_type='none', scan=True) + c.model.llm = dict(vocab_size=256_000 + 1024 + 128, dropout=0.0) + c.model_init = f'pt_{c.res}' + + # FSDP strategy. + c.mesh = [('data', -1)] + c.sharding_strategy = [('.*', 'fsdp(axis="data")')] + c.sharding_rules = [('act_batch', ('data',))] + + # These probably do not need any change/tuning + c.input.shuffle_buffer_size = 50_000 + c.log_training_steps = 50 + c.ckpt_steps = 1_000 + c.pp_modules = ['ops_general', 'ops_image', 'ops_text', 'proj.paligemma.ops'] + + # Update configs for quicker local runs and avoid swapping. + if c.mode in ('runlocal', 'mock'): + c.input.shuffle_buffer_size = None + for ev in c.evals.values(): + ev.data.split = ev.data.split.split('[')[0] + '[:16]' + + if c.mode == 'runlocal': + c.log_training_steps = 1 + c.input.batch_size = 2 + del c.total_epochs + c.total_steps = 10 + + c.seed = 0 + return c + + +def metrics(arg=None): # pylint: disable=unused-argument + m = ['training_loss'] + for split in ('minival', 'test', 'test2', 'minitrain'): + m.append(f'rsvqa_hr/{split}/acc_any') + m.append(f'rsvqa_hr/{split}/acc_avg') + m.append(f'rsvqa_hr/{split}/acc_avg_nonum') + m.append(f'rsvqa_hr/{split}/anls') + m.append(f'rsvqa_hr/{split}/pplx/avg') + return m diff --git a/big_vision/configs/proj/paligemma/transfers/rsvqa_lr.py b/big_vision/configs/proj/paligemma/transfers/rsvqa_lr.py new file mode 100644 index 0000000000000000000000000000000000000000..2e0f91cefa97688209036be13b8900553d26c613 --- /dev/null +++ b/big_vision/configs/proj/paligemma/transfers/rsvqa_lr.py @@ -0,0 +1,181 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""PaliGemma transfer to RSVQA-LR. +""" + +import big_vision.configs.common as bvcc +from big_vision.configs.proj.paligemma.transfers.common import combine_and_keep_train, combine_and_keep_eval, TOKENIZER + +ONLY_NON_NUMERIC_ANSWERS = True +# text_len: max prefix: 31, suffix: 5, prefix+sep+suffix: 32. + + +def training_data(res, *, final_split, text_len=32): + """Creates training data config. + + See Colab: + http://(internal link)#scrollTo=1jZ-9FMPVD-q + + You can add more arguments beside `res`, but give them good defaults. + + Args: + res: The requested image resolution (eg 224). + final_split: Whether to use all of the validation data. + text_len: The maximum text length (in tokens). + + Returns: + The ConfigDict for the input section. + """ + c = bvcc.parse_arg('') # Just make a configdict without extra import. + c.data = dict( + name='rsvqa_lr/nonum' if ONLY_NON_NUMERIC_ANSWERS else 'rsvqa_lr/all', + split='train + val' if final_split else 'train', + ) + c.pp = '|'.join([ + f'decode|resize({res}, antialias=True)|value_range(-1, 1)', + # Answers in reference evaluation are converted to lower case. + # See https://rrc.cvc.uab.es/?ch=17&com=tasks or paper. + 'lower(key="answers")', + 'strfmt("answer en {question}", outkey="prefix")', + 'choice_no_replacement(inkey="answers", outkey="suffix")', + combine_and_keep_train(text_len), + ]) + return c + + +def add_eval(c, res, text_len=32, **kw): + """Add evaluators.""" + c_train = training_data(res, final_split=True, text_len=text_len) # Use mostly same settings as training. + + pp = '|'.join([ + f'decode|resize({res})|value_range(-1, 1)', + # Answers in reference evaluation are converted to lower case. + # See https://rrc.cvc.uab.es/?ch=17&com=tasks or paper. + 'lower(key="answers")', + 'strfmt("answer en {question}", outkey="prefix")', + combine_and_keep_eval(text_len, keep=('answers', 'question_type', 'question_id')), + ]) + + for freq, name, split in [ + (0.1, 'minitrain', 'train[:1280]'), + (0.1, 'minival', 'val'), + (0.1, 'test', 'test'), + ]: + c.evals[f'rsvqa_lr/{name}'] = dict( + type='proj.paligemma.transfers.rsvqa', + pred='decode', pred_kw={'max_decode_len': text_len}, + # Answers in reference evaluation are converted to lower case. + # See https://rrc.cvc.uab.es/?ch=17&com=tasks or paper. + to_lower=True, + outfile=f'{{workdir}}/rsvqa_lr_{name}.json', + data={**c_train.data, 'split': split}, + log_percent=freq, tokenizer=TOKENIZER, pp_fn=pp) + c.evals[f'rsvqa_lr/{name}'].update(kw) + + +def add_eval_pplx(c, res, text_len=32): + """Perplexity evaluator to test runs before implementing the real deal.""" + c_train = training_data(res, final_split=True, text_len=text_len) # Use mostly same settings as training. + for name, split in [ + ('minitrain', 'train[:1280]'), + ('minival', 'val'), + ]: + c.evals[f'rsvqa_lr/{name}/pplx'] = dict( + type='proj.paligemma.perplexity', pred='logits', + key='text', shift_labels=True, + log_percent=0.1, + data={**c_train.data, 'split': split}, + pp_fn=c_train.pp, + ) + + +def sweep_best(add, arg=None): + """Train with best hyper-params.""" + c = bvcc.parse_arg(arg, final_split=False) + # Note: best performance achieved at 224. + add(lr=3e-6, wd=0.0, total_epochs=3, **bvcc.arg(res=224, **c)) + add(lr=3e-6, wd=0.0, total_epochs=3, **bvcc.arg(res=448, **c)) + # add(lr=3e-6, wd=0.0, total_epochs=3, **bvcc.arg(res=896, **c)) + + +sweep = sweep_best # Choose which sweep to run. + + +def get_config(arg=None): + """Config for training.""" + c = bvcc.parse_arg(arg, mode='xm', res=224, final_split=False) + + c.input = training_data(c.res, final_split=c.final_split) + + # Instead of epochs, you can also use `total_examples` or `total_steps`. + c.total_epochs = 3 + c.input.batch_size = 256 + c.optax_name = 'scale_by_adam' + c.optax = dict(b2=0.999) + c.lr = 3e-6 + c.wd = 0.0 + c.grad_clip_norm = 1.0 + c.label_smoothing = 0.2 + c.schedule = dict(decay_type='cosine', warmup_percent=0.05) + + # Add evaluators. + c.evals = {} + add_eval(c, c.res, batch_size=1024) + add_eval_pplx(c, c.res) + + # Model section. + c.model_name = 'proj.paligemma.paligemma' + c.model = {} + c.model.img = dict(variant='So400m/14', pool_type='none', scan=True) + c.model.llm = dict(vocab_size=256_000 + 1024 + 128, dropout=0.0) + c.model_init = f'pt_{c.res}' + + # FSDP strategy. + c.mesh = [('data', -1)] + c.sharding_strategy = [('.*', 'fsdp(axis="data")')] + c.sharding_rules = [('act_batch', ('data',))] + + # These probably do not need any change/tuning + c.input.shuffle_buffer_size = 50_000 + c.log_training_steps = 50 + c.ckpt_steps = 1_000 + c.pp_modules = ['ops_general', 'ops_image', 'ops_text', 'proj.paligemma.ops'] + + # Update configs for quicker local runs and avoid swapping. + if c.mode in ('runlocal', 'mock'): + c.input.shuffle_buffer_size = None + for ev in c.evals.values(): + ev.data.split = ev.data.split.split('[')[0] + '[:16]' + + if c.mode == 'runlocal': + c.log_training_steps = 1 + c.input.batch_size = 2 + del c.total_epochs + c.total_steps = 10 + + c.seed = 0 + return c + + +def metrics(arg=None): # pylint: disable=unused-argument + m = ['training_loss'] + for split in ('minival', 'test', 'minitrain'): + m.append(f'rsvqa_lr/{split}/acc_any') + m.append(f'rsvqa_lr/{split}/acc_avg') + m.append(f'rsvqa_lr/{split}/acc_avg_nonum') + m.append(f'rsvqa_lr/{split}/anls') + m.append(f'rsvqa_lr/{split}/pplx/avg') + return m diff --git a/big_vision/configs/proj/paligemma/transfers/scicap.py b/big_vision/configs/proj/paligemma/transfers/scicap.py new file mode 100644 index 0000000000000000000000000000000000000000..51618033ef6d70db125b5e5cef092cc335cae9ae --- /dev/null +++ b/big_vision/configs/proj/paligemma/transfers/scicap.py @@ -0,0 +1,167 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""PaliGemma transfer to SciCap. +""" + +import big_vision.configs.common as bvcc +from big_vision.configs.proj.paligemma.transfers.common import combine_and_keep_train, combine_and_keep_eval, TOKENIZER + +# This feature name is also hardcoded in the coco_captions evaluator. +CAPTION_FEATURE = 'caption/lowercase_and_token_and_remove_figure_index' + + +def training_data(res, *, final_split, text_len=96): + """Creates training data config. + + See (internal link) + You can add more arguments beside `res`, but give them good defaults. + + Args: + res: The requested image resolution (e.g. 224). + final_split: Train on all train+val data. + text_len: sequence length. + + Returns: + The ConfigDict for the input section. + """ + c = bvcc.parse_arg('') # Just make a configdict without extra import. + c.data = dict( + name='scicap/first_sentence_subfig_no', + split='train+val' if final_split else 'train', + ) + c.pp = '|'.join([ + 'decode', f'resize({res})', 'value_range(-1, 1)', + 'strfmt("caption en", outkey="prefix")', + f'copy(inkey="{CAPTION_FEATURE}", outkey="suffix")', + combine_and_keep_train(text_len), + ]) + return c + + +def add_eval(c, res, text_len=96, **kw): + """SciCap evaluator (BLEU-4).""" + pp = '|'.join([ + f'reshape([1], inkey="{CAPTION_FEATURE}", outkey="captions")', # GT for evaluator. + 'decode', f'resize({res})', 'value_range(-1, 1)', + 'strfmt("caption en", outkey="prefix")', + combine_and_keep_eval(text_len, keep=('image/id', 'captions')), + ]) + + for name, split in [ + ('minitrain', 'train[:5%]'), # To gauge memorization. + ('val', 'val'), # To tune hparams + ('test', 'test'), + ]: + c.evals[f'scicap/{name}'] = dict( + type='proj.paligemma.transfers.coco_caption', + pred='decode', pred_kw={'max_decode_len': text_len}, + metrics=('cider', 'bleu-4'), + data=dict( + name='scicap/first_sentence_subfig_no', + split=split, + ), + log_percent=0.3, tokenizer=TOKENIZER, pp_fn=pp) + c.evals[f'scicap/{name}'].update(kw) + + +def add_eval_pplx(c, res, text_len=96): + """Perplexity evaluator to test runs before implementing the real deal.""" + c_train = training_data(res, final_split=True, text_len=text_len) # Use mostly same settings as training. + for name, split in [ + ('minitrain', 'train[:5%]'), # To gauge memorization. + ('val', 'val'), # To tune hparams. + ('test', 'test'), + ]: + c.evals[f'scicap/{name}/pplx'] = dict( + type='proj.paligemma.perplexity', pred='logits', + key='text', shift_labels=True, + log_percent=0.05, # Eval ~20x per run; it's cheap. + data={**c_train.data, 'split': split}, + pp_fn=c_train.pp, + ) + + +def sweep_best(add, arg=None): + """Train with best hyper-params.""" + c = bvcc.parse_arg(arg, final_split=False) + add(**bvcc.arg(res=224, **c)) + add(**bvcc.arg(res=448, **c)) + # 896 is too slow and not better than 448: (internal link) + + +sweep = sweep_best + + +def get_config(arg=None): + """Config for training.""" + c = bvcc.parse_arg(arg, mode='xm', res=224, final_split=False) + + c.input = training_data(c.res, final_split=c.final_split) + + # Instead of epochs, you can also use `total_examples` or `total_steps`. + c.total_epochs = 80 + c.input.batch_size = 256 + c.optax_name = 'scale_by_adam' + c.optax = dict(b2=0.999) + c.lr = 3e-5 + c.wd = 0.0 + c.grad_clip_norm = 1.0 + c.label_smoothing = 0.1 + c.schedule = dict(decay_type='cosine', warmup_percent=0.05) + + # Add evaluators. + c.evals = {} + add_eval(c, c.res, batch_size=1024) + add_eval_pplx(c, c.res) + + # Model section. + c.model_name = 'proj.paligemma.paligemma' + c.model = {} + c.model.img = dict(variant='So400m/14', pool_type='none', scan=True) + c.model.llm = dict(vocab_size=256_000 + 1024 + 128, dropout=0.1) + c.model_init = f'pt_{c.res}' + + # FSDP strategy. + c.mesh = [('data', -1)] + c.sharding_strategy = [('.*', 'fsdp(axis="data")')] + c.sharding_rules = [('act_batch', ('data',))] + + # These probably do not need any change/tuning + c.input.shuffle_buffer_size = 50_000 + c.log_training_steps = 50 + c.ckpt_steps = 1_000 + c.pp_modules = ['ops_general', 'ops_image', 'ops_text', 'proj.paligemma.ops'] + + # Update configs for quicker local runs and avoid swapping. + if c.mode in ('runlocal', 'mock'): + c.input.shuffle_buffer_size = None + for ev in c.evals.values(): + ev.data.split = ev.data.split.split('[')[0] + '[:16]' + + if c.mode == 'runlocal': + c.log_training_steps = 1 + c.input.batch_size = 2 + + c.seed = 0 + return c + + +def metrics(arg=None): # pylint: disable=unused-argument + m = ['training_loss'] + for split in ('test', 'val', 'minitrain'): + m.append(f'scicap/{split}/pplx/avg') + m.append(f'scicap/{split}/cider') + return m diff --git a/big_vision/configs/proj/paligemma/transfers/science_qa.py b/big_vision/configs/proj/paligemma/transfers/science_qa.py new file mode 100644 index 0000000000000000000000000000000000000000..5b8777b63f707a98b922be9a6eb6ee90834d27de --- /dev/null +++ b/big_vision/configs/proj/paligemma/transfers/science_qa.py @@ -0,0 +1,225 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""PaliGemma transfer to science_qa. +""" + +import big_vision.configs.common as bvcc +from big_vision.configs.proj.paligemma.transfers.common import combine_and_keep_train, combine_and_keep_eval, TOKENIZER + + +def training_data(res, *, final_split, text_len=512, qfmt='QCM', afmt='A'): + """Creates training data config. + + See + (internal link) + You can add more arguments beside `res`, but give them good defaults. + implemented based on : + https://github.com/lupantech/ScienceQA/blob/main/models/base_prompt.py + default prompt format baseline: QCM -> A (see paper) + Args: + res: The requested image resolution (eg 224). + final_split: Train on all train+val data. + text_len: sequence length. + qfmt: see config_prompt_format. + afmt: see config_prompt_format. + + Returns: + The ConfigDict for the input section. + """ + c = bvcc.parse_arg('') # Just make a configdict without extra import. + c.data = dict( + name='science_qa', + split='train+val' if final_split else 'train', + ) + qfmt, afmt = config_prompt_format(qfmt, afmt) + c.pp = '|'.join([ + # Read and prepare the image by just resizing it: + f'decode|resize({res})|value_range(-1, 1)', + 'drop("indexed_choices","indexed_answer")', + "sci_qa_choices_shuffle(choice_str_inkey='choices', ans_inkey='answer')", + f'strfmt("{qfmt}", outkey="prefix")', + f'strfmt("{afmt}", outkey="suffix")', + combine_and_keep_train(text_len), + ]) + return c + + +def config_prompt_format(qfmt='QCM', afmt='A'): + """Configure prompt format, default: QCM -> A. + + See https://github.com/lupantech/ScienceQA/blob/main/models/base_prompt.py + Args: + qfmt: input prompt format -> the question text (Q), the context text + (C), and multiple options (M) + afmt: out prompt format -> A = answer, AE = answer with + explanation, ALE = answer with lecture and explanation. + + Returns: + prompt format string for training data config to digest + """ + # TODO: b/keranrong - The \\nAnswer: is useless for our model. + if qfmt == 'simple': + qfmt = '{question}\\nOptions: {indexed_choices}' + elif qfmt == 'QM': + qfmt = 'Question: {question}\\nOptions: {indexed_choices}\\nAnswer:' + elif qfmt == 'CQM': + qfmt = 'Context: {hint}\\nQuestion: {question}\\nOptions: {indexed_choices}\\nAnswer:' + elif qfmt == 'QCM': + qfmt = 'Question: {question}\\nContext: {hint}\\nOptions: {indexed_choices}\\nAnswer:' + else: + raise ValueError(qfmt) + + if afmt == 'simple': + afmt = '{indexed_answer}' + elif afmt == 'A': + afmt = 'The answer is {indexed_answer}.' + elif afmt == 'AL': + afmt = 'The answer is {indexed_answer}. BECAUSE: {solution}' + elif afmt == 'AE': + afmt = 'The answer is {indexed_answer}. BECAUSE: {lecture}' + elif afmt == 'ALE': + afmt = 'The answer is {indexed_answer}. BECAUSE: {lecture} {solution}' + else: + raise ValueError(afmt) + + return qfmt, afmt + + +def add_eval(c, res, text_len=512, qfmt='QCM', afmt='A', **kw): + """Science QA evaluators.""" + prefix, suffix = config_prompt_format(qfmt, afmt) + pp = '|'.join([ + f'decode|resize({res})|value_range(-1, 1)', + f'strfmt("{prefix}", outkey="prefix")', + f'strfmt("{suffix}", outkey="answer")', + 'copy(inkey="_id",outkey="question_id")', + combine_and_keep_eval(text_len, keep=('answer', 'question_id')), + ]) + + for name, split in [ + ('minitrain', 'train[:5%]'), # To gauge memorization. + ('minival', 'val'), # To tune hparams. + ('eval', 'test'), + ]: + c.evals[f'science_qa/{name}'] = dict( + type='proj.paligemma.transfers.science_qa', + pred='decode', pred_kw={'max_decode_len': text_len}, + outfile=f'{{workdir}}/science_{name}.json', + data={**training_data(res, final_split=True, text_len=text_len, qfmt=qfmt, afmt=afmt).data, 'split': split}, + log_percent=1/8, tokenizer=TOKENIZER, pp_fn=pp) + c.evals[f'science_qa/{name}'].update(kw) + + +def add_eval_pplx(c, res, text_len=512, qfmt='QCM', afmt='A'): + """Perplexity evaluator to test runs before implementing the real deal.""" + c_train = training_data(res, final_split=True, text_len=text_len, qfmt=qfmt, afmt=afmt) # Use mostly same settings as training. + for name, split in [ + ('minitrain', 'train[:5%]'), # To gauge memorization. + ('minival', 'val'), # To tune hparams. + ]: + c.evals[f'science_qa/{name}/pplx'] = dict( + type='proj.paligemma.perplexity', pred='logits', + key='text', shift_labels=True, + log_percent=0.05, # Eval ~20x per run; it's cheap. + data={**c_train.data, 'split': split}, + pp_fn=c_train.pp, + ) + + +def sweep_best(add, arg=None): + """Train with best hyper-params.""" + c = bvcc.parse_arg(arg, final_split=False) + # Based on sweep xids/98045006 (science_qa/eval/acc). + # TODO: b/keranrong - try getting rid of freezing + add(lr=1e-5, wd=0, **bvcc.arg(freeze_vit=True, res=224, **c)) + add(lr=1e-5, wd=0, **bvcc.arg(freeze_vit=True, res=448, **c)) + + +sweep = sweep_best # Choose which sweep to run. + + +def get_config(arg=None): + """Config for training.""" + c = bvcc.parse_arg(arg, mode='xm', res=448, final_split=False, freeze_vit=False, freeze_llm=False, qfmt='QCM', afmt='A') + + c.input = training_data(c.res, final_split=c.final_split, qfmt=c.qfmt, afmt=c.afmt) + + # Instead of epochs, you can also use `total_examples` or `total_steps`. + c.total_epochs = 20 + c.input.batch_size = 128 + c.optax_name = 'scale_by_adam' + c.optax = dict(b2=0.999) + c.lr = 5e-5 + c.wd = 0.0 + c.grad_clip_norm = 1.0 + c.label_smoothing = 0.0 + + # Learning-rate schedule. Probably is fine like this. + sched = dict(decay_type='cosine', warmup_percent=0.05) + c.schedule = [ + ('img/.*', None if c.freeze_vit else sched), + ('llm/.*', None if c.freeze_llm else sched), + ] + + # Add evaluators. + c.evals = {} + add_eval(c, c.res, qfmt=c.qfmt, afmt=c.afmt, batch_size=1024) + add_eval_pplx(c, c.res, qfmt=c.qfmt, afmt=c.afmt) + + # Model section. + c.model_name = 'proj.paligemma.paligemma' + c.model = {} + c.model.img = dict(variant='So400m/14', pool_type='none', scan=True) + c.model.llm = dict(vocab_size=256_000 + 1024 + 128, dropout=0.0) + c.model_init = f'pt_{c.res}' + + # FSDP strategy. + c.mesh = [('data', -1)] + c.sharding_strategy = [('.*', 'fsdp(axis="data")')] + c.sharding_rules = [('act_batch', ('data',))] + + # These probably do not need any change/tuning + c.input.shuffle_buffer_size = 50_000 + c.log_training_steps = 50 + c.ckpt_steps = 1_000 + c.pp_modules = [ + 'ops_general', + 'ops_image', + 'ops_text', + 'proj.paligemma.ops', + 'proj.paligemma.sciqa_ops', + ] + + # Update configs for quicker local runs and avoid swapping. + if c.mode in ('runlocal', 'mock'): + c.input.shuffle_buffer_size = None + for ev in c.evals.values(): + ev.data.split = ev.data.split.split('[')[0] + '[:16]' + + if c.mode == 'runlocal': + c.log_training_steps = 1 + c.input.batch_size = 2 + + c.seed = 0 + return c + + +def metrics(arg=None): # pylint: disable=unused-argument + m = ['training_loss'] + for split in ('eval', 'minival', 'minitrain'): + m.append(f'science_qa/{split}/pplx/avg') + m.append(f'science_qa/{split}/acc') + return m diff --git a/big_vision/configs/proj/paligemma/transfers/screen2words.py b/big_vision/configs/proj/paligemma/transfers/screen2words.py new file mode 100644 index 0000000000000000000000000000000000000000..da0665291e0ed1bab103429694bda53034348f57 --- /dev/null +++ b/big_vision/configs/proj/paligemma/transfers/screen2words.py @@ -0,0 +1,164 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""PaliGemma transfer to Screen2words. +""" + +import big_vision.configs.common as bvcc +from big_vision.configs.proj.paligemma.transfers.common import combine_and_keep_train, combine_and_keep_eval, TOKENIZER + + +def training_data(res, text_len=24, final_split=False): + """Creates training data config. + + See (internal link) + You can add more arguments beside `res`, but give them good defaults. + + Args: + res: The requested image resolution (e.g. 224). + text_len: sequence length + final_split: Train also on val or not. + + Returns: + The ConfigDict for the input section. + """ + c = bvcc.parse_arg('') # Just make a configdict without extra import. + c.data = dict( + name='screen2_words', + split='train+dev' if final_split else 'train', + ) + c.pp = '|'.join([ + 'decode', f'resize({res})', 'value_range(-1, 1)', + 'strfmt("caption en", outkey="prefix")', + 'copy(inkey="image/id", outkey="_id")', + 'choice_no_replacement(inkey="summary", outkey="suffix")', + combine_and_keep_train(text_len), + ]) + return c + + +def add_eval(c, res, text_len=24, **kw): + """Screen2words evaluator (CIDER).""" + pp = '|'.join([ + 'copy("summary", "captions")', # GT for evaluator. + 'decode', f'resize({res})', 'value_range(-1, 1)', + 'strfmt("caption en", outkey="prefix")', + combine_and_keep_eval(text_len, keep=('image/id', 'captions')), + ]) + + for name, split in [ + ('minitrain', 'train[:5%]'), # To gauge memorization. + ('dev', 'dev'), # To tune hparams + ('test', 'test'), + ]: + c.evals[f'screen2words/{name}'] = dict( + type='proj.paligemma.transfers.coco_caption', + pred='decode', pred_kw={'max_decode_len': text_len}, + data=dict( + name='screen2_words', + split=split, + ), + log_percent=0.1, tokenizer=TOKENIZER, pp_fn=pp) + c.evals[f'screen2words/{name}'].update(kw) + + +def add_eval_pplx(c, res, text_len=24): + """Perplexity evaluator to test runs before implementing the real deal.""" + c_train = training_data(res, text_len) # Use mostly same settings as training. + for name, split in [ + ('minitrain', 'train[:5%]'), # To gauge memorization. + ('dev', 'dev'), # To tune hparams. + ('test', 'test'), + ]: + c.evals[f'screen2words/{name}/pplx'] = dict( + type='proj.paligemma.perplexity', pred='logits', + key='text', shift_labels=True, + log_percent=0.05, # Eval ~20x per run; it's cheap. + data={**c_train.data, 'split': split}, + pp_fn=c_train.pp, + ) + + +def sweep_best(add, arg=None): + """Train with best hyper-params.""" + c = bvcc.parse_arg(arg, final_split=False) + add(**bvcc.arg(res=224, **c)) + add(**bvcc.arg(res=448, **c)) + # 896 not better than 448. + + +sweep = sweep_best # Choose which sweep to run. + + +def get_config(arg=None): + """Config for training.""" + c = bvcc.parse_arg(arg, mode='xm', res=224, final_split=False) + + c.input = training_data(c.res, final_split=c.final_split) + + # Instead of epochs, you can also use `total_examples` or `total_steps`. + c.total_epochs = 10 + c.input.batch_size = 256 + c.optax_name = 'scale_by_adam' + c.optax = dict(b2=0.999) + c.lr = 1e-5 + c.wd = 0.0 + c.grad_clip_norm = 1.0 + c.label_smoothing = 0.2 + c.schedule = dict(decay_type='cosine', warmup_percent=0.05) + + # Add evaluators. + c.evals = {} + add_eval(c, c.res, batch_size=1024) + add_eval_pplx(c, c.res) + + # Model section. + c.model_name = 'proj.paligemma.paligemma' + c.model = {} + c.model.img = dict(variant='So400m/14', pool_type='none', scan=True) + c.model.llm = dict(vocab_size=256_000 + 1024 + 128, dropout=0.3) + c.model_init = f'pt_{c.res}' + + # FSDP strategy. + c.mesh = [('data', -1)] + c.sharding_strategy = [('.*', 'fsdp(axis="data")')] + c.sharding_rules = [('act_batch', ('data',))] + + # These probably do not need any change/tuning + c.input.shuffle_buffer_size = 50_000 + c.log_training_steps = 50 + c.ckpt_steps = 1_000 + c.pp_modules = ['ops_general', 'ops_image', 'ops_text', 'proj.paligemma.ops'] + + # Update configs for quicker local runs and avoid swapping. + if c.mode in ('runlocal', 'mock'): + c.input.shuffle_buffer_size = None + for ev in c.evals.values(): + ev.data.split = ev.data.split.split('[')[0] + '[:16]' + + if c.mode == 'runlocal': + c.log_training_steps = 1 + c.input.batch_size = 2 + + c.seed = 0 + return c + + +def metrics(arg=None): # pylint: disable=unused-argument + m = ['training_loss'] + for split in ('test', 'dev', 'minitrain'): + m.append(('epoch', f'screen2words/{split}/cider')) + m.append(('epoch', f'screen2words/{split}/pplx/avg')) + return m diff --git a/big_vision/configs/proj/paligemma/transfers/stvqa.py b/big_vision/configs/proj/paligemma/transfers/stvqa.py new file mode 100644 index 0000000000000000000000000000000000000000..1f3be64e82cad0c29960fa06959340b0ce2c1231 --- /dev/null +++ b/big_vision/configs/proj/paligemma/transfers/stvqa.py @@ -0,0 +1,175 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""PaliGemma transfer to Scene Text VQA (ST-VQA). +""" + +import big_vision.configs.common as bvcc +from big_vision.configs.proj.paligemma.transfers.common import combine_and_keep_train, combine_and_keep_eval, TOKENIZER + + +def training_data(res, final_split, text_len=48): + """Creates training data config. + + See Colab: + http://(internal link)#scrollTo=1jZ-9FMPVD-q + + You can add more arguments beside `res`, but give them good defaults. + + Args: + res: The requested image resolution (eg 224). + final_split: Whether to use all train data. + text_len: The maximum text length (in tokens). + + Returns: + The ConfigDict for the input section. + """ + c = bvcc.parse_arg('') # Just make a configdict without extra import. + c.data = dict( + name='stvqa', + split='train+val' if final_split else 'train', + ) + c.pp = '|'.join([ + f'decode|resize({res}, antialias=True)|value_range(-1, 1)', + # Answers in reference evaluation are converted to lower case. + # See https://rrc.cvc.uab.es/?ch=17&com=tasks or paper. + 'lower(key="answers")', + 'strfmt("answer en {question}", outkey="prefix")', + 'choice_no_replacement(inkey="answers", outkey="suffix")', + combine_and_keep_train(text_len), + ]) + return c + + +def add_eval(c, res, text_len=48, **kw): + """Add evaluators.""" + pp = '|'.join([ + f'decode|resize({res}, antialias=True)|value_range(-1, 1)', + # Answers in reference evaluation are converted to lower case. + # See https://rrc.cvc.uab.es/?ch=17&com=tasks or paper. + 'lower(key="answers")', + 'strfmt("answer en {question}", outkey="prefix")', + combine_and_keep_eval(text_len, keep=('answers', 'question_id')), + ]) + + for freq, name, split in [ + (0.1, 'minitrain', 'train[:5%]'), # To gauge memorization. ~2.5k samples. + (0.1, 'minival', 'val'), # Pseudo-test. + (1.0, 'test', 'test'), # Also stores predictions for the test set (4.1k). + ]: + c.evals[f'stvqa/{name}'] = dict( + type='proj.paligemma.transfers.vqa', + pred='decode', pred_kw={'max_decode_len': text_len}, + # Answers in reference evaluation are converted to lower case. + # See https://rrc.cvc.uab.es/?ch=17&com=tasks or paper. + to_lower=True, + outfile=f'{{workdir}}/stvqa_{name}.json', + data={**training_data(res, False, text_len).data, 'split': split}, + log_percent=freq, skip_first=freq == 1, tokenizer=TOKENIZER, pp_fn=pp) + c.evals[f'stvqa/{name}'].update(kw) + + +def add_eval_pplx(c, res, text_len=48): + """Perplexity evaluator to test runs before implementing the real deal.""" + c_train = training_data(res, False, text_len) # Use mostly same settings as training. + for name, split in [ + ('minitrain', 'train[:5%]'), # To gauge memorization. + ('minival', 'val'), # To tune hparams. + ]: + c.evals[f'stvqa/{name}/pplx'] = dict( + type='proj.paligemma.perplexity', pred='logits', + key='text', shift_labels=True, + log_percent=0.05, # Eval ~20x per run; it's cheap. + data={**c_train.data, 'split': split}, + pp_fn=c_train.pp, + ) + + +def sweep_best(add, arg=None): + """Train with best hyper-params.""" + # Based on http://(internal link)/ZCwPqz0b3tE and (internal link) + c = bvcc.parse_arg(arg, mode='xm', final_split=False) + add(lr=1e-5, wd=1e-6, total_epochs=3, **bvcc.arg(res=224, **c)) + add(lr=1e-5, wd=1e-6, total_epochs=3, **bvcc.arg(res=448, **c)) + add(lr=3e-6, wd=3e-7, total_epochs=3, **bvcc.arg(res=896, **c)) + + +sweep = sweep_best # Choose which sweep to run. + + +def get_config(arg=None): + """Config for training.""" + # max prefix: 29, suffix: 50, prefix+sep+suffix: 59 + c = bvcc.parse_arg(arg, mode='xm', res=896, text_len=48, final_split=False) + + c.input = training_data(c.res, c.final_split, c.text_len) + + # Instead of epochs, you can also use `total_examples` or `total_steps`. + c.total_epochs = 5 + c.input.batch_size = {224: 256, 448: 128, 896: 32}[c.res] + c.optax_name = 'scale_by_adam' + c.optax = dict(b2=0.999) + c.lr = 3e-6 + c.wd = 3e-7 + c.grad_clip_norm = 1.0 + c.label_smoothing = 0.1 + c.schedule = dict(decay_type='cosine', warmup_percent=0.05) + + # Add evaluators. + c.evals = {} + add_eval(c, c.res, c.text_len, batch_size=256) + add_eval_pplx(c, c.res, c.text_len) + + # Model section. + c.model_name = 'proj.paligemma.paligemma' + c.model = {} + c.model.img = dict(variant='So400m/14', pool_type='none', scan=True) + c.model.llm = dict(vocab_size=256_000 + 1024 + 128, dropout=0.0) + c.model_init = f'pt_{c.res}' + + # FSDP strategy. + c.mesh = [('data', -1)] + c.sharding_strategy = [('.*', 'fsdp(axis="data")')] + c.sharding_rules = [('act_batch', ('data',))] + + # These probably do not need any change/tuning + c.input.shuffle_buffer_size = 50_000 + c.log_training_steps = 50 + c.ckpt_steps = 1_000 + c.pp_modules = ['ops_general', 'ops_image', 'ops_text', 'proj.paligemma.ops'] + + # Update configs for quicker local runs and avoid swapping. + if c.mode in ('runlocal', 'mock'): + c.input.shuffle_buffer_size = None + for ev in c.evals.values(): + ev.data.split = ev.data.split.split('[')[0] + '[:16]' + + if c.mode == 'runlocal': + c.log_training_steps = 1 + c.input.batch_size = 2 + del c.total_epochs + c.total_steps = 10 + + c.seed = 0 + return c + + +def metrics(arg=None): # pylint: disable=unused-argument + m = ['training_loss'] + for split in ('minival', 'minitrain'): + m.append(f'stvqa/{split}/anls') + m.append(f'stvqa/{split}/acc') + m.append(f'stvqa/{split}/pplx/avg') + return m diff --git a/big_vision/configs/proj/paligemma/transfers/tallyqa.py b/big_vision/configs/proj/paligemma/transfers/tallyqa.py new file mode 100644 index 0000000000000000000000000000000000000000..738a95226e4fe31a558799e39a03e3c67dac0300 --- /dev/null +++ b/big_vision/configs/proj/paligemma/transfers/tallyqa.py @@ -0,0 +1,191 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""PaliGemma transfer to TallyQA. +""" + +import big_vision.configs.common as bvcc +from big_vision.configs.proj.paligemma.transfers.common import combine_and_keep_train, combine_and_keep_eval, TOKENIZER + + +def training_data(res, text_len=32): + """Creates training data config. + + See (internal link) + You can add more arguments beside `res`, but give them good defaults. + + Args: + res: The requested image resolution (eg 224) + text_len: sequence length + + Returns: + The ConfigDict for the input section. + """ + c = bvcc.parse_arg('') # Just make a configdict without extra import. + c.data = dict( + name='tallyqa', + split='train', + ) + + c.pp = '|'.join([ + f'decode|resize({res}, antialias=True)|value_range(-1, 1)', + 'strfmt("answer en {question}", outkey="prefix")', + 'strfmt("{answer}", outkey="suffix")', + combine_and_keep_train(text_len), + ]) + return c + + +def countbenchqa_eval_data(res, text_len=32): + """Creates eval data config for CountBenchQA.""" + c = bvcc.parse_arg('') # Just make a configdict without extra import. + c.data = dict( + name='countbenchqa', + split='huggingface', + ) + c.pp = '|'.join([ + f'decode|resize({res}, antialias=True)|value_range(-1, 1)', + 'strfmt("answer en {question}", outkey="prefix")', + 'strfmt("{number}", outkey="answer")', + combine_and_keep_eval(text_len, keep=('answer',)), + ]) + return c + + +def add_eval(c, res, text_len=32, **kw): + """Add eval configs.""" + tallyqa_pp_eval = '|'.join([ + f'decode|resize({res}, antialias=True)|value_range(-1, 1)', + 'strfmt("answer en {question}", outkey="prefix")', + 'strfmt("{answer}", outkey="answer")', + combine_and_keep_eval(text_len, keep=('answer', 'issimple')), + ]) + + for freq, name, split in [ + (0.1, 'minitrain', 'train[:5%]'), + # (0.1, 'minival', 'train[-5%:]'), + (1/4, 'eval', 'test'), + ]: + c.evals[f'tallyqa/{name}'] = dict( + type='proj.paligemma.transfers.tallyqa', + pred='decode', pred_kw={'max_decode_len': text_len}, + data={**training_data(res, text_len).data, 'split': split}, + log_percent=freq, tokenizer=TOKENIZER, pp_fn=tallyqa_pp_eval) + c.evals[f'tallyqa/{name}'].update(kw) + + # CountBenchQA eval. We use the TallyQA eval for this but just pass in + # different data. + c.evals['countbenchqa/eval'] = dict( + type='proj.paligemma.transfers.tallyqa', + pred='decode', pred_kw={'max_decode_len': text_len}, + data=countbenchqa_eval_data(res, text_len).data, + log_percent=0.1, # This is a very small and cheap eval set. + tokenizer=TOKENIZER, + pp_fn=countbenchqa_eval_data(res, text_len).pp) + c.evals['countbenchqa/eval'].update(kw) + + +def add_eval_pplx(c, res, text_len=32): + """Perplexity evaluator to test runs before implementing the real deal.""" + c_train = training_data(res, text_len) # Use mostly same settings as training. + for name, split in [ + ('minitrain', 'train[:5%]'), # To gauge memorization. + # ('minival', 'train[-5%:]'), # To tune hparams. + ('eval', 'test'), # To compute final publishable scores. + ]: + c.evals[f'tallyqa/{name}/pplx'] = dict( + type='proj.paligemma.perplexity', pred='logits', + key='text', shift_labels=True, + log_percent=0.1, # Eval ~10x per run; it's cheap. + data={**c_train.data, 'split': split}, + pp_fn=c_train.pp, + ) + + +def sweep_best(add, arg=None): # pylint: disable=unused-argument + """Train with best hyper-params.""" + add(total_epochs=2, lr=1e-5, wd=0.00, **bvcc.arg(res=224)) + add(total_epochs=2, lr=1e-5, wd=1e-6, **bvcc.arg(res=448)) + add(total_epochs=2, lr=7e-6, wd=7e-7, **bvcc.arg(res=896)) + + +sweep = sweep_best # Choose which sweep to run. + + +def get_config(arg=None): + """Config for training.""" + c = bvcc.parse_arg(arg, mode='xm', res=224) + + c.input = training_data(c.res) + + # Instead of epochs, you can also use `total_examples` or `total_steps`. + c.total_epochs = 2 + c.input.batch_size = 256 + c.optax_name = 'scale_by_adam' + c.optax = dict(b2=0.999) + c.lr = 1e-5 + c.wd = 0.0 + c.grad_clip_norm = 1.0 + c.label_smoothing = 0.0 + c.schedule = dict(decay_type='cosine', warmup_percent=0.05) + + # Add evaluators. + c.evals = {} + add_eval(c, c.res, batch_size=256) + add_eval_pplx(c, c.res) + + # Model section. + c.model_name = 'proj.paligemma.paligemma' + c.model = {} + c.model.img = dict(variant='So400m/14', pool_type='none', scan=True) + c.model.llm = dict(vocab_size=256_000 + 1024 + 128, dropout=0.0) + c.model_init = f'pt_{c.res}' + + # FSDP strategy. + c.mesh = [('data', -1)] + c.sharding_strategy = [('.*', 'fsdp(axis="data")')] + c.sharding_rules = [('act_batch', ('data',))] + + # These probably do not need any change/tuning + c.input.shuffle_buffer_size = 50_000 + c.log_training_steps = 50 + c.ckpt_steps = 1_000 + c.pp_modules = ['ops_general', 'ops_image', 'ops_text', 'proj.paligemma.ops'] + + # Update configs for quicker local runs and avoid swapping. + if c.mode in ('runlocal', 'mock'): + c.input.shuffle_buffer_size = None + for ev in c.evals.values(): + ev.data.split = ev.data.split.split('[')[0] + '[:16]' + + if c.mode == 'runlocal': + c.log_training_steps = 1 + c.input.batch_size = 2 + + c.seed = 0 + return c + + +def metrics(arg=None): # pylint: disable=unused-argument + """Returns a list of metric names.""" + return [ + 'training_loss', + 'countbenchqa/eval/acc', + 'tallyqa/minitrain/pplx/avg', + 'tallyqa/eval/pplx/avg', + 'tallyqa/eval/acc', + 'tallyqa/eval/acc/complex', + 'tallyqa/eval/acc/simple', + ] diff --git a/big_vision/configs/proj/paligemma/transfers/textcaps.py b/big_vision/configs/proj/paligemma/transfers/textcaps.py new file mode 100644 index 0000000000000000000000000000000000000000..7249df806010fe0fd0ec80697df44a9f1816c5b4 --- /dev/null +++ b/big_vision/configs/proj/paligemma/transfers/textcaps.py @@ -0,0 +1,181 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""PaliGemma transfer to TextCaps captioning task. +""" + +import big_vision.configs.common as bvcc +from big_vision.configs.proj.paligemma.transfers.common import combine_and_keep_train, combine_and_keep_eval, TOKENIZER + + +def training_data(res, *, final_split, text_len=32, crop='rs'): + """Creates training data config. + + See (internal link) + You can add more arguments beside `res`, but give them good defaults. + + Args: + res: The requested image resolution (eg 224). + final_split: Train on all train data or train[:98%]. + text_len: sequence length. + crop: one of {'ic', 'rc', 'rs'}. + + Returns: + The ConfigDict for the input section. + """ + c = bvcc.parse_arg('') # Just make a configdict without extra import. + c.data = dict( + name='text_caps', + split='train' if final_split else 'train[:98%]', + ) + + if crop == 'ic': + crop = f'inception_crop({res}, area_min=50)' + elif crop == 'rc': + crop = f'resize_small({res*8//7})|random_crop({res})' + elif crop == 'rs': + crop = f'resize({res})' + else: + raise ValueError(f'Unknown crop: {crop}') + + c.pp = '|'.join([ + 'flatten', + 'decode', crop, 'value_range(-1, 1)', + 'choice_no_replacement(inkey="texts", outkey="suffix")', + 'strfmt("caption en", outkey="prefix")', + 'lower(key="suffix")', + combine_and_keep_train(text_len), + ]) + return c + + +def add_eval(c, res, text_len=32, **kw): + """Captioning evaluator with cider/bleu-4/meteor/rouge/spice metrics.""" + # Input eval pp without ground truth text and random crop. + pp_eval = '|'.join([ + 'decode', f'resize({res})', 'value_range(-1, 1)', + 'flatten', 'copy("texts", "captions")', # GT for evaluator. + 'strfmt("caption en", outkey="prefix")', + combine_and_keep_eval(text_len, keep=('image/id', 'captions')), + ]) + + for name, split in [ + ('minitrain', 'train[:2%]'), + ('minival', 'train[-2%:]'), + ('eval', 'val'), + ]: + c.evals[f'textcaps/{name}'] = dict( + type='proj.paligemma.transfers.coco_caption', + pred='decode', pred_kw={'max_decode_len': text_len}, + log_percent=0.1, tokenizer=TOKENIZER, + data={'name': 'text_caps', 'split': split}, + pp_fn=pp_eval, + ) + c.evals[f'textcaps/{name}'].update(kw) + + +def add_eval_pplx(c, res, text_len=32): + """Perplexity evaluator to test runs before implementing the real deal.""" + c_train = training_data(res, final_split=True, text_len=text_len, crop='rs') # Use mostly same settings as training. + for name, split in [ + ('minitrain', 'train[:2%]'), + ('minival', 'train[-2%:]'), + ('eval', 'val'), + ]: + c.evals[f'textcaps/{name}/pplx'] = dict( + type='proj.paligemma.perplexity', pred='logits', + key='text', shift_labels=True, + log_percent=0.05, # Eval ~20x per run; it's cheap. + data={**c_train.data, 'split': split}, + pp_fn=c_train.pp, + ) + + +def get_config(arg=None): + """Config for training.""" + c = bvcc.parse_arg(arg, mode='xm', crop='rs', res=224, beam_size=3, final_split=False) + + c.input = training_data(c.res, final_split=c.final_split, crop=c.crop) + + c.total_epochs = 5 # Note each example has 5 captions. + c.input.batch_size = 256 + c.optax_name = 'scale_by_adam' + c.optax = dict(b2=0.999) + c.lr = 1e-5 + c.wd = 0.0 + c.grad_clip_norm = 1.0 + c.label_smoothing = 0.0 + c.schedule = dict(decay_type='cosine', warmup_percent=0.05) + + # Add evaluators. + c.evals = {} + add_eval_pplx(c, c.res) + + if c.beam_size: + decode_kw = {'pred': 'beam_decode', 'pred_kw': {'beam_size': c.beam_size}} + else: + decode_kw = {} + + add_eval(c, c.res, batch_size=1024, **decode_kw) + + # Model section. + c.model_name = 'proj.paligemma.paligemma' + c.model = {} + c.model.img = dict(variant='So400m/14', pool_type='none', scan=True) + c.model.llm = dict(vocab_size=256_000 + 1024 + 128, dropout=0.0) + c.model_init = f'pt_{c.res}' + + # FSDP strategy. + c.mesh = [('data', -1)] + c.sharding_strategy = [('.*', 'fsdp(axis="data")')] + c.sharding_rules = [('act_batch', ('data',))] + + # These probably do not need any change/tuning + c.input.shuffle_buffer_size = 50_000 + c.log_training_steps = 50 + c.ckpt_steps = 1_000 + c.pp_modules = ['ops_general', 'ops_image', 'ops_text', 'proj.paligemma.ops'] + + # Update configs for quicker local runs and avoid swapping. + if c.mode in ('runlocal', 'mock'): + c.input.shuffle_buffer_size = None + for ev in c.evals.values(): + ev.data.split = ev.data.split.split('[')[0] + '[:16]' + + if c.mode == 'runlocal': + c.log_training_steps = 1 + c.input.batch_size = 2 + + c.seed = 0 + return c + + +def sweep_best(add, arg=None): + """Train with best hyper-params.""" + c = bvcc.parse_arg(arg, final_split=False) + # Note: wd=0.0 probably works as good. + add(lr=1e-5, wd=1e-6, total_epochs=5, **bvcc.arg(res=224, **c)) + add(lr=1e-5, wd=1e-6, total_epochs=5, **bvcc.arg(res=448, **c)) + + +sweep = sweep_best # Choose which sweep to run. + + +def metrics(arg=None): # pylint: disable=unused-argument + m = ['training_loss'] + for split in ('eval', 'minival', 'minitrain'): + m.append(('epoch', f'textcaps/{split}/cider')) + m.append(('epoch', f'textcaps/{split}/pplx/avg')) + return m diff --git a/big_vision/configs/proj/paligemma/transfers/textvqa.py b/big_vision/configs/proj/paligemma/transfers/textvqa.py new file mode 100644 index 0000000000000000000000000000000000000000..0ddbea6e01221a390012f362da7e9375329bfb7f --- /dev/null +++ b/big_vision/configs/proj/paligemma/transfers/textvqa.py @@ -0,0 +1,163 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""PaliGemma transfer to TextVQA. +""" + +import big_vision.configs.common as bvcc +from big_vision.configs.proj.paligemma.transfers.common import combine_and_keep_train, combine_and_keep_eval, TOKENIZER + + +def training_data(res, *, final_split, text_len=32): + """Creates training data config. + + See (internal link) + You can add more arguments beside `res`, but give them good defaults. + + Args: + res: The requested image resolution (eg 224). + final_split: Train on all train+val data. + text_len: sequence length. + + Returns: + The ConfigDict for the input section. + """ + c = bvcc.parse_arg('') # Just make a configdict without extra import. + c.data = dict( + name='textvqa', + split='train+val' if final_split else 'train', + ) + c.pp = '|'.join([ + f'decode|resize({res}, antialias=True)|value_range(-1, 1)', + 'strfmt("answer en {question}", outkey="prefix")', + 'choice_no_replacement(inkey="answers", outkey="suffix")', + combine_and_keep_train(text_len), + ]) + return c + + +def add_eval(c, res, text_len=32, **kw): + """TextVQA evaluators.""" + pp = '|'.join([ + f'decode|resize({res}, antialias=True)|value_range(-1, 1)', + 'strfmt("answer en {question}", outkey="prefix")', + # TextVQA does not have a "answer_type", but if we use an answer_type + # that is used in VQAv2, we can just reuse that evaluator. Use "other". + 'strfmt("other", outkey="answer_type")', + combine_and_keep_eval(text_len, keep=('answers', 'answer_type', 'question_id')), + ]) + + for freq, name, split in [ + (1/8, 'minitrain', 'train[:5120]'), # To gauge memorization. + (1/8, 'eval', 'val'), # To tune hparams. Only 5k samples in val. + (1.0, 'test', 'test'), + ]: + c.evals[f'textvqa/{name}'] = dict( + type='proj.paligemma.transfers.vqav2', + pred='decode', pred_kw={'max_decode_len': text_len}, + outfile=f'{{workdir}}/textvqa_{name}.json', + data={**training_data(res, final_split=True, text_len=text_len).data, + 'split': split}, + log_percent=freq, skip_first=freq == 1, tokenizer=TOKENIZER, pp_fn=pp) + c.evals[f'textvqa/{name}'].update(kw) + + +def add_eval_pplx(c, res, text_len=32): + """Perplexity evaluator to test runs before implementing the real deal.""" + c_train = training_data(res, final_split=True, text_len=text_len) # Use mostly same settings as training. + for name, split in [ + ('minitrain', 'train[:5%]'), # To gauge memorization. + ('eval', 'val'), # To tune hparams. + ]: + c.evals[f'textvqa/{name}/pplx'] = dict( + type='proj.paligemma.perplexity', pred='logits', + key='text', shift_labels=True, + log_percent=0.05, # Eval ~20x per run; it's cheap. + data={**c_train.data, 'split': split}, + pp_fn=c_train.pp, + ) + + +def sweep_best(add, arg=None): + """Train with best hyper-params.""" + c = bvcc.parse_arg(arg, final_split=False) + add(lr=3e-6, wd=0, total_epochs=5, **bvcc.arg(res=224, **c)) + add(lr=3e-6, wd=3e-7, total_epochs=10, **bvcc.arg(res=448, **c)) + add(lr=3e-6, wd=0, total_epochs=10, **bvcc.arg(res=896, **c)) + + +sweep = sweep_best # Choose which sweep to run. + + +def get_config(arg=None): + """Config for training.""" + c = bvcc.parse_arg(arg, mode='xm', res=224, final_split=False) + + c.input = training_data(c.res, final_split=c.final_split) + + # Instead of epochs, you can also use `total_examples` or `total_steps`. + c.total_epochs = 5 # For 448, 10 epochs seems to work better. + c.input.batch_size = 256 + c.optax_name = 'scale_by_adam' + c.optax = dict(b2=0.999) + c.lr = 3e-6 + c.wd = 0.0 # wd doesn't seem to matter much, sometimes 0.1*lr is a bit better. + c.grad_clip_norm = 1.0 + c.label_smoothing = 0.0 + c.schedule = dict(decay_type='cosine', warmup_percent=0.05) + + # Add evaluators. + c.evals = {} + add_eval(c, c.res, batch_size=256) + add_eval_pplx(c, c.res) + + # Model section. + c.model_name = 'proj.paligemma.paligemma' + c.model = {} + c.model.img = dict(variant='So400m/14', pool_type='none', scan=True) + c.model.llm = dict(vocab_size=256_000 + 1024 + 128, dropout=0.0) + c.model_init = f'pt_{c.res}' + + # FSDP strategy. + c.mesh = [('data', -1)] + c.sharding_strategy = [('.*', 'fsdp(axis="data")')] + c.sharding_rules = [('act_batch', ('data',))] + + # These probably do not need any change/tuning + c.input.shuffle_buffer_size = 50_000 + c.log_training_steps = 50 + c.ckpt_steps = 1_000 + c.pp_modules = ['ops_general', 'ops_image', 'ops_text', 'proj.paligemma.ops'] + + # Update configs for quicker local runs and avoid swapping. + if c.mode in ('runlocal', 'mock'): + c.input.shuffle_buffer_size = None + for ev in c.evals.values(): + ev.data.split = ev.data.split.split('[')[0] + '[:16]' + + if c.mode == 'runlocal': + c.log_training_steps = 1 + c.input.batch_size = 2 + + c.seed = 0 + return c + + +def metrics(arg=None): # pylint: disable=unused-argument + m = ['training_loss'] + for split in ('eval', 'minitrain'): + m.append(f'textvqa/{split}/acc') + m.append(f'textvqa/{split}/pplx/avg') + return m diff --git a/big_vision/configs/proj/paligemma/transfers/vatex_cap.py b/big_vision/configs/proj/paligemma/transfers/vatex_cap.py new file mode 100644 index 0000000000000000000000000000000000000000..78ae9f66da6bb2b4db2a6d3039b3d8f174f8de62 --- /dev/null +++ b/big_vision/configs/proj/paligemma/transfers/vatex_cap.py @@ -0,0 +1,210 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""PaliGemma transfer to VATEX Video captioning. + +IMPORTANT: This config is based on an unreleased version of DeepMind Video +Readers (DMVR). Users can either set up DMVR using the open source code from +GitHub (see below for details), or add their own data loader of choice. +""" + +import big_vision.configs.common as bvcc +from big_vision.configs.proj.paligemma.transfers.common import combine_and_keep_train, combine_and_keep_eval, TOKENIZER + +TEXT_LEN = 64 +DATASET_NAME = 'vatex' +# Numbers might need to be updated due to wipeout. Current from 2024-04-28 +SPLIT_SIZE = {'train': 22315, 'valid': 2584, 'test': 5135} + + +def training_data(res, *, final_split, num_frames=8, stride=None): + """Creates training data config. + + Args: + res: The requested image resolution (eg 224). + final_split: Train on all train+valid data. + num_frames: number of sampled frames per video. + stride: stride at which the frames are sampled. + + Returns: + The ConfigDict for the input section. + """ + pp = '|'.join([ + # prepare the frames by decoding, resizing, replicating, sampling: + f'video_decode({res})|video_replicate_img({num_frames},{num_frames})', + f'video_ensure_shape("image", {(num_frames, res, res, 3)})', + # pick one caption at random during training + 'strfmt("caption en", outkey="prefix")', + 'video_choice(inkey="caption/string", outkey="suffix")', + combine_and_keep_train(TEXT_LEN), + ]) + + c = bvcc.parse_arg('') + c.data = {} + splits = ['train', 'valid'] if final_split else ['train'] + raise NotImplementedError('Please implement a video reader of choice!') + # For example DMVR https://github.com/google-deepmind/dmvr + # The reader should support the following arguments: + # - name: Name of the reader. + # - dataset_name: Name of the data set. + # - split: Data set split. + # - num_frames: Number of frames sampled from the video. + # - stride: Stride at which the video frames are sampled. + # - deterministic_fs: Whether to sample the frames starting at the first + # frame or whether an offest should be chosen at random (if there are more + # frames than num_frames * stride) + # - first_k_shards: Whether to only use the first k shards of the data + # (optional but useful for speeding up intermediate evaluations). + for split in splits: + c.data[split] = SPLIT_SIZE[split] + c[split] = {'pp': pp} + c[split].data = dict( + # PLEASE ADD YOUR READER HERE: + name='', + dataset_name=DATASET_NAME, split=split, + num_frames=num_frames, stride=stride, + deterministic_fs=False) + return c + + +def add_eval(c, res, num_frames=8, stride=None): + """Captioning evaluator.""" + c_train = training_data(res, final_split=True, num_frames=num_frames, stride=stride) + + pp = '|'.join([ + # prepare the frames by decoding, resizing, replicating, sampling: + f'video_decode({res})|video_replicate_img({num_frames},{num_frames})', + f'video_ensure_shape("image", {(num_frames, res, res, 3)})', + 'strfmt("caption en", outkey="prefix")', + 'copy("example/video_id", "image/id")', + 'copy("caption/string", "captions")', + combine_and_keep_eval(TEXT_LEN, keep=('image/id', 'captions')), + ]) + + for freq, name, split, first_k_shards, skip_first_eval in [ + (1/8, 'minitrain', 'train', 2, False), # To gauge memorization. + (1/4, 'minival', 'valid', 2, False), # To monitor val progress. + (1, 'val', 'valid', None, False), # To tune hparams. + (1, 'eval', 'test', None, False), # final metric + ]: + c.evals[f'{DATASET_NAME}/{name}'] = dict( + type='proj.paligemma.transfers.coco_caption', + pred='decode', pred_kw={'max_decode_len': TEXT_LEN}, + data={**c_train.train.data, 'split': split, + 'first_k_shards': first_k_shards, + 'deterministic_fs': True}, + log_percent=freq, tokenizer=TOKENIZER, + pp_fn=pp, skip_first=skip_first_eval) + + +def add_eval_pplx(c, res, num_frames=8, stride=None): + """Perplexity evaluator to test runs before implementing the real deal.""" + c_train = training_data(res, final_split=True, num_frames=num_frames, stride=stride) + + for name, split, first_k_shards in [ + ('minitrain', 'train', 2), # To gauge memorization. + ]: + c.evals[f'{DATASET_NAME}/{name}/pplx'] = dict( + type='proj.paligemma.perplexity', pred='logits', + key='text', shift_labels=True, + log_percent=1/8, # Not too cheap, do 10x per run. + data={**c_train.train.data, 'split': split, + 'first_k_shards': first_k_shards, + 'deterministic_fs': True}, + pp_fn=c_train.train.pp, + ) + + +def sweep_best(add, arg=None): + """Train with best hyper-params.""" + c = bvcc.parse_arg(arg, final_split=False) + add(lr=3e-6, wd=3e-7, total_epochs=10, **bvcc.arg(res=224, **c)) + + +sweep = sweep_best + + +def get_config(arg=None): + """Config for training.""" + c = bvcc.parse_arg(arg, mode='xm', num_frames=16, stride=7, res=224, + freeze_vit=False, freeze_llm=False, final_split=False) + + c.input = training_data( + c.res, final_split=c.final_split, + num_frames=c.num_frames, stride=c.stride) + + c.total_epochs = 3 + c.input.batch_size = 128 + c.optax_name = 'scale_by_adam' + c.optax = dict(b2=0.999) + c.lr = 3e-6 + c.wd = 3e-7 + c.grad_clip_norm = 1.0 + c.label_smoothing = 0.0 + + # Learning-rate schedule. + sched = dict(decay_type='cosine', warmup_percent=0.05) + c.schedule = [ + ('img/.*', None if c.freeze_vit else sched), + ('llm/.*', None if c.freeze_llm else sched), + ] + + # Add evaluators. + c.evals = {} + add_eval(c, c.res, c.num_frames, c.stride) + add_eval_pplx(c, c.res, c.num_frames, c.stride) + + # Model section. + c.model_name = 'proj.paligemma.paligemma' + c.model = {} + c.model.img = dict(variant='So400m/14', pool_type='none', scan=True) + c.model.llm = dict(vocab_size=256_000 + 1024 + 128, dropout=0.0) + c.model_init = f'pt_{c.res}' + + # FSDP strategy. + c.mesh = [('data', -1)] + c.sharding_strategy = [('.*', 'fsdp(axis="data")')] + c.sharding_rules = [('act_batch', ('data',))] + + for split in c.input.data.keys(): + c.input[split].shuffle_buffer_size = 10_000 + c.log_training_steps = 50 + c.ckpt_steps = 1_000 + c.pp_modules = ['ops_general', 'ops_image', 'ops_text', 'proj.paligemma.ops', + 'proj.paligemma.video'] + + # Update configs for quicker local runs and avoid swapping. + if c.mode in ('runlocal', 'mock'): + for split in c.input.data.keys(): + c.input[split].shuffle_buffer_size = None + for ev in c.evals.values(): + ev.data.first_k_shards = 1 + + if c.mode == 'runlocal': + c.log_training_steps = 1 + c.input.batch_size = 2 + + c.seed = 0 + return c + + +def metrics(arg=None): # pylint: disable=unused-argument + m = ['training_loss'] + for split in ('minitrain', 'minival', 'val', 'eval'): + m.append((f'{DATASET_NAME}/{split}/cider')) + for split in ('minitrain', 'minival'): + m.append((f'{DATASET_NAME}/{split}/pplx/avg')) + return m + diff --git a/big_vision/configs/proj/paligemma/transfers/vertexai_l4.py b/big_vision/configs/proj/paligemma/transfers/vertexai_l4.py new file mode 100644 index 0000000000000000000000000000000000000000..503f2eebf5d9f0ff0e4c24b2ad343d519bb0fc95 --- /dev/null +++ b/big_vision/configs/proj/paligemma/transfers/vertexai_l4.py @@ -0,0 +1,115 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""PaliGemma transfer to a task stored in JSON-L, designed to fit on an L4 GPU. +""" + +import big_vision.configs.common as bvcc + + +def training_data(res, text_len): + """Creates training data config.""" + c = bvcc.parse_arg('') # Just make a configdict without extra import. + c.data = dict( + name='bv:jsonl', + fname='gs://longcap100/data_train90.jsonl', + fopen_keys={'image': 'gs://longcap100/'}, + # See docstring in datasets/jsonl.py for further details. + # download_keys=['image'], # If jsonl contains external paths. + ) + c.pp = '|'.join([ + # Read and prepare the image by just resizing it: + f'decode|resize({res}, antialias=True)|value_range(-1, 1)', + # The texts are already prepared in `prefix` and `suffix` keys. + 'strfmt("caption en", outkey="prefix")', + combine_and_keep(text_len), + ]) + # Keep the whole dataset in RAM after first pass. Useful optimization for + # small/mid-size datasets, but risks a host OOM for large datasets. + c.cache_raw = True + return c + + +def get_config(arg=None): + """Config for training.""" + # You probably do NOT want to add settings here. The `arg` way of settings is + # really only for things you'd want to sweep and which affect MULTIPLE config + # settings at once or go into the pp string. + c = bvcc.parse_arg(arg, res=224, text_len=128, batch_size=4, + freeze_vit=False, freeze_llm=False) + + c.input = training_data(c.res, c.text_len) + + # These settings are suited for fitting in a single L4. + c.total_epochs = 1 + c.input.batch_size = c.batch_size + c.optax_name = 'big_vision.sgd' # Without momentum, so really low-memory. + c.lr = 0.1 + c.wd = 0.0 + c.grad_clip_norm = 1.0 + c.label_smoothing = 0.0 + + # Learning-rate schedule. Probably is fine like this. + sched = dict(decay_type='cosine', warmup_percent=0.05) + c.schedule = [ + ('img/.*', None if c.freeze_vit else sched), + ('llm/.*', None if c.freeze_llm else sched), + ] + + c.evals = {} + + # Model section. + c.model_name = 'proj.paligemma.paligemma' + c.model = {} + # TODO: b/lbeyer - no scan and no remat might be better on 1-GPU machines? + c.model.img = dict(variant='So400m/14', pool_type='none', scan=True) + c.model.llm = dict(vocab_size=256_000 + 1024 + 128, dropout=0.0) + c.model_init = f'pt_{c.res}' + + # FSDP strategy. + c.mesh = [('data', -1)] + c.sharding_strategy = [('.*', 'fsdp(axis="data")')] + c.sharding_rules = [('act_batch', ('data',))] + + c.input.shuffle_buffer_size = 1000 + c.log_training_steps = 1 + c.ckpt_steps = 1_000 + c.pp_modules = ['ops_general', 'ops_image', 'ops_text', 'proj.paligemma.ops'] + + c.seed = 0 + return c + + +def tok(**kw): + """Creates the tokenization preprocessing string.""" + # Single entry point so that it's consistent everywhere and easier to switch. + kw.setdefault('model', 'gemma(tokensets=("loc", "seg"))') + kw = ', '.join(f'{k}={repr(v)}' for k, v in kw.items()) + return f'tok({kw})' + + +def combine_and_keep(text_len): + return '|'.join([ + tok(key='prefix', bos='yes'), + tok(key='suffix', eos='yes'), + tok(key='septok', text='\n'), + # If masks confuse you, see (internal link) + 'masked_concat(["prefix", "septok", "suffix"], mask_ar=[0, 0, 1], mask_loss=[0, 0, 1])', + # For training, we +1 because the trainer removes EOS. + f'tolen({text_len+1}, pad_value=0, key="text")', # For text, value doesn't matter. + f'tolen({text_len+1}, pad_value=1, key="mask_ar")', + f'tolen({text_len+1}, pad_value=0, key="mask_loss")', + 'keep("image", "text", "mask_ar", "mask_loss")', + ]) diff --git a/big_vision/configs/proj/paligemma/transfers/vizwizvqa.py b/big_vision/configs/proj/paligemma/transfers/vizwizvqa.py new file mode 100644 index 0000000000000000000000000000000000000000..58e361b61cf120ab329384e8cd7a1273b921e5ee --- /dev/null +++ b/big_vision/configs/proj/paligemma/transfers/vizwizvqa.py @@ -0,0 +1,160 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""PaliGemma transfer to VQAv2. +""" + +import big_vision.configs.common as bvcc +from big_vision.configs.proj.paligemma.transfers.common import combine_and_keep_train, combine_and_keep_eval, TOKENIZER + + +def training_data(res, final_split, text_len=48): + """Creates training data config. + + See (internal link) + You can add more arguments beside `res`, but give them good defaults. + + Args: + res: The requested image resolution (eg 224) + final_split: Train on combined train+val + text_len: sequence length + + Returns: + The ConfigDict for the input section. + """ + c = bvcc.parse_arg('') # Just make a configdict without extra import. + c.data = dict( + name='vizwizvqa', + split='train+val' if final_split else 'train', + ) + c.pp = '|'.join([ + f'decode|resize({res}, antialias=True)|value_range(-1, 1)', + 'strfmt("answer en {question}", outkey="prefix")', + 'choice_no_replacement(inkey="answers", outkey="suffix")', + combine_and_keep_train(text_len), + ]) + return c + + +def add_eval(c, res, text_len=48, **kw): + """VQAv2 evaluators.""" + pp = '|'.join([ + f'decode|resize({res})|value_range(-1, 1)', + 'strfmt("answer en {question}", outkey="prefix")', + 'copy("image/filename", "question_id")', + combine_and_keep_eval(text_len, keep=('answers', 'question_id')), + ]) + + for freq, name, split in [ + (1/8, 'minitrain', 'train[:5120]'), # To gauge memorization. 400s on 32v2 + (0.1, 'minival', 'val'), # To tune hparams. 4k samples, full eval is fine. + (1.0, 'test', 'test'), # For the test-server. SLOW. + ]: + c.evals[f'vizwizvqa/{name}'] = dict( + type='proj.paligemma.transfers.vqa', + pred='decode', pred_kw={'max_decode_len': text_len}, + outfile=f'{{workdir}}/vizwiz_{name}.json', + out_question_key='image', + data={**training_data(res, True, text_len).data, 'split': split}, + log_percent=freq, tokenizer=TOKENIZER, pp_fn=pp) + c.evals[f'vizwizvqa/{name}'].update(kw) + + +def add_eval_pplx(c, res, text_len=48): + """Perplexity evaluator to test runs before implementing the real deal.""" + c_train = training_data(res, True, text_len) # Use mostly same settings as training. + for name, split in [ + ('minitrain', 'train'), # To gauge memorization. + ('minival', 'val'), # To tune hparams + ]: + c.evals[f'vizwizvqa/{name}/pplx'] = dict( + type='proj.paligemma.perplexity', pred='logits', + key='text', shift_labels=True, + log_percent=1/8, + data={**c_train.data, 'split': split}, + pp_fn=c_train.pp, + ) + + +def sweep_best(add, arg=None): + """Train with best hyper-params.""" + c = bvcc.parse_arg(arg, final_split=False) + add(**bvcc.arg(res=224, **c)) + add(**bvcc.arg(res=448, **c)) + + +sweep = sweep_best # Choose which sweep to run. + + +def get_config(arg=None): + """Config for training.""" + c = bvcc.parse_arg(arg, mode='xm', res=224, final_split=False) + + c.input = training_data(c.res, c.final_split) + + # Instead of epochs, you can also use `total_examples` or `total_steps`. + c.total_epochs = 10 + c.input.batch_size = 256 + c.optax_name = 'scale_by_adam' + c.optax = dict(b2=0.999) + c.lr = 0.00001 + c.wd = 0.0 + c.grad_clip_norm = 1.0 + c.label_smoothing = 0.0 + c.schedule = dict(decay_type='cosine', warmup_percent=0.05) + + # Add evaluators. + c.evals = {} + add_eval(c, c.res, batch_size=1024) + add_eval_pplx(c, c.res) + + # Model section. + c.model_name = 'proj.paligemma.paligemma' + c.model = {} + c.model.img = dict(variant='So400m/14', pool_type='none', scan=True) + c.model.llm = dict(vocab_size=256_000 + 1024 + 128, dropout=0.0) + c.model_init = f'pt_{c.res}' + + # FSDP strategy. + c.mesh = [('data', -1)] + c.sharding_strategy = [('.*', 'fsdp(axis="data")')] + c.sharding_rules = [('act_batch', ('data',))] + + # These probably do not need any change/tuning + c.input.shuffle_buffer_size = 25_000 + c.log_training_steps = 50 + c.ckpt_steps = 1_000 + c.pp_modules = ['ops_general', 'ops_image', 'ops_text', 'proj.paligemma.ops'] + + # Update configs for quicker local runs and avoid swapping. + if c.mode in ('runlocal', 'mock'): + c.input.shuffle_buffer_size = None + for ev in c.evals.values(): + ev.data.split = ev.data.split.split('[')[0] + '[:16]' + + if c.mode == 'runlocal': + c.log_training_steps = 1 + c.input.batch_size = 2 + + c.seed = 0 + return c + + +def metrics(arg=None): # pylint: disable=unused-argument + m = ['training_loss'] + for split in ('minival', 'minitrain'): + m.append(f'vizwizvqa/{split}/acc') + m.append(f'vizwizvqa/{split}/pplx/avg') + return m diff --git a/big_vision/configs/proj/paligemma/transfers/vqav2.py b/big_vision/configs/proj/paligemma/transfers/vqav2.py new file mode 100644 index 0000000000000000000000000000000000000000..0c7c00c7eef5b66664933600e240e2b556b9cd0b --- /dev/null +++ b/big_vision/configs/proj/paligemma/transfers/vqav2.py @@ -0,0 +1,160 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""PaliGemma transfer to VQAv2. +""" + +import big_vision.configs.common as bvcc +from big_vision.configs.proj.paligemma.transfers.common import combine_and_keep_train, combine_and_keep_eval, TOKENIZER + + +def training_data(res, final_split, text_len=32): + """Creates training data config. + + See (internal link) + You can add more arguments beside `res`, but give them good defaults. + + Args: + res: The requested image resolution (eg 224). + final_split: Whether to use all of the validation data. + text_len: sequence length + + Returns: + The ConfigDict for the input section. + """ + c = bvcc.parse_arg('') # Just make a configdict without extra import. + c.data = dict( + name='vqa', + split='train + validation' if final_split else 'train + validation[:-10240]', + ) + c.pp = '|'.join([ + f'decode|resize({res}, antialias=True)|value_range(-1, 1)', + 'strfmt("answer en {question_text}", outkey="prefix")', + 'choice_no_replacement(inkey="answers", outkey="suffix")', + combine_and_keep_train(text_len), + ]) + return c + + +def add_eval(c, res, text_len=32, **kw): + """VQAv2 evaluators.""" + pp = '|'.join([ + f'decode|resize({res}, antialias=True)|value_range(-1, 1)', + 'strfmt("answer en {question_text}", outkey="prefix")', + combine_and_keep_eval(text_len, keep=('answers', 'answer_type', 'question_type', 'question_id')), + ]) + + for freq, name, split in [ + (1/4, 'minitrain', 'train[:5120]'), # To gauge memorization. + (1/4, 'minival', 'validation[-10240:]'), # To tune hparams. + # To generate final predictions. Test sets combined since 2021 challenge. + (1.0, 'test', 'test + test-dev'), + ]: + c.evals[f'vqav2/{name}'] = dict( + type='proj.paligemma.transfers.vqav2', + pred='decode', pred_kw={'max_decode_len': text_len}, + outfile=f'{{workdir}}/vqav2_{name}.json', + data={**training_data(res, True, text_len).data, 'split': split}, + log_percent=freq, skip_first=freq == 1, tokenizer=TOKENIZER, pp_fn=pp) + c.evals[f'vqav2/{name}'].update(kw) + + +def add_eval_pplx(c, res, text_len=32): + """Perplexity evaluator to test runs before implementing the real deal.""" + c_train = training_data(res, True, text_len) # Use mostly same settings as training. + + for name, split in [ + ('minitrain', 'train[:20_864]'), # To gauge memorization. + ('minival', 'validation[-10240:]'), # To tune hparams. + ]: + c.evals[f'vqav2/{name}/pplx'] = dict( + type='proj.paligemma.perplexity', pred='logits', + key='text', shift_labels=True, + log_percent=1/4, # Not too cheap, do 4x per run. + data={**c_train.data, 'split': split}, + pp_fn=c_train.pp, + ) + + +def sweep_best(add, arg=None): + """Train with best hyper-params.""" + c = bvcc.parse_arg(arg, final_split=False) + # NOTE: lr was highest in sweep. + add(total_epochs=10, lr=1e-5, wd=1e-6, **bvcc.arg(res=224, **c)) + add(total_epochs=10, lr=1e-5, wd=0.00, **bvcc.arg(res=448, **c)) + + +sweep = sweep_best # Choose which sweep to run. + + +def get_config(arg=None): + """Config for training.""" + c = bvcc.parse_arg(arg, mode='xm', res=224, final_split=False) + + c.input = training_data(c.res, c.final_split) + + # Instead of epochs, you can also use `total_examples` or `total_steps`. + c.total_epochs = 10 + c.input.batch_size = 256 + c.optax_name = 'scale_by_adam' + c.lr = 3e-6 + c.wd = 3e-7 + c.grad_clip_norm = 1.0 + c.label_smoothing = 0.0 + c.schedule = dict(decay_type='cosine', warmup_percent=0.05) + + # Add evaluators. + c.evals = {} + add_eval(c, c.res, batch_size=1024) + add_eval_pplx(c, c.res) + + # Model section. + c.model_name = 'proj.paligemma.paligemma' + c.model = {} + c.model.img = dict(variant='So400m/14', pool_type='none', scan=True) + c.model.llm = dict(vocab_size=256_000 + 1024 + 128, dropout=0.0) + c.model_init = f'pt_{c.res}' + + # FSDP strategy. + c.mesh = [('data', -1)] + c.sharding_strategy = [('.*', 'fsdp(axis="data")')] + c.sharding_rules = [('act_batch', ('data',))] + + # These probably do not need any change/tuning + c.input.shuffle_buffer_size = 50_000 + c.log_training_steps = 50 + c.ckpt_steps = 1_000 + c.pp_modules = ['ops_general', 'ops_image', 'ops_text', 'proj.paligemma.ops'] + + # Update configs for quicker local runs and avoid swapping. + if c.mode in ('runlocal', 'mock'): + c.input.shuffle_buffer_size = None + for ev in c.evals.values(): + ev.data.split = ev.data.split.split('[')[0] + '[:16]' + + if c.mode == 'runlocal': + c.log_training_steps = 1 + c.input.batch_size = 2 + + c.seed = 0 + return c + + +def metrics(arg=None): # pylint: disable=unused-argument + m = ['training_loss'] + for split in ('minival', 'minitrain'): + m.append(f'vqav2/{split}/acc') + m.append(f'vqav2/{split}/pplx/avg') + return m diff --git a/big_vision/configs/proj/paligemma/transfers/widgetcap.py b/big_vision/configs/proj/paligemma/transfers/widgetcap.py new file mode 100644 index 0000000000000000000000000000000000000000..d2dbfc4fcac63589ce2481a4331d9c403059f182 --- /dev/null +++ b/big_vision/configs/proj/paligemma/transfers/widgetcap.py @@ -0,0 +1,180 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""PaliGemma transfer to widgetcap (bbox drawn in the picture). +""" + +import big_vision.configs.common as bvcc +from big_vision.configs.proj.paligemma.transfers.common import combine_and_keep_train, combine_and_keep_eval, TOKENIZER + + +def training_data(res, *, final_split, text_len=32): + """Creates training data config. + + See (internal link) + You can add more arguments beside `res`, but give them good defaults. + + Args: + res: The requested image resolution (eg 224). + final_split: Train on all train+dev data. + text_len: The max text length. + + Returns: + The ConfigDict for the input section. + """ + c = bvcc.parse_arg('') # Just make a configdict without extra import. + c.data = dict( + name='widgetcap', + split='train+dev' if final_split else 'train', + ) + c.pp = '|'.join([ + 'decode', + f'resize({res}, antialias=True)', + 'draw_bbox', + 'value_range(-1, 1)', + 'strfmt("caption en", outkey="prefix")', + 'choice_no_replacement(inkey="texts", outkey="suffix")', + combine_and_keep_train(text_len), + ]) + return c + + +def add_eval(c, res, text_len=32, **kw): + """Captioning evaluator with cider/bleu-4/meteor/rouge/spice metrics.""" + # Input eval pp without ground truth text and random crop. + pp_eval = '|'.join([ + 'copy("texts", "captions")', # GT for evaluator. + 'decode', + f'resize({res}, antialias=True)', + 'draw_bbox', + 'value_range(-1, 1)', + 'strfmt("caption en", outkey="prefix")', + combine_and_keep_eval(text_len, keep=('image/id', 'captions')), + ]) + + for name, split in [ + ('val', 'dev'), + ('eval', 'test'), + ]: + c.evals[f'widgetcap/{name}'] = dict( + type='proj.paligemma.transfers.coco_caption', + pred='decode', pred_kw={'max_decode_len': text_len}, + data=dict( + name='widgetcap', + split=split, + ), + log_percent=0.1, tokenizer=TOKENIZER, pp_fn=pp_eval) + c.evals[f'widgetcap/{name}'].update(kw) + + +def add_eval_pplx(c, res, text_len=32): + """Perplexity evaluator to test runs before implementing the real deal.""" + c_train = training_data(res, final_split=True, text_len=text_len) # Use mostly same settings as training. + for name, split in [ + ('minitrain', 'train[:5%]'), # To gauge memorization. + ('minival', 'dev'), # To tune hparams. + ('eval', 'test'), # To compute final publishable scores. + ]: + c.evals[f'widgetcap/{name}/pplx'] = dict( + type='proj.paligemma.perplexity', pred='logits', + key='text', shift_labels=True, + log_percent=0.05, # Eval ~20x per run; it's cheap. + data={**c_train.data, 'split': split}, + pp_fn=c_train.pp, + ) + + +def sweep_best(add, arg=None): + """Train with best hyper-params.""" + c = bvcc.parse_arg(arg, final_split=False) + # Based on sweeps (internal link) (widgetcap/val/cider). + # NOTE: dropout always on, see get_config. + add(lr=3e-6, wd=3e-7, total_epochs=4, **bvcc.arg(res=224, **c)) + add(lr=3e-6, wd=3e-7, total_epochs=4, **bvcc.arg(res=448, **c)) + # Not better: add(lr=3e-6, wd=3e-7, total_epochs=4, **bvcc.arg(res=896, **c)) + + +sweep = sweep_best # Choose which sweep to run. + + +def get_config(arg=None): + """Config for training.""" + c = bvcc.parse_arg(arg, mode='xm', res=224, final_split=False) + + c.input = training_data(c.res, final_split=c.final_split) + + # Instead of epochs, you can also use `total_examples` or `total_steps`. + c.total_epochs = 4 + c.input.batch_size = 64 + c.optax_name = 'scale_by_adam' + c.optax = dict(b2=0.999) + c.lr = 3e-6 + c.wd = 3e-7 + c.grad_clip_norm = 1.0 + c.label_smoothing = 0.1 + c.schedule = dict(decay_type='cosine', warmup_percent=0.05) + + # Add evaluators. + c.evals = {} + add_eval(c, c.res, batch_size=1024) + add_eval_pplx(c, c.res) + + # Model section. + c.model_name = 'proj.paligemma.paligemma' + c.model = {} + c.model.img = dict(variant='So400m/14', pool_type='none', scan=True) + c.model.llm = dict(vocab_size=256_000 + 1024 + 128, dropout=0.1) + c.model_init = f'pt_{c.res}' + + # FSDP strategy. + c.mesh = [('data', -1)] + c.sharding_strategy = [('.*', 'fsdp(axis="data")')] + c.sharding_rules = [('act_batch', ('data',))] + + # These probably do not need any change/tuning + c.input.shuffle_buffer_size = 50_000 + c.log_training_steps = 50 + c.ckpt_steps = 1_000 + c.pp_modules = [ + 'ops_general', + 'ops_image', + 'ops_text', + 'proj.paligemma.ops', + 'proj.paligemma.widgetcap', + ] + + # Update configs for quicker local runs and avoid swapping. + if c.mode in ('runlocal', 'mock'): + c.input.shuffle_buffer_size = None + for ev in c.evals.values(): + ev.data.split = ev.data.split.split('[')[0] + '[:16]' + + if c.mode == 'runlocal': + c.log_training_steps = 1 + c.input.batch_size = 2 + + c.seed = 0 + return c + + +def metrics(arg=None): # pylint: disable=unused-argument + # This function defines the default flatboard. If you want, it can be a lot + # fancier too, but the simplest way is a list of metric names. + m = ['training_loss'] + for split in ('eval', 'minival', 'minitrain'): + m.append(f'widgetcap/{split}/pplx/avg') + for split in ('val', 'eval'): + m.append(f'widgetcap/{split}/cider') + return m diff --git a/big_vision/configs/proj/reward_tune/detection_reward.py b/big_vision/configs/proj/reward_tune/detection_reward.py new file mode 100644 index 0000000000000000000000000000000000000000..73e96c2b75e90ea74a5ffff40187e66a00dd8b4b --- /dev/null +++ b/big_vision/configs/proj/reward_tune/detection_reward.py @@ -0,0 +1,232 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Object detection reward from "Tuning computer vision models with task rewards" (https://arxiv.org/abs/2302.08242). + +The `reward_fn` computes the reward for a batch of predictions and ground truth +annotations. When using it to optimize a model that outputs a prediction as a +sequence of tokens like [y0, x0, Y0, X0, class0, confidence0, y1, x1, Y1, ...] +the training loop may look like: + +``` +# Settings used in the paper. +config.max_level = 1000 # Coordinates are discretized into 1000 buckets. +config.max_conf = 2 # Two tokens are reserved to represent confidence. +config.num_cls = 80 # Number of classes in COCO. +config.nms_w = 0.3 # Weight for duplicate instances. +config.cls_smooth = 0.05 # Adjust the classes weights based on their frequency. +config.reward_thr = (0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95) +config.correct_thr = 0.5 # Learn the IoU when matching with threshold=0.5. +config.conf_w = 0.3 # Weight for the confidence loss. + + +# 1) Sample N outputs for each input and compute rewards, use one sample to +# optimize and others to compute a reward baseline. +sample_seqs = sample_fn(params, images, num_samples) +sample_rewards, aux = reward_fn(sample_seqs, labels, config) +labels = sample_seqs[:, 0, ...] +rewards = sample_rewards[:, 0] +match_iou = aux["match_iou"][:, 0] +baselines = (jnp.sum(sample_rewards, axis=-1) - rewards) / (num_samples - 1) + +# 2) Optimizize the model. By using REINFORCE to adjust the likelihood of the +# sequence based on the reward and with supervision to teach the model to +# predict the expected IoU of each box in its own samples. +def loss_fn(params): + logits = model.apply(params, images, labels, train=True, rngs=rngs) + logits_softmax = jax.nn.log_softmax(logits) + + # Use reinforce to optimize the expected reward for the whole sequence. + seq_rewards = (rewards - baselines) + # Note: consider improve this code to skip this loss for confidence tokens. + # The paper did not do it due to a bug (and also does not seem to matter). + target = jax.nn.one_hot(labels, logits.shape[-1]) * seq_rewards[:, None, None] + loss_reward = -jnp.sum(target * logits_softmax, axis=-1) + + # Use supervision loss to tune the confidence tokens to predict IoU: + # - (1.0, 0.0, 0.0, ...) -> for padded boxes. + # - (0.0, 1-iou, iou, ...) -> for sampled boxes. + conf0 = (labels[:, 5::6] == 0) + conf1 = (labels[:, 5::6] > 0) * (1.0 - match_iou) + conf2 = (labels[:, 5::6] > 0) * match_iou + target_conf = jnp.stack([conf0, conf1, conf2], axis=-1) + logits_conf = logits_softmax[:, 5::6, :3] + loss_conf = -jnp.sum(target_conf * logits_conf, axis=-1) + + loss = jnp.mean(loss_reward) + config.conf_w * jnp.mean(loss_conf) + return loss +``` +""" +import functools + +import einops +import jax +import jax.numpy as jnp + + +# Frequency of COCO object detection classes as observed in the training set. +# pylint: disable=bad-whitespace,bad-continuation +CLS_COUNTS = [ + 262465, 7113, 43867, 8725, 5135, 6069, 4571, 9973, 10759, + 12884, 1865, 1983, 1285, 9838, 10806, 4768, 5508, 6587, + 9509, 8147, 5513, 1294, 5303, 5131, 8720, 11431, 12354, + 6496, 6192, 2682, 6646, 2685, 6347, 9076, 3276, 3747, + 5543, 6126, 4812, 24342, 7913, 20650, 5479, 7770, 6165, + 14358, 9458, 5851, 4373, 6399, 7308, 7852, 2918, 5821, + 7179, 6353, 38491, 5779, 8652, 4192, 15714, 4157, 5805, + 4970, 2262, 5703, 2855, 6434, 1673, 3334, 225, 5610, + 2637, 24715, 6334, 6613, 1481, 4793, 198, 1954 +] +# pylint: enable=bad-whitespace,bad-continuation + + +def seq2box(seq, max_level, max_conf, num_cls): + """Extract boxes encoded as sequences.""" + # Reshape to instances of boxes + dim_per_box = 6 + seq_len = seq.shape[-1] + seq = seq[..., :(seq_len - seq_len % dim_per_box)] + seq = einops.rearrange(seq, "... (n d) -> ... n d", d=dim_per_box) + + # Unpack box fields + boxes, labels, confs = seq[..., 0:4], seq[..., 4], seq[..., 5] + boxes = boxes - max_conf - 1 + labels = labels - max_conf - 1 - max_level - 1 + boxes = jnp.clip(boxes, 0, max_level) / max_level + labels = jnp.clip(labels, 0, num_cls - 1) + confs = jnp.clip(confs, 0, max_conf) + + return boxes, labels, confs + + +def iou_fn(box1, box2): + """Compute IoU of two boxes.""" + ymin1, xmin1, ymax1, xmax1 = box1 + ymin2, xmin2, ymax2, xmax2 = box2 + + a1 = jnp.abs((ymax1 - ymin1) * (xmax1 - xmin1)) + a2 = jnp.abs((ymax2 - ymin2) * (xmax2 - xmin2)) + + yl = jnp.maximum(ymin1, ymin2) + yr = jnp.minimum(ymax1, ymax2) + yi = jnp.maximum(0, yr - yl) + + xl = jnp.maximum(xmin1, xmin2) + xr = jnp.minimum(xmax1, xmax2) + xi = jnp.maximum(0, xr - xl) + + inter = xi * yi + return inter / (a1 + a2 - inter + 1e-9) + +iou_fn_batched = jax.vmap( + jax.vmap(iou_fn, in_axes=(None, 0)), in_axes=(0, None) +) + + +def _reward_fn_thr(seq_pred, seq_gt, + thr, nms_w, max_level, max_conf, num_cls, cls_smooth): + """Compute detection reward function for a given IoU threshold.""" + # Weight matches of each label inversely proportional to the percentage of + # GT instances with such label in the whole train dataset. Additionally + # smooth out the observed distribution. + cls_counts = jnp.array(CLS_COUNTS) + weights = 1.0 / (cls_counts + cls_smooth*jnp.sum(cls_counts)) + weights = num_cls * weights / jnp.sum(weights) + + boxes_pred, labels_pred, confs_pred = seq2box( + seq_pred, max_level, max_conf, num_cls) + boxes_gt, labels_gt, confs_gt = seq2box( + seq_gt, max_level, max_conf, num_cls) + + # Compute IoU matrix: Predictions X GT + iou = iou_fn_batched(boxes_pred, boxes_gt) + + # IoU thr + iou = jnp.where(iou > thr, iou, 0.0) + + # EOS mask + confs_mask = (confs_pred[:, None] > 0) * (confs_gt[None, :] > 0) + iou = confs_mask * iou + + # Label mask + label_mask = labels_pred[:, None] == labels_gt[None, :] + iou = label_mask * iou + + # Each prediction is matched to a single box + single_match_mask = jax.nn.one_hot(jnp.argmax(iou, axis=1), iou.shape[1]) + iou = iou * single_match_mask + + # Pred. boxes indicators + correct = jnp.any(iou > 0.0, axis=1).astype("int32") + 1 + correct = jnp.where(confs_pred > 0, correct, 0) + + # For each GT box find best match + matches_idx = jnp.argmax(iou, axis=0) + matches_iou = jnp.take_along_axis(iou, matches_idx[None], axis=0)[0] + matches_idx = jnp.where(matches_iou > 0.0, matches_idx, -1) + + match_reward = jnp.sum((matches_idx >= 0) * weights[labels_gt][None, :]) + + # Compute duplicate penalty (aka NMS). + matches_mask = jax.nn.one_hot(matches_idx, iou.shape[0], axis=0) + nms_penalty = jnp.sum( + (iou > 0.0) * (1 - matches_mask) * weights[labels_pred][:, None]) + + match_iou = jnp.sum(iou, axis=1) + + return { + "reward": (match_reward - nms_w * nms_penalty), + "num_matches": jnp.sum(matches_idx >= 0), + "nms_penalty": nms_penalty, + "correct": correct, + "match_iou": match_iou, + } + + +def reward_fn(seqs_pred, seqs_gt, config): + """Total reward.""" + result = {} + thrs = config.reward_thr + correct_thr = config.correct_thr + r_keys = ["reward", "num_matches", "nms_penalty"] + for thr in thrs: + fn = functools.partial( + _reward_fn_thr, + thr=thr, + nms_w=config.nms_w, + max_level=config.max_level, + max_conf=config.max_conf, + num_cls=config.num_cls, + cls_smooth=config.cls_smooth, + ) + rewards = jax.vmap(jax.vmap(fn, in_axes=(0, None)))(seqs_pred, seqs_gt) + + result = {**result, **{f"{k}-{thr:0.1f}": rewards[k] + for k in r_keys}} + if thr == correct_thr: + correct = rewards["correct"] + match_iou = rewards["match_iou"] + + result = { + **result, + **{k: jnp.mean( + jnp.array([result[f"{k}-{thr:0.1f}"] for thr in thrs]), axis=0) + for k in r_keys} + } + + return result["reward"], { + "result": result, + "correct": correct, + "match_iou": match_iou, + } diff --git a/big_vision/configs/proj/scaling_laws/train_vit_g.py b/big_vision/configs/proj/scaling_laws/train_vit_g.py new file mode 100644 index 0000000000000000000000000000000000000000..137e51cb629613f473ed7d34584f0117ab308465 --- /dev/null +++ b/big_vision/configs/proj/scaling_laws/train_vit_g.py @@ -0,0 +1,87 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""Pre-train ViT-g (1B params) on JFT-3B as in https://arxiv.org/abs/2106.04560 + +To train ViT-G (2B params), simply update the following single line: + `config.model.variant = 'G/14'` + +The code is released for reference purposes. +One can test the code using public ImageNet-1k or ImageNet-21k dataset. + +big_vision.train \ + --config big_vision/configs/proj/scaling_laws/train_vit_g.py \ + --workdir gs://[your_bucket]/big_vision/`date '+%m-%d_%H%M'` + +""" +from big_vision.configs.common_fewshot import get_fewshot_lsr +import ml_collections as mlc + + +def get_config(): + """Rocket config.""" + config = mlc.ConfigDict() + + config.dataset = 'jft_3b' + config.val_split = 'val' + config.train_split = 'train' + config.num_classes = 29_593 + config.init_head_bias = -10.0 + + # Fits 32 images per TPUv3 core with ViT-g/14. + config.batch_size = 4096*4 + + pp_common = '|value_range(-1, 1)' + pp_common += f'|onehot({config.num_classes})' + pp_common += '|keep("image", "labels")' + config.pp_train = 'inception_crop(224)|flip_lr' + pp_common + config.pp_eval = 'resize_small(256)|central_crop(224)' + pp_common + config.shuffle_buffer_size = 250_000 # Per host, so small-ish is ok. + + config.log_training_steps = 50 + config.log_eval_steps = 1000 + # NOTE: eval is very fast O(seconds) so it's fine to run it often. + + config.ckpt_steps = 1000 + config.keep_ckpt_steps = 10_000 + + config.prefetch_to_device = 1 + config.trial = 0 + + # Model section + config.model_name = 'vit' + config.model = mlc.ConfigDict() + config.model.variant = 'g/14' + config.model.pool_type = 'map' + + # Optimizer section + config.optax_name = 'big_vision.scale_by_adafactor' + config.grad_clip_norm = 1.0 + config.lr = 8e-4 + config.wd = 0.03 * 8e-4 + config.wd_mults = [ + ('.*head/kernel', 100.0), + ('.*/kernel', 1.0), + ] + config.schedule = dict( + decay_type='rsqrt', timescale=10_000, warmup_steps=10_000, + cooldown_steps=50_000) + config.total_steps = 1_000_000 + + # Few-shot eval section + config.evals = {} + config.evals.fewshot = dict(log_steps=10_000, **get_fewshot_lsr()) + + return config diff --git a/big_vision/configs/proj/uvim/README.md b/big_vision/configs/proj/uvim/README.md new file mode 100644 index 0000000000000000000000000000000000000000..8a6bfa8313f3608f2e8c725908248783e6bb1140 --- /dev/null +++ b/big_vision/configs/proj/uvim/README.md @@ -0,0 +1,84 @@ +# UViM: A Unified Modeling Approach for Vision with Learned Guiding Codes + +*by Alexander Kolesnikov, André Susano Pinto, Lucas Beyer, Xiaohua Zhai, Jeremiah Harmsen, Neil Houlsby* + +We provide pretrained UViM models from the [original paper](https://arxiv.org/abs/2205.10337), +as well as the instructions on how to reproduce core paper experiments. + +## Pretrained models + +The table below contains UViM models (stage I and II) trained for three +different tasks: panoptic segmentation, colorization and depth prediction. + +| task | model | dataset | accuracy | download link | +| --------------------- | ------------------- | ------------------------------------------------------------------------ | ------------ | ----------------------------------------------------------------------------------------- | +| Panoptic segmentation | UViM Stage I model | [COCO(2017)](https://cocodataset.org/#home) | 75.8 PQ | [link](https://storage.googleapis.com/big_vision/uvim/panoptic_stageI_params.npz) | +| Panoptic segmentation | UViM Stage II model | [COCO(2017)](https://cocodataset.org/#home) | 43.1 PQ | [link](https://storage.googleapis.com/big_vision/uvim/panoptic_stageII_params.npz) | +| Colorization | UViM Stage I model | [ILSVRC-2012](https://www.image-net.org/) | 15.59 FID | [link](https://storage.googleapis.com/big_vision/uvim/color_stageI_params.npz) | +| Colorization | UViM Stage II model | [ILSVRC-2012](https://www.image-net.org/) | 16.99 FID | [link](https://storage.googleapis.com/big_vision/uvim/color_stageII_params.npz) | +| Depth | UViM Stage I model | [NYU Depth V2](https://cs.nyu.edu/~silberman/datasets/nyu_depth_v2.html) | 0.155 RMSE | [link](https://storage.googleapis.com/big_vision/uvim/depth_stageI_params.npz) | +| Depth | UViM Stage II model | [NYU Depth V2](https://cs.nyu.edu/~silberman/datasets/nyu_depth_v2.html) | 0.463 RMSE | [link](https://storage.googleapis.com/big_vision/uvim/depth_stageII_params.npz) | + +All of this models can be interactively explored in our [colabs](configs/proj/uvim). + +## Running on a single-host TPU machine + +Below we provide instructions on how to run UViM training (stage I and +stage II) using a single TPU host with 8 TPU accelerators. These instructions +can be easily adapted to a GPU host and multi-host TPU setup, see the main +`big_vision` [README file](README.md). + +We assume that the user has already created and `ssh`-ed to the TPU host +machine. The next step is to clone `big_vision` repository: +`git clone https://github.com/google-research/big_vision.git`. + +The next steps are to create a python virtual environment and install python +dependencies: +``` +virtualenv bv +source bv/bin/activate +cd big_vision/ +pip3 install --upgrade pip +pip3 install -r big_vision/requirements.txt +pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html +``` + +After this invoke the helper tool to download and prepare data: +`python3 -m big_vision.tools.download_tfds_datasets coco/2017_panoptic nyu_depth_v2`. +For preparing the ImageNet dataset consult the main codebase README. + +> :warning: TPU machines have 100 GB of the disk space. It may not be enough to +> store all training data (though only panoptic or only depth data may fit). +> Consider preparing the data on a seperate machine and then copying it to +> to TPU machine's extra persistent disk or to a Google Cloud Bucket. See +> instructions for [creating an extra persistent disk](https://cloud.google.com/tpu/docs/users-guide-tpu-vm). +> Remember to set the correct data home directory, e.g.`export DISK=/mnt/disk/persist; export TFDS_DATA_DIR=$DISK/tensorflow_datasets`. + +Our panoptic evaluator uses raw variant of the COCO data, so we move it into a +separate folder. Note, `tfds` has already pre-downloaded the panoptic data, +except for one small json file that we fetch manually: +``` +mkdir $DISK/coco_data +cd $DISK/coco_data +mv $TFDS_DATA_DIR/downloads/extracted/ZIP.image.cocod.org_annot_panop_annot_train.zip/annotations/* . +wget https://raw.githubusercontent.com/cocodataset/panopticapi/master/panoptic_coco_categories.json +export COCO_DATA_DIR=$DISK/coco_data +``` + +For FID evaluator, which is used for the colorization model, set the path to the +directory with image id files, e.g. +`export FID_DATA_DIR=/big_vision/evaluators/proj/uvim/coltran_fid_data`. + +As an example, stage I panoptic training can be invoked as (note the `:singlehost` config parameter which will use lightweight configuration suitable for a single host): +``` +python3 -m big_vision.trainers.proj.uvim.vqvae --config big_vision/configs/proj/uvim/vqvae_coco_panoptic.py:singlehost --workdir workdirs/`date '+%m-%d_%H%M'` +``` +or stage II training +``` +python3 -m big_vision.trainers.proj.uvim.train --config big_vision/configs/proj/uvim/train_coco_panoptic_pretrained.py:singlehost --workdir workdirs/`date '+%m-%d_%H%M'` +``` + +## Acknowledgments +The sampling code in `models/proj/uvim/decode.py` module is based on contributions +from Anselm Levskaya, Ilya Tolstikhin and Maxim Neumann. + diff --git a/big_vision/configs/proj/uvim/train_coco_panoptic_pretrained.py b/big_vision/configs/proj/uvim/train_coco_panoptic_pretrained.py new file mode 100644 index 0000000000000000000000000000000000000000..aa68eb7994115d06cf7b148fba62b89f20e6860a --- /dev/null +++ b/big_vision/configs/proj/uvim/train_coco_panoptic_pretrained.py @@ -0,0 +1,164 @@ +# Copyright 2022 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""A config for training a UViM stage II model for the panoptic task. + +This config is expected to reproduce the paper's result and achieve +approximately 43.7 PQ points on the COCO holdout data. + +We also provide a low-resource variant of this config, which can be enabled +by adding `:singlehost` postfix to the config name. This one is expected to +achieve 39.4 PQ points on the COCO holdout data. +""" + +import big_vision.configs.common as bvcc +from ml_collections import ConfigDict + +VTT_MODELS = { + 'base': dict(num_layers=12, num_heads=12, mlp_dim=3072, emb_dim=768), + 'large': dict(num_layers=24, num_heads=16, mlp_dim=4096, emb_dim=1024), +} + +VQVAE_MODELS = { + 'base': dict(enc_depth=6, dec_depth=12, num_heads=12, mlp_dim=3072, width=768), +} + +RES = 512 +PATCH_SIZE = 16 +LABEL_RES = 512 +LABEL_PATCH_SIZE = 16 + + +def get_config(arg=''): + """Config for training.""" + arg = bvcc.parse_arg(arg, runlocal=False, singlehost=False) + config = ConfigDict() + + config.input = {} + config.input.pp = ( + f'decode|coco_panoptic|concat(["semantics","instances"], "labels")|' + f'randu("fliplr")|det_fliplr(key="image")|det_fliplr(key="labels")|' + f'inception_box|crop_box(key="image")|crop_box(key="labels")|' + f'resize({LABEL_RES}, inkey="image", outkey="image_ctx")|' + f'resize({RES})|resize({LABEL_RES},key="labels",method="nearest")|' + f'value_range(-1, 1, key="image_ctx")|' + f'value_range(-1, 1)|make_canonical|keep("image","image_ctx","labels")' + ) + pp_eval = ( + f'decode|coco_panoptic|concat(["semantics","instances"], "labels")|' + f'resize({LABEL_RES}, inkey="image", outkey="image_ctx")|' + f'resize({RES})|resize({LABEL_RES},key="labels",method="nearest")|' + f'value_range(-1, 1, key="image_ctx")|' + f'value_range(-1, 1)|make_canonical|keep("image","image_ctx","labels")' + ) + pp_predict = ( + f'resize({LABEL_RES}, inkey="image", outkey="image_ctx")|resize({RES})|' + f'value_range(-1, 1, key="image_ctx")|value_range(-1, 1)|' + f'keep("image","image_ctx","image/id")' # image/id used for rng seeds. + ) + + config.input.data = dict(name='coco/2017_panoptic', split='train[4096:]') + config.input.batch_size = 512 + config.input.shuffle_buffer_size = 50_000 + + config.total_epochs = 200 + + config.log_training_steps = 50 + config.ckpt_steps = 1000 + config.keep_ckpt_steps = 5000 + config.prefetch_to_device = 2 + config.seed = 0 + + # Optimizer section + config.optax_name = 'big_vision.scale_by_adafactor' + config.optax = dict(beta2_cap=0.95) + + config.lr = 0.001 + config.wd = 0.000001 + config.lr_mults = [ + ('pos_embedding_encoder.*', 0.1), + ('EmbedPatches.*', 0.1), + ('encoder.*', 0.1), + ('decoder.*', 1.0) + ] + config.schedule = dict(decay_type='cosine', warmup_steps=4_000) + + # Oracle section + config.oracle = ConfigDict() + config.oracle.task = 'proj.uvim.panoptic_task' + config.oracle.model_init = 'gs://big_vision/uvim/panoptic_stageI_params.npz' + config.oracle.model_name = 'proj.uvim.vit' + config.oracle.model = ConfigDict(VQVAE_MODELS['base']) + config.oracle.model.input_size = (LABEL_RES, LABEL_RES) + config.oracle.model.patch_size = (LABEL_PATCH_SIZE, LABEL_PATCH_SIZE) + config.oracle.model.code_len = 256 + config.oracle.model.dict_size = 4096 + config.oracle.model.codeword_dim = 768 + config.oracle.model.with_encoder_ctx = True + config.oracle.model.with_decoder_ctx = True + config.oracle.model.code_dropout = 'random' + config.oracle.model.bottleneck_resize = True + config.oracle.model.inputs = { + 'semantics': (133 + 1, LABEL_PATCH_SIZE**2), # +1 for void label + 'instances': (100, LABEL_PATCH_SIZE**2), # COCO: actually 98 train/78 validation. + } + config.oracle.model.outputs = config.oracle.model.inputs + + # Model section + config.model_name = 'proj.uvim.vtt' + # config.model_init = {'encoder': 'howto-i21k-B/8'} + config.model_init = {'encoder': 'howto-i21k-L/16'} + config.model = ConfigDict(VTT_MODELS['large']) + config.model.patches = ConfigDict({'size': (PATCH_SIZE, PATCH_SIZE)}) + config.model.vocab_size = config.oracle.model.get_ref('dict_size') + 1 + config.model.posemb_type = 'learn' + config.model.input_size = (RES, RES) + config.model.seq_len = config.oracle.model.get_ref('code_len') + + # Evaluation section + config.evals = {} + config.evals.val = ConfigDict() + config.evals.val.type = 'proj.uvim.compute_mean' + config.evals.val.pred = 'validation' + config.evals.val.data = dict(name=config.input.data.name, split='train[:4096]') + config.evals.val.pp_fn = pp_eval + config.evals.val.log_steps = 1000 + + base = { + 'type': 'proj.uvim.coco_panoptic', + 'pp_fn': pp_predict, + 'log_steps': 10_000, + # Filters objects that occupy less than 0.03^2 fraction of all pixels. + # 'predict_kwargs': {'min_fraction': 0.03 ** 2}, + } + config.evals.coco_panoptic_train = dict(**base, split='train[4096:8192]') + config.evals.coco_panoptic_holdout = dict(**base, split='train[:4096]') + config.evals.coco_panoptic = dict(**base, split='validation') + + # config.evals.save_pred = dict(type='proj.uvim.save_predictions') + # config.evals.save_pred.pp = pp_eval.replace('decode|', '') + # config.evals.save_pred.log_steps = 100_000 + # config.evals.save_pred.dataset = config.dataset + # config.evals.save_pred.split = 'validation[:1024]' + # config.evals.save_pred.outfile = 'inference.npz' + + if arg.singlehost: + config.input.batch_size = 32 + config.num_epochs = 50 + elif arg.runlocal: + config.input.batch_size = 4 + config.input.shuffle_buffer_size = 10 + config.evals.val.data.split = 'train[:16]' + return config \ No newline at end of file diff --git a/big_vision/configs/proj/uvim/train_imagenet2012_colorization_pretrained.py b/big_vision/configs/proj/uvim/train_imagenet2012_colorization_pretrained.py new file mode 100644 index 0000000000000000000000000000000000000000..0245db559e2454584be3dc617401be4cf327df78 --- /dev/null +++ b/big_vision/configs/proj/uvim/train_imagenet2012_colorization_pretrained.py @@ -0,0 +1,161 @@ +# Copyright 2022 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""A config for training a UViM stage II model for the colorization task. +""" + +import big_vision.configs.common as bvcc +from ml_collections import ConfigDict + +VTT_MODELS = { + 'base': dict(num_layers=12, num_heads=12, mlp_dim=3072, emb_dim=768), + 'large': dict(num_layers=24, num_heads=16, mlp_dim=4096, emb_dim=1024), +} + +VQVAE_MODELS = { + 'base': dict(enc_depth=6, dec_depth=12, num_heads=12, mlp_dim=3072, width=768), +} + +RES = 512 +PATCH_SIZE = 16 +LABEL_RES = 512 +LABEL_PATCH_SIZE = 16 + + +def get_config(arg=''): + """Config for training.""" + arg = bvcc.parse_arg(arg, runlocal=False, singlehost=False) + config = ConfigDict() + + config.input = {} + config.input.pp = ( + f'decode_jpeg_and_inception_crop({RES})' + f'|flip_lr' + f'|copy(inkey="image", outkey="labels")' + f'|resize({LABEL_RES},inkey="labels",outkey="labels",method="nearest")' + f'|value_range(-1,1,key="labels")' + f'|rgb_to_grayscale_to_rgb(inkey="image",outkey="image")' + f'|value_range(-1,1,key="image")' + f'|copy(inkey="image", outkey="image_ctx")' + f'|resize({LABEL_RES},inkey="image_ctx",outkey="image_ctx")' + f'|keep("image","image_ctx","labels")') + pp_eval = ( + f'decode' + f'|resize({RES})' + f'|copy(inkey="image", outkey="labels")' + f'|resize({LABEL_RES},inkey="labels",outkey="labels",method="nearest")' + f'|value_range(-1,1,key="labels")' + f'|rgb_to_grayscale_to_rgb(inkey="image",outkey="image")' + f'|value_range(-1,1,key="image")' + f'|copy(inkey="image", outkey="image_ctx")' + f'|resize({LABEL_RES},inkey="image_ctx",outkey="image_ctx")' + f'|strong_hash(inkey="tfds_id", outkey="image/id")' + f'|keep("image","image_ctx","labels","image/id")') + + config.input.data = dict(name='imagenet2012', split='train[4096:]') + config.input.batch_size = 512 + config.input.shuffle_buffer_size = 50_000 + + config.total_epochs = 50 + + config.log_training_steps = 50 + config.ckpt_steps = 1000 + config.keep_ckpt_steps = 5000 + config.prefetch_to_device = 2 + config.seed = 0 + + # Optimizer section + config.optax_name = 'big_vision.scale_by_adafactor' + config.optax = dict(beta2_cap=0.95) + + config.lr = 0.001 + config.wd = 0.000001 + config.lr_mults = [ + ('pos_embedding_encoder.*', 0.1), + ('EmbedPatches.*', 0.1), + ('encoder.*', 0.1), + ('decoder.*', 1.0) + ] + config.schedule = dict(decay_type='cosine', warmup_steps=4_000) + + # Oracle section + config.oracle = ConfigDict() + config.oracle.task = 'proj.uvim.colorization_task' + config.oracle.model_init = 'gs://big_vision/uvim/color_stageI_params.npz' + config.oracle.model_name = 'proj.uvim.vit' + config.oracle.model = ConfigDict(VQVAE_MODELS['base']) + config.oracle.model.input_size = (LABEL_RES, LABEL_RES) + config.oracle.model.patch_size = (LABEL_PATCH_SIZE, LABEL_PATCH_SIZE) + config.oracle.model.code_len = 256 + config.oracle.model.dict_size = 4096 + config.oracle.model.codeword_dim = 768 + config.oracle.model.with_encoder_ctx = True + config.oracle.model.with_decoder_ctx = True + config.oracle.model.code_dropout = 'random' + config.oracle.model.bottleneck_resize = True + config.oracle.model.inputs = { + 'color': (3, LABEL_PATCH_SIZE**2), + } + config.oracle.model.outputs = config.oracle.model.inputs + + # Model section + config.model_name = 'proj.uvim.vtt' + # config.model_init = {'encoder': 'howto-i21k-B/8'} + config.model_init = {'encoder': 'howto-i21k-L/16'} + config.model = ConfigDict(VTT_MODELS['large']) + config.model.patches = ConfigDict({'size': (PATCH_SIZE, PATCH_SIZE)}) + config.model.vocab_size = config.oracle.model.get_ref('dict_size') + 1 + config.model.posemb_type = 'learn' + config.model.input_size = (RES, RES) + config.model.seq_len = config.oracle.model.get_ref('code_len') + + # Evaluation section + config.evals = {} + config.evals.val = ConfigDict() + config.evals.val.type = 'proj.uvim.compute_mean' + config.evals.val.pred = 'validation' + config.evals.val.data = dict(name=config.input.data.name, split='train[:4096]') + config.evals.val.pp_fn = pp_eval + config.evals.val.log_steps = 1000 + + base = { + 'type': 'proj.uvim.psnr', + 'pp_fn': pp_eval.replace('decode|', ''), + 'log_steps': 10_000, + } + config.evals.psnr_train = dict(**base, split='train[4096:8192]') + config.evals.psnr_holdout = dict(**base, split='train[:4096]') + config.evals.psnr_val = dict(**base, split='validation') + + config.evals.colorization_val_coltran_fid = { + 'type': 'proj.uvim.coltran_fid', + 'log_steps': 100_000, + } + + # config.evals.save_pred = dict(type='proj.uvim.save_predictions') + # config.evals.save_pred.pp_fn = pp_eval.replace('decode|', '') + # config.evals.save_pred.log_steps = 100_000 + # config.evals.save_pred.dataset = config.dataset + # config.evals.save_pred.split = 'validation[:1024]' + # config.evals.save_pred.outfile = 'inference.npz' + + if arg.singlehost: + config.input.batch_size = 32 + config.total_epochs = 20 + elif arg.runlocal: + config.input.batch_size = 8 + config.input.shuffle_buffer_size = 10 + config.evals.val.data.split = 'validation[:256]' + return config \ No newline at end of file diff --git a/big_vision/configs/proj/uvim/train_nyu_depth_pretrained.py b/big_vision/configs/proj/uvim/train_nyu_depth_pretrained.py new file mode 100644 index 0000000000000000000000000000000000000000..d2e05810708e03ad9e198025a8641eec53811830 --- /dev/null +++ b/big_vision/configs/proj/uvim/train_nyu_depth_pretrained.py @@ -0,0 +1,170 @@ +# Copyright 2022 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""A config for training a UViM stage II model for the depth task. +""" + +import big_vision.configs.common as bvcc +from ml_collections import ConfigDict + + +VTT_MODELS = { + 'base': dict(num_layers=12, num_heads=12, mlp_dim=3072, emb_dim=768), + 'large': dict(num_layers=24, num_heads=16, mlp_dim=4096, emb_dim=1024), +} + +VQVAE_MODELS = { + 'base': dict(enc_depth=6, dec_depth=12, num_heads=12, mlp_dim=3072, width=768), +} + + +RES = 512 +PATCH_SIZE = 16 +LABEL_RES = 512 +LABEL_PATCH_SIZE = 16 +QUANTIZATION_BINS = 256 +# Same as values used in eval, see evaluators/nyu_depth.py. +MIN_DEPTH = 1e-3 +MAX_DEPTH = 10 + + +def get_config(arg='split=final'): + """Config for training.""" + arg = bvcc.parse_arg(arg, split='final', runlocal=False, singlehost=False) + config = ConfigDict() + + config.input = {} + config.input.pp = ( + f'decode|nyu_depth|' + f'randu("fliplr")|det_fliplr(key="image")|det_fliplr(key="labels")|' + f'inception_box|crop_box(key="image")|crop_box(key="labels")|' + f'resize({RES})|' + f'resize({LABEL_RES},inkey="image",outkey="image_ctx")|' + f'resize({LABEL_RES},key="labels",method="nearest")|' + f'value_range(-1,1)|' + f'value_range(-1,1,inkey="image_ctx",outkey="image_ctx")|' + f'keep("image","image_ctx","labels")' + ) + pp_eval = ( + f'decode|nyu_depth|' + f'nyu_eval_crop|' + f'resize({RES})|' + f'resize({LABEL_RES},inkey="image",outkey="image_ctx")|' + f'resize({LABEL_RES},key="labels",method="nearest")|' + f'value_range(-1,1)|' + f'value_range(-1,1,inkey="image_ctx",outkey="image_ctx")|' + f'keep("image","image_ctx","labels")' + ) + pp_predict = ( + f'nyu_depth|' + f'nyu_eval_crop|copy("labels","ground_truth")|' + f'resize({RES})|' + f'resize({LABEL_RES},inkey="image",outkey="image_ctx")|' + f'value_range(-1,1)|' + f'value_range(-1,1,inkey="image_ctx",outkey="image_ctx")|' + f'keep("image","image_ctx","ground_truth")' + ) + + config.input.data = dict(name='nyu_depth_v2', split='train') + config.input.batch_size = 512 + config.input.shuffle_buffer_size = 50_000 + + config.total_epochs = 50 + + config.log_training_steps = 50 + config.ckpt_steps = 1000 + config.keep_ckpt_steps = 5000 + config.prefetch_to_device = 2 + config.seed = 0 + + # Optimizer section + config.optax_name = 'big_vision.scale_by_adafactor' + config.optax = dict(beta2_cap=0.95) + config.optax.clipping_threshold = None + + config.lr = 0.001 + config.wd = 0.000001 + config.lr_mults = ( + ('pos_embedding_encoder.*', 0.1), + ('EmbedPatches.*', 0.1), + ('encoder.*', 0.1), + ('decoder.*', 1.0) + ) + config.schedule = dict(decay_type='cosine', warmup_steps=4_000) + + # Oracle section + config.oracle = ConfigDict() + config.oracle.min_depth = MIN_DEPTH + config.oracle.max_depth = MAX_DEPTH + config.oracle.task = 'proj.uvim.depth_task' + config.oracle.model_init = 'gs://big_vision/uvim/depth_stageI_params.npz' + config.oracle.model_name = 'proj.uvim.vit' + config.oracle.model = ConfigDict(VQVAE_MODELS['base']) + config.oracle.model.input_size = (LABEL_RES, LABEL_RES) + config.oracle.model.patch_size = (LABEL_PATCH_SIZE, LABEL_PATCH_SIZE) + config.oracle.model.code_len = 256 + config.oracle.model.dict_size = 4096 + config.oracle.model.codeword_dim = 768 + config.oracle.model.with_encoder_ctx = True + config.oracle.model.with_decoder_ctx = True + config.oracle.model.code_dropout = 'random' + config.oracle.model.bottleneck_resize = True + config.oracle.model.inputs = { + 'depth': (QUANTIZATION_BINS, LABEL_PATCH_SIZE**2,), + } + config.oracle.model.outputs = config.oracle.model.inputs + + # Model section + config.model_name = 'proj.uvim.vtt' + # config.model_init = {'encoder': 'howto-i21k-B/8''} # B/8 I21K + config.model_init = {'encoder': 'howto-i21k-L/16'} # L/16 I21K + config.model = ConfigDict(VTT_MODELS['large']) + config.model.patches = ConfigDict({'size': (PATCH_SIZE, PATCH_SIZE)}) + config.model.vocab_size = config.oracle.model.dict_size + 1 + config.model.posemb_type = 'learn' + config.model.input_size = (RES, RES) + config.model.seq_len = config.oracle.model.get_ref('code_len') + config.model.zero_decoder_seq = False + + # Evaluation section + config.evals = {} + config.evals.val = ConfigDict() + config.evals.val.type = 'proj.uvim.compute_mean' + config.evals.val.pred = 'validation' + config.evals.val.data = {**config.input.data} + config.evals.val.data.split = 'validation' + config.evals.val.pp_fn = pp_eval + config.evals.val.log_steps = 1000 + + base = { + 'type': 'proj.uvim.nyu_depth', + 'dataset': config.input.data.name, + 'pp_fn': pp_predict, + 'log_steps': 2000, + 'min_depth': MIN_DEPTH, + 'max_depth': MAX_DEPTH, + } + config.evals.nyu_depth_val = dict(**base, split='validation') + + if arg.singlehost: + config.input.batch_size = 32 + config.total_epochs = 20 + elif arg.runlocal: + config.oracle.model_init = '/tmp/checkpoint.npz' + config.model_init = {'encoder': '/tmp/enc_checkpoint.npz'} + config.evals = {} + config.input.batch_size = 1 + config.input.shuffle_buffer_size = 10 + return config \ No newline at end of file diff --git a/big_vision/configs/proj/uvim/uvim_color_task.ipynb b/big_vision/configs/proj/uvim/uvim_color_task.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..a1ded392586375100e807f0bcdfd7b07be88bd5e --- /dev/null +++ b/big_vision/configs/proj/uvim/uvim_color_task.ipynb @@ -0,0 +1,167 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "UViM color task", + "provenance": [], + "collapsed_sections": [], + "private_outputs": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + }, + "accelerator": "GPU", + "gpuClass": "standard" + }, + "cells": [ + { + "cell_type": "code", + "source": [ + "# Fetch big_vision repository and move it into the current workdir (import path).\n", + "!git clone --depth=1 https://github.com/google-research/big_vision big_vision_repo\n", + "!cp -R big_vision_repo/big_vision big_vision\n", + "!pip install -qr big_vision/requirements.txt" + ], + "metadata": { + "id": "sKZK6_QpVI_O" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "import jax\n", + "import jax.numpy as jnp\n", + "import numpy as np\n", + "\n", + "from big_vision.models.proj.uvim import vtt # stage-II model\n", + "from big_vision.models.proj.uvim import vit # stage-I model\n", + "\n", + "from big_vision.models.proj.uvim import decode\n", + "from big_vision.trainers.proj.uvim import colorization_task as task\n", + "from big_vision.configs.proj.uvim import train_imagenet2012_colorization_pretrained as config_module\n", + "\n", + "import big_vision.pp.ops_image\n", + "import big_vision.pp.ops_general\n", + "import big_vision.pp.proj.uvim.pp_ops\n", + "from big_vision.pp import builder as pp_builder\n", + "\n", + "config = config_module.get_config()\n", + "res = 512\n", + "seq_len = config.model.seq_len\n", + "\n", + "lm_model = vtt.Model(**config.model)\n", + "oracle_model = vit.Model(**config.oracle.model)\n", + "\n", + "preprocess_fn = pp_builder.get_preprocess_fn(\n", + " 'decode|resize(512)|'\n", + " 'rgb_to_grayscale_to_rgb|value_range(-1,1)|'\n", + " 'copy(inkey=\"image\",outkey=\"image_ctx\")')\n", + "\n", + "@jax.jit\n", + "def predict_code(params, x, rng, temperature):\n", + " prompts = jnp.zeros((x[\"image\"].shape[0], seq_len), dtype=jnp.int32)\n", + " seqs, _, _ = decode.temperature_sampling(\n", + " params=params, model=lm_model, seed=rng,\n", + " inputs=x[\"image\"],\n", + " prompts=prompts,\n", + " temperature=temperature,\n", + " num_samples=1, eos_token=-1, prefill=False)\n", + " seqs = jnp.squeeze(seqs, axis=1) # drop num_samples axis \n", + " return seqs - 1\n", + " \n", + "@jax.jit\n", + "def labels2code(params, x, ctx):\n", + " y, aux = oracle_model.apply(params, x, ctx=ctx, train=False, method=oracle_model.encode)\n", + " return aux[\"code\"]\n", + "\n", + "@jax.jit\n", + "def code2labels(params, code, ctx):\n", + " logits, aux = oracle_model.apply(params, code, ctx=ctx, train=False, discrete_input=True, method=oracle_model.decode)\n", + " return task.predict_outputs(logits, config.oracle)" + ], + "metadata": { + "id": "QzThueWDzc7I" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# Load checkpoints\n", + "!gsutil cp -n gs://big_vision/uvim/color_stageI_params.npz gs://big_vision/uvim/color_stageII_params.npz .\n", + "\n", + "oracle_params, oracle_state = vit.load(None, \"color_stageI_params.npz\")\n", + "oracle_params = jax.device_put({\"params\": oracle_params, \"state\": oracle_state})\n", + "\n", + "lm_params = vtt.load(None, \"color_stageII_params.npz\")\n", + "lm_params = jax.device_put({\"params\": lm_params})" + ], + "metadata": { + "id": "AEjRgshLa6Fp" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# Prepare set of images from coco/val2017:\n", + "# - https://cocodataset.org/\n", + "import os\n", + "import tensorflow as tf\n", + "\n", + "if not os.path.exists(\"val2017/\"):\n", + " !wget --no-clobber http://images.cocodataset.org/zips/val2017.zip\n", + " !unzip -uq val2017.zip\n", + "\n", + "dataset = tf.data.Dataset.list_files(\"val2017/*.jpg\", shuffle=True)\n", + "dataset = dataset.map(lambda filename: {\"image\": tf.io.read_file(filename)})\n", + "dataset = dataset.map(preprocess_fn)" + ], + "metadata": { + "id": "BKifDDRnH_Ll" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# Run the model in a few examples:\n", + "from matplotlib import pyplot as plt\n", + "\n", + "num_examples = 4\n", + "data = dataset.batch(1).take(num_examples).as_numpy_iterator()\n", + "key = jax.random.PRNGKey(0)\n", + "temperature = jnp.array(1.0)\n", + "\n", + "def render_example(image, prediction):\n", + " f, ax = plt.subplots(1, 2, figsize=(10, 10))\n", + " ax[0].imshow(image*0.5 + 0.5)\n", + " ax[0].axis(\"off\")\n", + " ax[1].imshow(prediction*0.5 + 0.5)\n", + " ax[1].axis(\"off\")\n", + "\n", + "for idx, batch in enumerate(data):\n", + " subkey = jax.random.fold_in(key, idx)\n", + " code = predict_code(lm_params, batch, key, temperature)\n", + " aux_inputs = task.input_pp(batch, config.oracle)\n", + " prediction = code2labels(oracle_params, code, aux_inputs[\"ctx\"])\n", + " render_example(batch[\"image\"][0], prediction[\"color\"][0])" + ], + "metadata": { + "id": "TuevCy33nuv3" + }, + "execution_count": null, + "outputs": [] + } + ] +} \ No newline at end of file diff --git a/big_vision/configs/proj/uvim/uvim_depth_task.ipynb b/big_vision/configs/proj/uvim/uvim_depth_task.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..a4bfb3f0381238b4a0c3048006ff9fbc0ec848c7 --- /dev/null +++ b/big_vision/configs/proj/uvim/uvim_depth_task.ipynb @@ -0,0 +1,181 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "UViM depth task", + "provenance": [], + "collapsed_sections": [], + "private_outputs": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + }, + "accelerator": "GPU", + "gpuClass": "standard" + }, + "cells": [ + { + "cell_type": "code", + "source": [ + "# Fetch big_vision repository and move it into the current workdir (import path).\n", + "!git clone --depth=1 https://github.com/google-research/big_vision big_vision_repo\n", + "!cp -R big_vision_repo/big_vision big_vision\n", + "!pip install -qr big_vision/requirements.txt" + ], + "metadata": { + "id": "sKZK6_QpVI_O" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "import jax\n", + "import jax.numpy as jnp\n", + "import numpy as np\n", + "\n", + "from big_vision.models.proj.uvim import vtt # stage-II model\n", + "from big_vision.models.proj.uvim import vit # stage-I model\n", + "\n", + "from big_vision.models.proj.uvim import decode\n", + "from big_vision.trainers.proj.uvim import depth_task as task\n", + "from big_vision.configs.proj.uvim import train_nyu_depth_pretrained as config_module\n", + "\n", + "import big_vision.pp.ops_image\n", + "import big_vision.pp.ops_general\n", + "import big_vision.pp.proj.uvim.pp_ops\n", + "from big_vision.pp import builder as pp_builder\n", + "\n", + "config = config_module.get_config()\n", + "res = 512\n", + "seq_len = config.model.seq_len\n", + "\n", + "lm_model = vtt.Model(**config.model)\n", + "oracle_model = vit.Model(**config.oracle.model)\n", + "\n", + "preprocess_fn = pp_builder.get_preprocess_fn(\n", + " 'resize(512)|value_range(-1,1)|'\n", + " 'copy(inkey=\"image\",outkey=\"image_ctx\")')\n", + "\n", + "@jax.jit\n", + "def predict_code(params, x, rng, temperature):\n", + " prompts = jnp.zeros((x[\"image\"].shape[0], seq_len), dtype=jnp.int32)\n", + " seqs, _, _ = decode.temperature_sampling(\n", + " params=params, model=lm_model, seed=rng,\n", + " inputs=x[\"image\"],\n", + " prompts=prompts,\n", + " temperature=temperature,\n", + " num_samples=1, eos_token=-1, prefill=False)\n", + " seqs = jnp.squeeze(seqs, axis=1) # drop num_samples axis \n", + " return seqs - 1\n", + " \n", + "@jax.jit\n", + "def labels2code(params, x, ctx):\n", + " y, aux = oracle_model.apply(params, x, ctx=ctx, train=False, method=oracle_model.encode)\n", + " return aux[\"code\"]\n", + "\n", + "@jax.jit\n", + "def code2labels(params, code, ctx):\n", + " logits, aux = oracle_model.apply(params, code, ctx=ctx, train=False, discrete_input=True, method=oracle_model.decode)\n", + " return task.predict_outputs(logits, config.oracle)" + ], + "metadata": { + "id": "QzThueWDzc7I" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# Load checkpoints\n", + "!gsutil cp -n gs://big_vision/uvim/depth_stageI_params.npz gs://big_vision/uvim/depth_stageII_params.npz .\n", + "\n", + "oracle_params, oracle_state = vit.load(None, \"depth_stageI_params.npz\")\n", + "oracle_params = jax.device_put({\"params\": oracle_params, \"state\": oracle_state})\n", + "\n", + "lm_params = vtt.load(None, \"depth_stageII_params.npz\")\n", + "lm_params = jax.device_put({\"params\": lm_params})" + ], + "metadata": { + "id": "AEjRgshLa6Fp" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# Prepare dataset of images from NYU Depth V2:\n", + "# - https://cs.nyu.edu/~silberman/datasets/nyu_depth_v2.html\n", + "import os\n", + "import h5py\n", + "import numpy as np\n", + "import tensorflow as tf\n", + "\n", + "if not os.path.exists(\"nyu_depth_v2_labeled.mat\"):\n", + " !wget --no-clobber http://horatio.cs.nyu.edu/mit/silberman/nyu_depth_v2/nyu_depth_v2_labeled.mat\n", + "\n", + "dataset_file = h5py.File(\"nyu_depth_v2_labeled.mat\", \"r\")\n", + "\n", + "def nyu_depth_examples():\n", + " for idx in range(dataset_file[\"images\"].shape[0]):\n", + " image = np.transpose(dataset_file[\"images\"][idx], (2, 1, 0))\n", + " yield {\"image\": image}\n", + "\n", + "dataset = tf.data.Dataset.from_generator(\n", + " nyu_depth_examples,\n", + " output_signature={\n", + " \"image\": tf.TensorSpec((480,640,3), tf.uint8),\n", + " }).map(preprocess_fn)" + ], + "metadata": { + "id": "BKifDDRnH_Ll" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# Run the model in a few examples:\n", + "from matplotlib import pyplot as plt\n", + "from matplotlib import patches\n", + "\n", + "num_examples = 4\n", + "data = dataset.batch(1).take(num_examples).as_numpy_iterator()\n", + "key = jax.random.PRNGKey(0)\n", + "temperature = jnp.array(1e-7)\n", + "\n", + "def to_depth(x, nbins=256, mind=1e-3, maxd=10):\n", + " depth = x.astype(np.float32) + 0.5 # Undoes floor in expectation.\n", + " return depth/nbins * (maxd - mind) + mind\n", + "\n", + "def render_example(image, prediction, with_legend=True):\n", + " f, ax = plt.subplots(1, 2, figsize=(10, 10))\n", + " ax[0].imshow(image*0.5 + 0.5)\n", + " ax[0].axis(\"off\")\n", + " ax[1].imshow(to_depth(prediction))\n", + " ax[1].axis(\"off\")\n", + "\n", + "for idx, batch in enumerate(data):\n", + " subkey = jax.random.fold_in(key, idx)\n", + " code = predict_code(lm_params, batch, key, temperature)\n", + " aux_inputs = task.input_pp(batch, config.oracle)\n", + " prediction = code2labels(oracle_params, code, aux_inputs[\"ctx\"])\n", + " render_example(batch[\"image\"][0], prediction[\"depth\"][0])" + ], + "metadata": { + "id": "TuevCy33nuv3" + }, + "execution_count": null, + "outputs": [] + } + ] +} \ No newline at end of file diff --git a/big_vision/configs/proj/uvim/uvim_panoptic_task.ipynb b/big_vision/configs/proj/uvim/uvim_panoptic_task.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..c03324e5a6735f96c72f1eaecbba3cbf2b6dc843 --- /dev/null +++ b/big_vision/configs/proj/uvim/uvim_panoptic_task.ipynb @@ -0,0 +1,180 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "UViM panoptic task", + "provenance": [], + "collapsed_sections": [], + "private_outputs": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + }, + "accelerator": "GPU", + "gpuClass": "standard" + }, + "cells": [ + { + "cell_type": "code", + "source": [ + "# Fetch big_vision repository and move it into the current workdir (import path).\n", + "!git clone --depth=1 https://github.com/google-research/big_vision big_vision_repo\n", + "!cp -R big_vision_repo/big_vision big_vision\n", + "!pip install -qr big_vision/requirements.txt" + ], + "metadata": { + "id": "sKZK6_QpVI_O" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "import jax\n", + "import jax.numpy as jnp\n", + "import numpy as np\n", + "\n", + "from big_vision.models.proj.uvim import vtt # stage-II model\n", + "from big_vision.models.proj.uvim import vit # stage-I model\n", + "\n", + "from big_vision.models.proj.uvim import decode\n", + "from big_vision.trainers.proj.uvim import panoptic_task as task\n", + "from big_vision.configs.proj.uvim import train_coco_panoptic_pretrained as config_module\n", + "\n", + "import big_vision.pp.ops_image\n", + "import big_vision.pp.ops_general\n", + "import big_vision.pp.proj.uvim.pp_ops\n", + "from big_vision.pp import builder as pp_builder\n", + "\n", + "config = config_module.get_config()\n", + "res = 512\n", + "seq_len = config.model.seq_len\n", + "\n", + "lm_model = vtt.Model(**config.model)\n", + "oracle_model = vit.Model(**config.oracle.model)\n", + "\n", + "preprocess_fn = pp_builder.get_preprocess_fn(\n", + " 'decode|resize(512)|value_range(-1,1)|'\n", + " 'copy(inkey=\"image\",outkey=\"image_ctx\")')\n", + "\n", + "@jax.jit\n", + "def predict_code(params, x, rng, temperature):\n", + " prompts = jnp.zeros((x[\"image\"].shape[0], seq_len), dtype=jnp.int32)\n", + " seqs, _, _ = decode.temperature_sampling(\n", + " params=params, model=lm_model, seed=rng,\n", + " inputs=x[\"image\"],\n", + " prompts=prompts,\n", + " temperature=temperature,\n", + " num_samples=1, eos_token=-1, prefill=False)\n", + " seqs = jnp.squeeze(seqs, axis=1) # drop num_samples axis \n", + " return seqs - 1\n", + " \n", + "@jax.jit\n", + "def labels2code(params, x, ctx):\n", + " y, aux = oracle_model.apply(params, x, ctx=ctx, train=False, method=oracle_model.encode)\n", + " return aux[\"code\"]\n", + "\n", + "@jax.jit\n", + "def code2labels(params, code, ctx):\n", + " logits, aux = oracle_model.apply(params, code, ctx=ctx, train=False, discrete_input=True, method=oracle_model.decode)\n", + " return task.predict_outputs(logits, config.oracle)" + ], + "metadata": { + "id": "QzThueWDzc7I" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# Load checkpoints\n", + "!gsutil cp -n gs://big_vision/uvim/panoptic_stageI_params.npz gs://big_vision/uvim/panoptic_stageII_params.npz .\n", + "\n", + "oracle_params, oracle_state = vit.load(None, \"panoptic_stageI_params.npz\")\n", + "oracle_params = jax.device_put({\"params\": oracle_params, \"state\": oracle_state})\n", + "\n", + "lm_params = vtt.load(None, \"panoptic_stageII_params.npz\")\n", + "lm_params = jax.device_put({\"params\": lm_params})" + ], + "metadata": { + "id": "AEjRgshLa6Fp" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# Prepare set of images from coco/val2017:\n", + "# - https://cocodataset.org/\n", + "import os\n", + "import tensorflow as tf\n", + "\n", + "if not os.path.exists(\"val2017/\"):\n", + " !wget --no-clobber http://images.cocodataset.org/zips/val2017.zip\n", + " !unzip -uq val2017.zip\n", + " !wget -c https://raw.githubusercontent.com/cocodataset/panopticapi/master/panoptic_coco_categories.json\n", + "\n", + "dataset = tf.data.Dataset.list_files(\"val2017/*.jpg\", shuffle=True)\n", + "dataset = dataset.map(lambda filename: {\"image\": tf.io.read_file(filename)})\n", + "dataset = dataset.map(preprocess_fn)" + ], + "metadata": { + "id": "k2ArKPlFQVcz" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# Run the model in a few examples:\n", + "from matplotlib import pyplot as plt\n", + "from matplotlib import patches\n", + "from big_vision.trainers.proj.uvim import coco_utils\n", + "\n", + "num_examples = 4\n", + "data = dataset.batch(1).take(num_examples).as_numpy_iterator()\n", + "key = jax.random.PRNGKey(0)\n", + "temperature = jnp.array(1e-7)\n", + "\n", + "def render_example(image, prediction, with_legend=True):\n", + " f, ax = plt.subplots(1, 2, figsize=(10, 10))\n", + " ax[0].imshow(image*0.5 + 0.5)\n", + " ax[0].axis(\"off\")\n", + "\n", + " rgb, info = coco_utils.rgb_panoptic_from_twochannels(prediction, boundaries=True)\n", + " ax[1].matshow(rgb)\n", + " ax[1].axis(\"off\")\n", + "\n", + " if with_legend:\n", + " handles = []\n", + " for instance in info.values():\n", + " handles.append(patches.Patch(\n", + " facecolor=np.array(instance[\"color\"])/255.0,\n", + " edgecolor='black', label=instance[\"name\"]))\n", + " ax[1].legend(handles=handles, loc=(1.04, 0.0));\n", + "\n", + "\n", + "for idx, batch in enumerate(data):\n", + " subkey = jax.random.fold_in(key, idx)\n", + " code = predict_code(lm_params, batch, key, temperature)\n", + " aux_inputs = task.input_pp(batch, config.oracle)\n", + " prediction = code2labels(oracle_params, code, aux_inputs[\"ctx\"])\n", + " render_example(batch[\"image\"][0], prediction[0])" + ], + "metadata": { + "id": "TuevCy33nuv3" + }, + "execution_count": null, + "outputs": [] + } + ] +} \ No newline at end of file diff --git a/big_vision/configs/proj/uvim/vqvae_coco_panoptic.py b/big_vision/configs/proj/uvim/vqvae_coco_panoptic.py new file mode 100644 index 0000000000000000000000000000000000000000..5a620bae982e59187a02e54a0b735e346ad7d7ac --- /dev/null +++ b/big_vision/configs/proj/uvim/vqvae_coco_panoptic.py @@ -0,0 +1,143 @@ +# Copyright 2022 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""A config for training a UViM stage I model for the panoptic task. + +This config is expected to reproduce the paper's result and achieve +approximately 75.7 PQ points on the COCO holdout data. + +We also provide a low-resource variant of this config, which can be enabled +by adding `:singlehost` postfix to the config name. This one is expected to +achieve 67.8 PQ points on the COCO holdout data. +""" + +import itertools +import big_vision.configs.common as bvcc +import ml_collections as mlc + + +def get_config(arg='res=512,patch_size=16'): + """Config for training label compression on COCO-panoptic.""" + arg = bvcc.parse_arg(arg, res=512, patch_size=16, + runlocal=False, singlehost=False) + config = mlc.ConfigDict() + + config.task = 'proj.uvim.panoptic_task' + + config.input = {} + config.input.data = dict(name='coco/2017_panoptic', split='train[4096:]') + + config.input.batch_size = 1024 + config.input.shuffle_buffer_size = 25_000 + + config.total_epochs = 1000 + + config.input.pp = ( + f'decode|coco_panoptic|concat(["semantics","instances"], "labels")|' + f'randu("fliplr")|det_fliplr(key="image")|det_fliplr(key="labels")|' + f'inception_box|crop_box(key="image")|crop_box(key="labels")|' + f'resize({arg.res})|resize({arg.res},key="labels",method="nearest")|' + f'value_range(-1, 1)|make_canonical|keep("image","labels")' + ) + pp_eval = ( + f'decode|coco_panoptic|concat(["semantics","instances"], "labels")|' + f'resize({arg.res})|resize({arg.res},key="labels",method="nearest")|' + f'value_range(-1, 1)|make_canonical|keep("image","labels")' + ) + + config.log_training_steps = 50 + config.ckpt_steps = 1000 + config.keep_ckpt_steps = 20_000 + + # Model section + config.model_name = 'proj.uvim.vit' + config.model = mlc.ConfigDict() + config.model.input_size = (arg.res, arg.res) + config.model.patch_size = (arg.patch_size, arg.patch_size) + config.model.code_len = 256 + config.model.width = 768 + config.model.enc_depth = 6 + config.model.dec_depth = 12 + config.model.mlp_dim = 3072 + config.model.num_heads = 12 + config.model.dict_size = 4096 # Number of words in dict. + config.model.codeword_dim = 768 + config.model.dict_momentum = 0.995 # Momentum for dict. learning. + config.model.with_encoder_ctx = True + config.model.with_decoder_ctx = True + config.model.code_dropout = 'random' + config.model.bottleneck_resize = True + config.model.inputs = { + 'semantics': (133 + 1, arg.patch_size**2), # +1 for void label + 'instances': (100, arg.patch_size**2), # COCO: actually 98 train/78 validation. + } + config.model.outputs = config.model.inputs + + # VQVAE-specific params. + config.freeze_dict = False # Will freeze a dict. inside VQ-VAE model. + config.w_commitment = 0.0 + + # Optimizer section + config.optax_name = 'big_vision.scale_by_adafactor' + config.optax = dict(beta2_cap=0.95) + + config.lr = 4e-4 + config.wd = 4e-5 + config.schedule = dict(decay_type='cosine', warmup_steps=4_000) + config.grad_clip_norm = 1.0 + + # Evaluation section + config.evals = {} + config.evals.val = mlc.ConfigDict() + config.evals.val.type = 'proj.uvim.compute_mean' + config.evals.val.pred = 'validation' + config.evals.val.data = {**config.input.data} + config.evals.val.data.split = 'train[:4096]' + config.evals.val.pp_fn = pp_eval + config.evals.val.log_steps = 250 + + base = { + 'type': 'proj.uvim.coco_panoptic', + 'pp_fn': pp_eval.replace('decode|', ''), + 'log_steps': 10_000, + # Filters objects that occupy less than 0.03^2 fraction of all pixels. + # 'predict_kwargs': {'min_fraction': 0.03 ** 2}, + } + config.evals.coco_panoptic_train = dict(**base, split='train[4096:8192]') + config.evals.coco_panoptic_holdout = dict(**base, split='train[:4096]') + config.evals.coco_panoptic = dict(**base, split='validation') + + # config.evals.save_pred = dict(type='proj.uvim.save_predictions') + # config.evals.save_pred.pp = pp_eval.replace('decode|', '') + # config.evals.save_pred.log_steps = 100_000 + # config.evals.save_pred.dataset = config.dataset + # config.evals.save_pred.split = 'validation[:1024]' + # config.evals.save_pred.outfile = 'inference.npz' + + config.seed = 0 + + if arg.singlehost: + config.input.batch_size = 128 + config.num_epochs = 100 + elif arg.runlocal: + config.input.batch_size = 16 + config.input.shuffle_buffer_size = 10 + config.log_training_steps = 5 + config.model.enc_depth = 1 + config.model.dec_depth = 1 + config.evals.val.data.split = 'validation[:16]' + config.evals.val.log_steps = 20 + + return config \ No newline at end of file diff --git a/big_vision/configs/proj/uvim/vqvae_imagenet2012_colorization.py b/big_vision/configs/proj/uvim/vqvae_imagenet2012_colorization.py new file mode 100644 index 0000000000000000000000000000000000000000..d4ecb7a57b6a47f573eca63f5b2652cdcc33dc47 --- /dev/null +++ b/big_vision/configs/proj/uvim/vqvae_imagenet2012_colorization.py @@ -0,0 +1,151 @@ +# Copyright 2022 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""A config for training a colorization VQ-VAE on imagenet2012. +""" + +import itertools +import big_vision.configs.common as bvcc +import ml_collections as mlc + + +def get_config(arg='res=512,patch_size=16'): + """A config for training a UViM stage I model for the colorization task.""" + arg = bvcc.parse_arg(arg, res=512, patch_size=16, + runlocal=False, singlehost=False) + config = mlc.ConfigDict() + + config.task = 'proj.uvim.colorization_task' + + config.input = {} + config.input.data = dict(name='imagenet2012', split='train[4096:]') + + config.input.batch_size = 1024 + config.input.shuffle_buffer_size = 25_000 + + config.total_epochs = 100 + + config.input.pp = ( + f'decode_jpeg_and_inception_crop({arg.res})' + f'|flip_lr' + f'|copy(inkey="image", outkey="labels")' + f'|rgb_to_grayscale_to_rgb(inkey="image",outkey="image")' + f'|value_range(-1,1,key="image")' + f'|value_range(-1,1,key="labels")' + f'|keep("image","labels")') + + pp_eval = ( + f'decode' + f'|resize({arg.res})' + f'|copy(inkey="image", outkey="labels")' + f'|rgb_to_grayscale_to_rgb(inkey="image",outkey="image")' + f'|value_range(-1,1,key="image")' + f'|value_range(-1,1,key="labels")' + f'|keep("image","labels")') + + config.log_training_steps = 50 + config.ckpt_steps = 1000 + config.keep_ckpt_steps = 20_000 + + # Model section + config.model_name = 'proj.uvim.vit' + config.model = mlc.ConfigDict() + config.model.input_size = (arg.res, arg.res) + config.model.patch_size = (arg.patch_size, arg.patch_size) + config.model.code_len = 256 + config.model.width = 768 + config.model.enc_depth = 6 + config.model.dec_depth = 12 + config.model.mlp_dim = 3072 + config.model.num_heads = 12 + config.model.dict_size = 4096 # Number of words in dict. + config.model.codeword_dim = 768 + config.model.dict_momentum = 0.995 # Momentum for dict. learning. + config.model.with_encoder_ctx = True + config.model.with_decoder_ctx = True + config.model.code_dropout = 'random' + config.model.bottleneck_resize = True + config.model.inputs = { + 'color': (3, arg.patch_size**2), + } + config.model.outputs = config.model.inputs + + # VQVAE-specific params. + config.freeze_dict = False # Will freeze a dict. inside VQ-VAE model. + config.w_commitment = 0.0 + + # Optimizer section + config.optax_name = 'big_vision.scale_by_adafactor' + config.optax = dict(beta2_cap=0.95) + + config.lr = 4e-4 + config.wd = 4e-5 + config.schedule = dict(decay_type='cosine', warmup_steps=4_000) + config.grad_clip_norm = 1.0 + + # Evaluation section + config.evals = {} + config.evals.val = mlc.ConfigDict() + config.evals.val.type = 'proj.uvim.compute_mean' + config.evals.val.pred = 'validation' + config.evals.val.data = {**config.input.data} + config.evals.val.data.split = 'train[:4096]' + config.evals.val.pp_fn = pp_eval + config.evals.val.log_steps = 250 + + base = { + 'type': 'proj.uvim.psnr', + 'pp_fn': pp_eval.replace('decode|', ''), + 'log_steps': 10_000, + } + config.evals.psnr_train = dict(**base, split='train[4096:8192]') + config.evals.psnr_holdout = dict(**base, split='train[:4096]') + config.evals.psnr_val = dict(**base, split='validation') + + config.evals.colorization_val_coltran_fid = { + 'type': 'proj.uvim.coltran_fid', + 'log_steps': 100_000, + } + + # config.evals.save_pred = dict(type='proj.uvim.save_predictions') + # config.evals.save_pred.pp = pp_eval.replace('decode|', '') + # config.evals.save_pred.log_steps = 100_000 + # config.evals.save_pred.dataset = config.dataset + # config.evals.save_pred.split = 'validation[:1024]' + # config.evals.save_pred.outfile = 'inference.npz' + + config.seed = 0 + + if arg.singlehost: + config.input.batch_size = 128 + config.total_epochs = 20 + elif arg.runlocal: + config.input.batch_size = 16 + config.input.shuffle_buffer_size = 10 + config.log_training_steps = 5 + config.model.enc_depth = 1 + config.model.dec_depth = 1 + config.evals.val.data.split = 'validation[:16]' + config.evals.val.log_steps = 20 + config.evals.psnr_train.split = 'train[:256]' + config.evals.psnr_train.log_steps = 20 + config.evals.psnr_holdout.split = 'train[256:512]' + config.evals.psnr_holdout.log_steps = 20 + config.evals.psnr_val.split = 'train[:256]' + config.evals.psnr_val.log_steps = 20 + config.evals.colorization_val_coltran_fid.split = 'validation[:256]' + config.evals.colorization_val_coltran_fid.log_steps = 20 + + return config \ No newline at end of file diff --git a/big_vision/configs/proj/uvim/vqvae_nyu_depth.py b/big_vision/configs/proj/uvim/vqvae_nyu_depth.py new file mode 100644 index 0000000000000000000000000000000000000000..d5ae29321bfa38bd2df097092126e0f3f15870f4 --- /dev/null +++ b/big_vision/configs/proj/uvim/vqvae_nyu_depth.py @@ -0,0 +1,144 @@ +# Copyright 2022 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""A config for training a UViM stage I model for the depth task. +""" + +import itertools +import big_vision.configs.common as bvcc +import ml_collections as mlc + + +QUANTIZATION_BINS = 256 +# Depths outside of this range will not be evaluated. +MIN_DEPTH = 1e-3 +MAX_DEPTH = 10 + + +def get_config(arg='res=512,patch_size=16'): + """Config for training label compression on NYU depth v2.""" + arg = bvcc.parse_arg(arg, res=512, patch_size=16, + runlocal=False, singlehost=False) + config = mlc.ConfigDict() + + config.task = 'proj.uvim.depth_task' + + config.input = {} + config.input.data = dict(name='nyu_depth_v2', split='train) + + config.input.batch_size = 1024 + config.input.shuffle_buffer_size = 25_000 + + config.total_epochs = 200 + + config.input.pp = ( + f'decode|nyu_depth|' + f'randu("fliplr")|det_fliplr(key="image")|det_fliplr(key="labels")|' + f'inception_box|crop_box(key="image")|crop_box(key="labels")|' + f'resize({arg.res})|resize({arg.res},key="labels",method="nearest")|' + f'value_range(-1, 1)|keep("image","labels")' + ) + + pp_eval = ( + f'decode|nyu_depth|nyu_eval_crop|' + f'resize({arg.res})|resize({arg.res},key="labels",method="nearest")|' + f'value_range(-1, 1)|keep("image","labels")' + ) + + # There are no image IDs in TFDS, so hand through the ground truth for eval. + pp_pred = ( + f'nyu_depth|nyu_eval_crop|copy("labels","ground_truth")|' + f'resize({arg.res})|resize({arg.res},key="labels",method="nearest")|' + f'value_range(-1, 1)|' + f'keep("image","labels","ground_truth")' + ) + + config.log_training_steps = 50 + config.ckpt_steps = 1000 + config.keep_ckpt_steps = 20_000 + + # Model section + config.min_depth = MIN_DEPTH + config.max_depth = MAX_DEPTH + config.model_name = 'proj.uvim.vit' + config.model = mlc.ConfigDict() + config.model.input_size = (arg.res, arg.res) + config.model.patch_size = (arg.patch_size, arg.patch_size) + config.model.code_len = 256 + config.model.width = 768 + config.model.enc_depth = 6 + config.model.dec_depth = 12 + config.model.mlp_dim = 3072 + config.model.num_heads = 12 + config.model.dict_size = 4096 # Number of words in dict. + config.model.codeword_dim = 768 + config.model.dict_momentum = 0.995 # Momentum for dict. learning. + config.model.with_encoder_ctx = True + config.model.with_decoder_ctx = True + config.model.code_dropout = 'random' + config.model.bottleneck_resize = True + config.model.inputs = { + 'depth': (QUANTIZATION_BINS, arg.patch_size**2), + } + config.model.outputs = config.model.inputs + + # VQVAE-specific params. + config.freeze_dict = False # Will freeze a dict. inside VQ-VAE model. + config.w_commitment = 0.0 + + # Optimizer section + config.optax_name = 'big_vision.scale_by_adafactor' + config.optax = dict(beta2_cap=0.95) + + config.lr = 1e-3 + config.wd = 1e-5 + config.schedule = dict(decay_type='cosine', warmup_steps=4_000) + config.grad_clip_norm = 1.0 + + # Evaluation section + config.evals = {} + config.evals.val = mlc.ConfigDict() + config.evals.val.type = 'proj.uvim.compute_mean' + config.evals.val.pred = 'validation' + config.evals.val.data = {**config.input.data} + config.evals.val.data.split = 'validation' + config.evals.val.pp_fn = pp_eval + config.evals.val.log_steps = 250 + + base = { + 'type': 'proj.uvim.nyu_depth', + 'dataset': config.input.data.name, + 'pp_fn': pp_pred, + 'log_steps': 2000, + 'min_depth': MIN_DEPTH, + 'max_depth': MAX_DEPTH, + } + config.evals.nyu_depth_val = dict(**base, split='validation') + + config.seed = 0 + + if arg.singlehost: + config.input.batch_size = 128 + config.total_epochs = 50 + elif arg.runlocal: + config.input.batch_size = 16 + config.input.shuffle_buffer_size = 10 + config.log_training_steps = 5 + config.model.enc_depth = 1 + config.model.dec_depth = 1 + config.evals.val.data.split = 'validation[:16]' + config.evals.val.log_steps = 20 + + return config \ No newline at end of file diff --git a/big_vision/configs/transfer.py b/big_vision/configs/transfer.py new file mode 100644 index 0000000000000000000000000000000000000000..1ee64e43b274bdf5d462c7e3f3d9b5fc085c1796 --- /dev/null +++ b/big_vision/configs/transfer.py @@ -0,0 +1,186 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long,missing-function-docstring +r"""A config for transferring vit-augreg. + +Best HP selected on (mini)val, expected test results (repeated 5 times): + +ViT-Augreg-B/32: + Dataset, crop, learning rate, mean (%), range (%) + - ImageNet, inception_crop, 0.03, 83.27, [83.22...83.33] + - Cifar10, resmall_crop, 0.003, 98.55, [98.46...98.6] + - Cifar100, resmall_crop, 0.01, 91.35, [91.09...91.62] + - Pets, inception_crop, 0.003, 93.78, [93.62...94.00] + - Flowers, inception_crop, 0.003, 99.43, [99.42...99.45] + + +Command to run: +big_vision.train \ + --config big_vision/configs/transfer.py:model=vit-i21k-augreg-b/32,dataset=cifar10,crop=resmall_crop \ + --workdir gs://$GS_BUCKET_NAME/big_vision/workdir/`date '+%m-%d_%H%M'` --config.lr=0.03 +""" + +import big_vision.configs.common as bvcc +import ml_collections as mlc + + +def _set_model(config, model): + """Load pre-trained models: vit or bit.""" + # Reset the head to init (of zeros) when transferring. + config.model_load = dict(dont_load=['head/kernel', 'head/bias']) + + if model == 'vit-i21k-augreg-b/32': + # Load "recommended" upstream B/32 from https://arxiv.org/abs/2106.10270 + config.model_name = 'vit' + config.model_init = 'howto-i21k-B/32' + config.model = dict(variant='B/32', pool_type='tok') + elif model == 'vit-i21k-augreg-l/16': + config.model_name = 'vit' + config.model_init = 'howto-i21k-L/16' + config.model = dict(variant='L/16', pool_type='tok') + elif model == 'vit-s16': + config.model_name = 'vit' + config.model_init = 'i1k-s16-300ep' + config.model = dict(variant='S/16', pool_type='gap', posemb='sincos2d', + rep_size=True) + elif model == 'bit-m-r50x1': + config.model_name = 'bit_paper' + config.model_init = 'M' + config.model = dict(depth=50, width=1) + else: + raise ValueError(f'Unknown model: {model}, please define customized model.') + + +def _set_dataset(config, dataset, crop='inception_crop', h_res=448, l_res=384): + if dataset == 'cifar10': + _set_task(config, 'cifar10', 'train[:98%]', 'train[98%:]', 'test', 10, steps=10_000, warmup=500, crop=crop, h_res=h_res, l_res=l_res) + elif dataset == 'cifar100': + _set_task(config, 'cifar100', 'train[:98%]', 'train[98%:]', 'test', 100, steps=10_000, warmup=500, crop=crop, h_res=h_res, l_res=l_res) + elif dataset == 'imagenet2012': + _set_task(config, 'imagenet2012', 'train[:99%]', 'train[99%:]', 'validation', 1000, steps=20_000, warmup=500, crop=crop, h_res=h_res, l_res=l_res) + _set_imagenet_variants(config) + elif dataset == 'oxford_iiit_pet': + _set_task(config, 'oxford_iiit_pet', 'train[:90%]', 'train[90%:]', 'test', 37, steps=500, warmup=100, crop=crop, h_res=h_res, l_res=l_res) + elif dataset == 'oxford_flowers102': + _set_task(config, 'oxford_flowers102', 'train[:90%]', 'train[90%:]', 'test', 102, steps=500, warmup=100, crop=crop, h_res=h_res, l_res=l_res) + else: + raise ValueError( + f'Unknown dataset: {dataset}, please define customized dataset.') + + +def _set_task(config, dataset, train, val, test, n_cls, + steps=20_000, warmup=500, lbl='label', crop='resmall_crop', + flip=True, h_res=448, l_res=384): + """Vision task with val and test splits.""" + config.total_steps = steps + config.schedule = dict( + warmup_steps=warmup, + decay_type='cosine', + ) + + config.input.data = dict(name=dataset, split=train) + pp_common = ( + '|value_range(-1, 1)|' + f'onehot({n_cls}, key="{lbl}", key_result="labels")|' + 'keep("image", "labels")' + ) + + if crop == 'inception_crop': + pp_train = f'decode|inception_crop({l_res})' + elif crop == 'resmall_crop': + pp_train = f'decode|resize_small({h_res})|random_crop({l_res})' + elif crop == 'resize_crop': + pp_train = f'decode|resize({h_res})|random_crop({l_res})' + else: + raise ValueError(f'Unknown crop: {crop}. Must be one of: ' + 'inception_crop, resmall_crop, resize_crop') + if flip: + pp_train += '|flip_lr' + config.input.pp = pp_train + pp_common + + pp = f'decode|resize_small({h_res})|central_crop({l_res})' + pp_common + config.num_classes = n_cls + + def get_eval(split): + return dict( + type='classification', + data=dict(name=dataset, split=split), + loss_name='softmax_xent', + log_steps=100, + pp_fn=pp, + ) + config.evals = dict(val=get_eval(val), test=get_eval(test)) + + +def _set_imagenet_variants(config, h_res=448, l_res=384): + """Evaluation tasks on ImageNet variants: v2 and real.""" + pp = (f'decode|resize_small({h_res})|central_crop({l_res})' + '|value_range(-1, 1)|onehot(1000, key="{lbl}", key_result="labels")|' + 'keep("image", "labels")' + ) + + # Special-case rename for i1k (val+test -> minival+val) + config.evals.minival = config.evals.val + config.evals.val = config.evals.test + # NOTE: keep test == val for convenience in subsequent analysis. + + config.evals.real = dict(type='classification') + config.evals.real.data = dict(name='imagenet2012_real', split='validation') + config.evals.real.pp_fn = pp.format(lbl='real_label') + config.evals.real.loss_name = config.loss + config.evals.real.log_steps = 100 + + config.evals.v2 = dict(type='classification') + config.evals.v2.data = dict(name='imagenet_v2', split='test') + config.evals.v2.pp_fn = pp.format(lbl='label') + config.evals.v2.loss_name = config.loss + config.evals.v2.log_steps = 100 + + +def get_config(arg=None): + """Config for adaptation.""" + arg = bvcc.parse_arg(arg, model='vit', dataset='cifar10', crop='resmall_crop', + h_res=448, l_res=384, batch_size=512, fsdp=False, + runlocal=False) + config = mlc.ConfigDict() + + config.input = {} + config.input.batch_size = arg.batch_size if not arg.runlocal else 8 + config.input.shuffle_buffer_size = 50_000 if not arg.runlocal else 100 + + config.log_training_steps = 10 + config.ckpt_steps = 1000 + config.ckpt_timeout = 600 + + # Optimizer section + config.optax_name = 'big_vision.momentum_hp' + config.grad_clip_norm = 1.0 + config.wd = None # That's our default, but just being explicit here! + config.loss = 'softmax_xent' + config.lr = 0.01 + config.mixup = dict(p=0.0) + + config.seed = 0 + + _set_dataset(config, arg.dataset, arg.crop, arg.h_res, arg.l_res) + + _set_model(config, arg.model) + if arg.fsdp: + config.mesh = [('data', -1)] + config.sharding_strategy = [('.*', 'fsdp(axis="data")')] + config.sharding_rules = [('act_batch', ('data',))] + config.model.scan = True + + return config \ No newline at end of file diff --git a/big_vision/configs/vit_i1k.py b/big_vision/configs/vit_i1k.py new file mode 100644 index 0000000000000000000000000000000000000000..b17a43bacd272c4b1ae0e9ef614fdf6a3ed6d911 --- /dev/null +++ b/big_vision/configs/vit_i1k.py @@ -0,0 +1,177 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""Pre-training ViT on ILSVRC-2012 as in https://arxiv.org/abs/2106.10270 + +This config does NOT include regularization (dropout, stochastic depth), which +was shown to help with B/32, B/16, L/16 models in the paper (Figure 4). + +This configuration makes use of the "arg" to get_config to select which model +to run, so a few examples are given below: + +Run training of a B/16 model: + +big_vision.train \ + --config big_vision/configs/vit_i1k.py:variant=B/16 \ + --workdir gs://[your_bucket]/big_vision/`date '+%m-%d_%H%M'` + +Run training of a B/32 model with custom aug-strenght and 300ep: + +big_vision.train \ + --config big_vision/configs/vit_i1k.py:variant=B/32,aug=light1 \ + --workdir gs://[your_bucket]/big_vision/`date '+%m-%d_%H%M'` \ + --config.total_epochs 300 +""" + +import big_vision.configs.common as bvcc +from big_vision.configs.common_fewshot import get_fewshot_lsr +import ml_collections as mlc + +MIXUP_DEF = { + 'none': dict(p=0.0, fold_in=None), + 'light1': dict(p=0.0, fold_in=None), + 'light2': dict(p=0.2, fold_in=None), + 'medium1': dict(p=0.2, fold_in=None), + 'medium2': dict(p=0.5, fold_in=None), + 'strong1': dict(p=0.5, fold_in=None), + 'strong2': dict(p=0.8, fold_in=None), +} + +RANDAUG_DEF = { + 'none': '', + 'light1': 'randaug(2,0)', # Actually not nothing! + 'light2': 'randaug(2,10)', + 'medium1': 'randaug(2,15)', + 'medium2': 'randaug(2,15)', + 'strong1': 'randaug(2,20)', + 'strong2': 'randaug(2,20)', +} + + +def get_config(arg=None): + """Config for training.""" + arg = bvcc.parse_arg(arg, variant='B/16', runlocal=False, aug='') + config = mlc.ConfigDict() + + config.seed = 0 + config.total_epochs = 300 + config.num_classes = 1000 + config.loss = 'sigmoid_xent' + config.init_head_bias = -6.9 + + # If this gives a KeyError, lookup Fig4 of the paper and add an entry. + # Note, this here is a good average between 30ep and 300ep, sometimes you coud + # find a slightly better setting for either of them. + aug_setting = arg.aug or { + 'Ti/16': 'light1', + 'S/32': 'medium1', + 'S/16': 'medium2', + 'B/32': 'medium2', + 'B/16': 'medium2', + 'L/16': 'medium2', + }[arg.variant] + + config.input = dict() + config.input.data = dict( + name='imagenet2012', + split='train[:99%]', + ) + config.input.batch_size = 4096 + config.input.cache = 'raw_data' if arg.runlocal else 'none' # Needs up to 120GB of RAM! + config.input.shuffle_buffer_size = 250_000 + + pp_common = ( + '|value_range(-1, 1)' + '|onehot(1000, key="{lbl}", key_result="labels")' + '|keep("image", "labels")' + ) + config.input.pp = ( + 'decode_jpeg_and_inception_crop(224)|flip_lr|' + + RANDAUG_DEF[aug_setting] + + pp_common.format(lbl='label') + ) + pp_eval = 'decode|resize_small(256)|central_crop(224)' + pp_common + + # To continue using the near-defunct randaug op. + config.pp_modules = ['ops_general', 'ops_image', 'ops_text', 'archive.randaug'] + + # Aggressive pre-fetching because our models here are small, so we not only + # can afford it, but we also need it for the smallest models to not be + # bottle-necked by the input pipeline. Play around with it for -L models tho. + config.input.prefetch = 8 + config.prefetch_to_device = 4 + + config.log_training_steps = 50 + config.ckpt_steps = 1000 + + # Model section + config.model_name = 'vit' + config.model = dict( + variant=arg.variant, + rep_size=True, + pool_type='tok', + ) + + # Optimizer section + config.grad_clip_norm = 1.0 + config.optax_name = 'scale_by_adam' + config.optax = dict(mu_dtype='bfloat16') + # The modified AdaFactor we introduced in https://arxiv.org/abs/2106.04560 + # almost always behaves exactly like adam, but at a fraction of the memory + # cost (specifically, adam_bf16 = +1.5M, adafactor = +0.5M), hence it is a + # good idea to try it when you are memory-bound! + # config.optax_name = 'big_vision.scale_by_adafactor' + # A good flag to play with when hitting instabilities, is the following: + # config.optax = dict(beta2_cap=0.95) + + config.lr = 0.001 + config.wd = 0.0001 + config.schedule = dict(warmup_steps=10_000, decay_type='cosine') + + config.mixup = MIXUP_DEF[aug_setting] + + # Eval section + def get_eval(split, dataset='imagenet2012'): + return dict( + type='classification', + data=dict(name=dataset, split=split), + pp_fn=pp_eval.format(lbl='label'), + loss_name=config.loss, + log_steps=2500, # Very fast O(seconds) so it's fine to run it often. + cache='final' if arg.runlocal else 'none', + ) + config.evals = {} + config.evals.train = get_eval('train[:2%]') + config.evals.minival = get_eval('train[99%:]') + config.evals.val = get_eval('validation') + config.evals.v2 = get_eval('test', dataset='imagenet_v2') + config.evals.real = get_eval('validation', dataset='imagenet2012_real') + config.evals.real.pp_fn = pp_eval.format(lbl='real_label') + + config.fewshot = get_fewshot_lsr(runlocal=arg.runlocal) + config.fewshot.log_steps = 10_000 + + # Make a few things much smaller for quick local debugging testruns. + if arg.runlocal: + config.input.shuffle_buffer_size = 10 + config.input.batch_size = 8 + config.input.cache_raw = False + config.evals.train.data.split = 'train[:16]' + config.evals.minival.data.split = 'train[:16]' + config.evals.val.data.split = 'validation[:16]' + config.evals.v2.data.split = 'test[:16]' + config.evals.real.data.split = 'validation[:16]' + + return config \ No newline at end of file diff --git a/big_vision/configs/vit_i21k.py b/big_vision/configs/vit_i21k.py new file mode 100644 index 0000000000000000000000000000000000000000..adae41838736be4f4a9737e614152dc5c7fd329b --- /dev/null +++ b/big_vision/configs/vit_i21k.py @@ -0,0 +1,145 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""Pre-training ViT on ImageNet-21k as in https://arxiv.org/abs/2106.10270 + +This config relies on the Imagenet-21k tfds dataset, which is not yet +available publicly in TFDS. We intend to add the dataset to public TFDS soon, +and this config will then be runnable. + +Note that regularization (dropout, stochastic depth) is not currently +implemented. This was not beneficial for ImageNet-21k pre-trainning. +""" + +import big_vision.configs.common as bvcc +from big_vision.configs.common_fewshot import get_fewshot_lsr +import ml_collections as mlc + +MIXUP_DEF = { + 'none': dict(p=0.0, fold_in=None), + 'light1': dict(p=0.0, fold_in=None), + 'light2': dict(p=0.2, fold_in=None), + 'medium1': dict(p=0.2, fold_in=None), + 'medium2': dict(p=0.5, fold_in=None), + 'strong1': dict(p=0.5, fold_in=None), + 'strong2': dict(p=0.8, fold_in=None), +} + +RANDAUG_DEF = { + 'none': '', + 'light1': 'randaug(2,0)', # Actually not nothing! + 'light2': 'randaug(2,10)', + 'medium1': 'randaug(2,15)', + 'medium2': 'randaug(2,15)', + 'strong1': 'randaug(2,20)', + 'strong2': 'randaug(2,20)', +} + + +def get_config(arg=None): + """Config for training.""" + arg = bvcc.parse_arg(arg, variant='B/16', runlocal=False, aug=None) + config = mlc.ConfigDict() + + config.seed = 0 + config.total_epochs = 300 + config.num_classes = 21843 + config.init_head_bias = -10.0 + config.loss = 'sigmoid_xent' + + # If this gives a KeyError, lookup Fig4 of the paper and add an entry. + # Note, this here is a good average between 30ep and 300ep, sometimes you coud + # find a slightly better setting for either of them. + aug_setting = { + 'Ti/16': 'none', + 'S/32': 'none', + 'S/16': 'light1', + 'B/32': 'light2', + 'B/16': 'light2', + 'L/16': 'medium2', + }[arg.variant] + + config.input = dict() + config.input.data = dict( + name='imagenet21k', + split='full[51200:]', + ) + config.input.batch_size = 4096 + config.input.shuffle_buffer_size = 250_000 # Per host, so small-ish is ok. + + pp_common = '|value_range(-1, 1)|onehot({onehot_args})|keep("image", "labels")' + pp_common_i21k = pp_common.format(onehot_args=f'{config.num_classes}') + pp_common_i1k = pp_common.format(onehot_args='1000, key="label", key_result="labels"') + config.input.pp = f'decode_jpeg_and_inception_crop(224)|flip_lr|{RANDAUG_DEF[aug_setting]}' + pp_common_i21k + pp_eval = 'decode|resize_small(256)|central_crop(224)' + + # To continue using the near-defunct randaug op. + config.pp_modules = ['ops_general', 'ops_image', 'ops_text', 'archive.randaug'] + + # Aggressive pre-fetching because our models here are small, so we not only + # can afford it, but we also need it for the smallest models to not be + # bottle-necked by the input pipeline. Play around with it for -L models tho. + config.input.prefetch = 8 + config.prefetch_to_device = 4 + + config.log_training_steps = 50 + config.ckpt_steps = 1000 + + # Model section + config.model_name = 'vit' + config.model = dict(variant=arg.variant, pool_type='gap', posemb='learn') + + # Optimizer section + config.optax_name = 'scale_by_adam' + config.optax = dict(mu_dtype='bfloat16') + config.grad_clip_norm = 1.0 + + config.lr = 0.001 + config.wd = 0.0001 + config.schedule = dict(warmup_steps=10_000, decay_type='cosine') + + config.mixup = MIXUP_DEF[aug_setting] + + # Evaluations on i21k itself. + def eval_i21k(split): + return dict( + type='classification', + data={**config.input.data, 'split': split}, + pp_fn=pp_eval + pp_common_i21k, + loss_name=config.loss, + log_steps=1000, # Very fast O(seconds) so it's fine to run it often. + ) + config.evals = {} + config.evals.test = eval_i21k('full[:25_600]') + config.evals.val = eval_i21k('full[25_600:51_200]') + config.evals.train = eval_i21k('full[51_200:76_800]') + + # Few-shot evaluators + config.evals.fewshot = get_fewshot_lsr(runlocal=arg.runlocal) + config.evals.fewshot.log_steps = 25_000 + + # Make a few things much smaller for quick local debugging testruns. + if arg.runlocal: + config.input.shuffle_buffer_size = 10 + config.input.batch_size = 8 + config.evals.test.data.split = 'full[:16]' + config.evals.train.data.split = 'full[:16]' + config.evals.val.data.split = 'full[:16]' + config.evals.i1k_val.data.split = 'validation[:16]' + config.evals.i1k_v2.data.split = 'test[:16]' + config.evals.i1k_a.data.split = 'test[:16]' + config.evals.i1k_r.data.split = 'test[:16]' + + return config \ No newline at end of file diff --git a/big_vision/configs/vit_s16_i1k.py b/big_vision/configs/vit_s16_i1k.py new file mode 100644 index 0000000000000000000000000000000000000000..d50dd26508713b67c434f0e677e58fbef7d8af13 --- /dev/null +++ b/big_vision/configs/vit_s16_i1k.py @@ -0,0 +1,105 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""Pre-training ViT-S/16 on ILSVRC-2012 following https://arxiv.org/abs/2205.01580. + +This should take 6-7h to finish 90ep on a TPU-v3-8 and reach 76.5%, +see the tech report for more details. + +Command to run: + +big_vision.train \ + --config big_vision/configs/vit_s16_i1k.py \ + --workdir gs://[your_bucket]/big_vision/`date '+%m-%d_%H%M'` + +To run for 300ep, add `--config.total_epochs 300` to the command. +""" + +import ml_collections as mlc + + +def get_config(): + """Config for training.""" + config = mlc.ConfigDict() + + config.seed = 0 + config.total_epochs = 90 + config.num_classes = 1000 + config.loss = 'softmax_xent' + + config.input = {} + config.input.data = dict( + name='imagenet2012', + split='train[:99%]', + ) + config.input.batch_size = 1024 + config.input.cache_raw = True # Needs up to 120GB of RAM! + config.input.shuffle_buffer_size = 250_000 + + pp_common = ( + '|value_range(-1, 1)' + '|onehot(1000, key="{lbl}", key_result="labels")' + '|keep("image", "labels")' + ) + config.input.pp = ( + 'decode_jpeg_and_inception_crop(224)|flip_lr|randaug(2,10)' + + pp_common.format(lbl='label') + ) + pp_eval = 'decode|resize_small(256)|central_crop(224)' + pp_common + + # To continue using the near-defunct randaug op. + config.pp_modules = ['ops_general', 'ops_image', 'ops_text', 'archive.randaug'] + + config.log_training_steps = 50 + config.ckpt_steps = 1000 + + # Model section + config.model_name = 'vit' + config.model = dict( + variant='S/16', + rep_size=True, + pool_type='gap', + posemb='sincos2d', + ) + + # Optimizer section + config.grad_clip_norm = 1.0 + config.optax_name = 'scale_by_adam' + config.optax = dict(mu_dtype='bfloat16') + + config.lr = 0.001 + config.wd = 0.0001 + config.schedule = dict(warmup_steps=10_000, decay_type='cosine') + + config.mixup = dict(p=0.2, fold_in=None) + + # Eval section + def get_eval(split, dataset='imagenet2012'): + return dict( + type='classification', + data=dict(name=dataset, split=split), + pp_fn=pp_eval.format(lbl='label'), + loss_name=config.loss, + log_steps=2500, # Very fast O(seconds) so it's fine to run it often. + ) + config.evals = {} + config.evals.train = get_eval('train[:2%]') + config.evals.minival = get_eval('train[99%:]') + config.evals.val = get_eval('validation') + config.evals.v2 = get_eval('test', dataset='imagenet_v2') + config.evals.real = get_eval('validation', dataset='imagenet2012_real') + config.evals.real.pp_fn = pp_eval.format(lbl='real_label') + + return config diff --git a/big_vision/datasets/__pycache__/core.cpython-310.pyc b/big_vision/datasets/__pycache__/core.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..de37c8c6980e8881b338a82afaef513fe91e028c Binary files /dev/null and b/big_vision/datasets/__pycache__/core.cpython-310.pyc differ diff --git a/big_vision/datasets/__pycache__/jsonl.cpython-310.pyc b/big_vision/datasets/__pycache__/jsonl.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fc2c7114e6856f626a3df8d70df902c7a9d33f07 Binary files /dev/null and b/big_vision/datasets/__pycache__/jsonl.cpython-310.pyc differ diff --git a/big_vision/datasets/ai2d/ai2d.py b/big_vision/datasets/ai2d/ai2d.py new file mode 100644 index 0000000000000000000000000000000000000000..710be04604a1d837de766f2b3ff436e56c585957 --- /dev/null +++ b/big_vision/datasets/ai2d/ai2d.py @@ -0,0 +1,209 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""AI2D TFDS converter. + + +It's a small dataset, so can be built locally. Copy the data to local disk: + + mkdir -p /tmp/data/ai2d && cd /tmp/data/ai2d + wget https://ai2-public-datasets.s3.amazonaws.com/diagrams/ai2d-all.zip + wget https://s3-us-east-2.amazonaws.com/prior-datasets/ai2d_test_ids.csv + wget https://github.com/googlefonts/dm-fonts/raw/main/Sans/fonts/ttf/DMSans-Regular.ttf + unzip ai2d-all.zip + +Also download a font for rendering, set the location in the flag font_path. + +Then, run conversion locally (make sure to install tensorflow-datasets for the `tfds` util): + + cd third_party/py/big_vision/datasets + env TFDS_DATA_DIR=/tmp/tfds tfds build --datasets=ai2d + +Example to load: + + import tensorflow_datasets as tfds + dataset = tfds.load(ai2d', split='train', data_dir='/tmp/tfds') +""" + +import functools +import glob +import io +import json +import os +from typing import Any, Dict + +from absl import flags +import numpy as np +from PIL import Image +from PIL import ImageDraw +from PIL import ImageFont +import tensorflow_datasets as tfds + + +_DESCRIPTION = """AI2D dataset.""" + +# pylint: disable=line-too-long +_CITATION = """ +@inproceedings{kembhavi2016eccv, + author = {Aniruddha Kembhavi, Mike Salvato, Eric Kolve, Minjoon Seo, Hannaneh Hajishirzi, Ali Farhadi}, + title = {A Diagram Is Worth A Dozen Images}, + booktitle = {European Conference on Computer Vision (ECCV)}, + year = {2016} + url={https://api.semanticscholar.org/CorpusID:2682274} +} +""" +# pylint: enable=line-too-long + + +_INPUT_PATH = flags.DEFINE_string( + 'input_path', '/tmp/data/ai2d/', 'Downloaded AI2D data.' +) +_FONT_PATH = flags.DEFINE_string( + 'font_path', '/tmp/data/ai2d/DMSans-Regular.ttf', + 'Font for rendering annotations.' +) + + +class Ai2d(tfds.core.GeneratorBasedBuilder): + """DatasetBuilder for AI2D dataset.""" + + VERSION = tfds.core.Version('1.1.0') + RELEASE_NOTES = {'1.1.0': 'Re-create from scratch + more fields.'} + + def _info(self): + """Returns the metadata.""" + return tfds.core.DatasetInfo( + builder=self, + description=_DESCRIPTION, + features=tfds.features.FeaturesDict({ + 'id': tfds.features.Text(), + 'question': tfds.features.Text(), + 'label': tfds.features.Scalar(np.int32), + 'answer': tfds.features.Text(), + 'possible_answers': tfds.features.Sequence(tfds.features.Text()), + 'abc_label': tfds.features.Scalar(np.bool_), + 'image_name': tfds.features.Text(), + 'image': tfds.features.Image(encoding_format='png'), + }), + homepage='https://allenai.org/data/diagrams', + citation=_CITATION, + ) + + def _split_generators(self, dl_manager: tfds.download.DownloadManager): + """Returns SplitGenerators.""" + return {split: self._generate_examples(split) + for split in ('test', 'train')} + + def _generate_examples(self, split: str): + """Yields (key, example) tuples.""" + with open( + os.path.join(_INPUT_PATH.value, 'ai2d_test_ids.csv'), 'r' + ) as f: + all_test_ids = f.readlines() + all_test_ids = [line.strip() for line in all_test_ids] + + all_annotation_paths = glob.glob( + os.path.join(_INPUT_PATH.value, 'ai2d/questions', '*.json')) + for annotation_path in all_annotation_paths: + basename = os.path.basename(annotation_path) + image_id = basename.split('.')[0] + if image_id in all_test_ids and split == 'train': + continue + elif image_id not in all_test_ids and split == 'test': + continue + + text_annotation_path = os.path.join( + _INPUT_PATH.value, 'ai2d/annotations', basename + ) + with open(annotation_path, 'r') as f: + with open(text_annotation_path, 'r') as g: + question_json = json.load(f) + text_annotation_json = json.load(g) + for question in question_json['questions']: + label_id = int( + question_json['questions'][question]['correctAnswer'] + ) + choices = question_json['questions'][question]['answerTexts'] + abc_label = question_json['questions'][question]['abcLabel'] + annotation = { + 'id': question_json['questions'][question]['questionId'], + 'question': question, + 'label': label_id, + 'answer': choices[label_id], + 'possible_answers': tuple(choices), + 'abc_label': abc_label, + 'image_name': question_json['imageName'], + } + annotation['image'] = _create_image( + annotation, text_annotation_json['text'] + ) + yield annotation['id'], annotation + + +@functools.cache +def Font( # pylint: disable=invalid-name + size: int, +) -> ImageFont.FreeTypeFont: + """Loads the font from in the specified style. + + Args: + size: The size of the returned font. + + Returns: + The loaded font. + """ + return ImageFont.truetype(_FONT_PATH.value, size=size) + + +def _create_image( + annotation: Dict[str, Any], text_annotation: Dict[str, Any] +) -> bytes: + """Adds image to one annotation.""" + img_path = os.path.join(_INPUT_PATH.value, 'ai2d/images', + annotation['image_name']) + with open(img_path, 'rb') as f: + if annotation['abc_label']: + raw_image = _draw_text(f, text_annotation) + else: + raw_image = f.read() + return raw_image + + +def _draw_text(image, text_annotations) -> bytes: + """Replaces text in image by the correct replacement letter from AI2D.""" + image = Image.open(image) + draw = ImageDraw.Draw(image) + for annotation in text_annotations: + current_annotation = text_annotations[annotation] + rectangle = current_annotation['rectangle'] + box = [tuple(rectangle[0]), tuple(rectangle[1]),] + text = current_annotation['replacementText'] + position = box[0] + draw.rectangle(box, fill='white') + font_size = 100 + x_diff = box[1][0] - box[0][0] + y_diff = box[1][1] - box[0][1] + font = Font(font_size) + size = font.getbbox(text) + while (size[2] > x_diff or size[3] > y_diff) and font_size > 0: + font = Font(font_size) + size = font.getbbox(text) + font_size -= 1 + delta = (x_diff - size[2]) // 2 + position = (position[0] + delta, position[1]) + draw.text(position, text, fill='black', font=font) + new_image_bytes = io.BytesIO() + image.save(new_image_bytes, format='PNG') + return new_image_bytes.getvalue() diff --git a/big_vision/datasets/aokvqa/aokvqa.py b/big_vision/datasets/aokvqa/aokvqa.py new file mode 100644 index 0000000000000000000000000000000000000000..497028110225a62cdab1b51e489a77f3d3719516 --- /dev/null +++ b/big_vision/datasets/aokvqa/aokvqa.py @@ -0,0 +1,182 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""Implements the OKVQA dataset for TFDS. + +Download the required files from https://aokvqa.allenai.org/download.html: + +mkdir -p /tmp/tfds +cd /tmp/tfds/ +wget http://images.cocodataset.org/zips/train2017.zip +wget http://images.cocodataset.org/zips/val2017.zip +wget http://images.cocodataset.org/zips/test2017.zip +wget https://prior-datasets.s3.us-east-2.amazonaws.com/aokvqa/aokvqa_v1p0.tar.gz +unzip val2017.zip +unzip train2017.zip +unzip test2017.zip +tar xzf aokvqa_v1p0.tar.gz + +Then, run conversion locally (make sure to install tensorflow-datasets for the `tfds` util): + + cd big_vision/datasets + env TFDS_DATA_DIR=/tmp/tfds tfds build --datasets=aokvqa + +Example to load: + + import tensorflow_datasets as tfds + dataset = tfds.load('aokvqa', split='val', data_dir='/tmp/tfds') +""" + +import json +import os +from typing import Any +import numpy as np +import tensorflow_datasets as tfds + +_DESCRIPTION = """ +A-OKVQA addresses the task of VQA with outside knowledge. +It is a follow-up dataset of OKVQA. + +This version of the dataset contains: +- Questions + Answers + Multiple Choice Answers + Rationales from A-OKVQA. +- Images from COCO. +""" + +_CITATION = """ +@article{AOKVQA, + title={A-OKVQA: A Benchmark for Visual Question Answering using World Knowledge}, + author={Dustin Schwenk and Apoorv Khandelwal and Christopher Clark and Kenneth Marino and Roozbeh Mottaghi}, + journal={arXiv}, + year={2022}, +} +""" + +ANNOTATION_FILES = { + 'train': 'aokvqa_v1p0_train.json', + 'val': 'aokvqa_v1p0_val.json', + 'test': 'aokvqa_v1p0_test.json', +} + + +# When running locally (recommended), copy files as above an use these: +_AOKVQA_PATH = '/tmp/tfds' + + +class AOkVqa(tfds.core.GeneratorBasedBuilder): + """AOKVQA dataset for TFDS.""" + + VERSION = tfds.core.Version('1.0.0') + RELEASE_NOTES = {'1.0.0': 'ArrayRecord version.'} + MANUAL_DOWNLOAD_INSTRUCTIONS = """ + In manual_dir/ you should have a directory a_ok_vqa which contains the + following files and directories: + From the A-OKVQA dataset: + - aokvqa_v1p0_train.json + - aokvqa_v1p0_val.json + - aokvqa_v1p0_test.json + It also requires the COCO data files. + """ + + def _info(self) -> tfds.core.DatasetInfo: + """Returns the dataset metadata.""" + features = tfds.features.FeaturesDict({ + 'image': tfds.features.Image(shape=(None, None, 3)), + 'image_id': tfds.features.Scalar(dtype=np.int64), + 'direct_answers': tfds.features.Sequence(tfds.features.Text()), + 'direct_answer_is_difficult': tfds.features.Scalar(dtype=np.bool_), + 'multiple_choice_possible_answers': # List of 4 possible answers. + tfds.features.Sequence(tfds.features.Text()), + 'multiple_choice_correct_idx': # Integer from 0-3. + tfds.features.Scalar(dtype=np.int32), + 'answer_rationales': tfds.features.Sequence(tfds.features.Text()), + 'question': tfds.features.Text(), + 'question_id': tfds.features.Text(), + }) + + return tfds.core.DatasetInfo( + builder=self, + features=features, + description=_DESCRIPTION, + supervised_keys=None, + homepage='https://okvqa.allenai.org/', + citation=_CITATION, + ) + + def _split_generators(self, dl_manager: tfds.download.DownloadManager) -> ...: + """Call the function which defines the splits.""" + # data_dir = dl_manager.manual_dir + data_dir = _AOKVQA_PATH + return { + 'train': self._generate_examples(data_dir, 'train'), + 'val': self._generate_examples(data_dir, 'val'), + 'test': self._generate_examples(data_dir, 'test'), + } + + def _generate_examples(self, data_dir: str, split: str) -> ...: + annotations = get_annotations(data_dir, split) + + for question_id, feature_dict in annotations.items(): + image_id = feature_dict['image_id'] + + # Add image and GT segmentatio labels from total_transfer. + feature_dict['image'] = self.get_image_path(data_dir, split, image_id) + + # Add dummy features for several features in the test set. + if split not in ['train', 'val']: + assert split == 'test', f'Unknown split: {split}' + feature_dict['multiple_choice_correct_idx'] = -1 + feature_dict['direct_answers'] = [] + feature_dict['answer_rationales'] = [] + yield f'{question_id}', feature_dict + + def get_image_path(self, data_dir: str, split: str, image_id: int) -> str: + return f'{data_dir}/{split}2017/{image_id:012d}.jpg' + + +def get_annotations( + data_dir: str, split: str) -> dict[int, dict[str, Any]]: + """Return okvqa annotations (quesions and answers) as dictionary.""" + path = os.path.join(data_dir, ANNOTATION_FILES[split]) + with open(path) as f: + annotations = json.load(f) + + aokvqa_annotations = {} + for annotation in annotations: + # Sanity checks + assert len(annotation['choices']) == 4 + + question_id = annotation['question_id'] + + aokvqa_annotations[question_id] = { + 'image_id': annotation['image_id'], + 'direct_answer_is_difficult': annotation['difficult_direct_answer'], + 'multiple_choice_possible_answers': annotation['choices'], + 'question': annotation['question'], + 'question_id': annotation['question_id'], + } + + # Get answers and rationales for train and val only, not for test. + if split in ['train', 'val']: + assert len(annotation['direct_answers']) == 10 + assert len(annotation['rationales']) == 3 + + aokvqa_annotations[question_id]['direct_answers'] = annotation[ + 'direct_answers'] + aokvqa_annotations[question_id]['answer_rationales'] = annotation[ + 'rationales'] + aokvqa_annotations[question_id]['multiple_choice_correct_idx'] = ( + annotation['correct_choice_idx']) + + return aokvqa_annotations diff --git a/big_vision/datasets/chartqa/chartqa.py b/big_vision/datasets/chartqa/chartqa.py new file mode 100644 index 0000000000000000000000000000000000000000..90748b691cd47b16b81ba8fdb733d26694963745 --- /dev/null +++ b/big_vision/datasets/chartqa/chartqa.py @@ -0,0 +1,122 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""Implements CharQA in TFDS structure. + +It's small data, so simple to run locally. First, copy the data to local disk: + + mkdir -p /tmp/data + wget -O /tmp/data/chartqa.zip https://huggingface.co/datasets/ahmed-masry/ChartQA/resolve/main/ChartQA%20Dataset.zip?download=true + unzip /tmp/data/chartqa.zip + +Then, run conversion locally (make sure to install tensorflow-datasets for the `tfds` util): + + cd big_vision/datasets + env TFDS_DATA_DIR=/tmp/tfds tfds build --datasets=chartqa + +Example to load: + + import tensorflow_datasets as tfds + dataset_augmented = tfds.load('chartqa/augmented', split='train', data_dir='/tmp/tfds') +""" +import json +import os + +import numpy as np +import tensorflow_datasets as tfds + + +_DESCRIPTION = """ChartQA dataset.""" + +# pylint: disable=line-too-long +_CITATION = """ +@inproceedings{masry-etal-2022-chartqa, + title = "{C}hart{QA}: A Benchmark for Question Answering about Charts with Visual and Logical Reasoning", + author = "Masry, Ahmed and + Do, Xuan Long and + Tan, Jia Qing and + Joty, Shafiq and + Hoque, Enamul", + editor = "Muresan, Smaranda and + Nakov, Preslav and + Villavicencio, Aline", + booktitle = "Findings of the Association for Computational Linguistics: ACL 2022", + month = may, + year = "2022", + address = "Dublin, Ireland", + publisher = "Association for Computational Linguistics", + url = "https://aclanthology.org/2022.findings-acl.177", + doi = "10.18653/v1/2022.findings-acl.177", + pages = "2263--2279", + abstract = "Charts are very popular for analyzing data. When exploring charts, people often ask a variety of complex reasoning questions that involve several logical and arithmetic operations. They also commonly refer to visual features of a chart in their questions. However, most existing datasets do not focus on such complex reasoning questions as their questions are template-based and answers come from a fixed-vocabulary. In this work, we present a large-scale benchmark covering 9.6K human-written questions as well as 23.1K questions generated from human-written chart summaries. To address the unique challenges in our benchmark involving visual and logical reasoning over charts, we present two transformer-based models that combine visual features and the data table of the chart in a unified way to answer questions. While our models achieve the state-of-the-art results on the previous datasets as well as on our benchmark, the evaluation also reveals several challenges in answering complex reasoning questions.", +} +""" +# pylint: enable=line-too-long + +# When running locally (recommended), copy files as above an use these: +_CHARTQA_PATH = '/tmp/data/ChartQA Dataset/' + + +class ChartQAConfig(tfds.core.BuilderConfig): + """Configuration to build the dataset.""" + pass + + +class ChartQA(tfds.core.GeneratorBasedBuilder): + """DatasetBuilder for ChartQA dataset.""" + + VERSION = tfds.core.Version('1.0.0') + RELEASE_NOTES = {'1.0.0': 'First release.'} + BUILDER_CONFIGS = [ + ChartQAConfig(name='human', description='Human set'), + ChartQAConfig(name='augmented', description='Augmented set'), + ] + + def _info(self): + """Returns the metadata.""" + return tfds.core.DatasetInfo( + builder=self, + description=_DESCRIPTION, + features=tfds.features.FeaturesDict({ + 'question_id': tfds.features.Scalar(np.int32), + 'image/filename': tfds.features.Text(), + 'image': tfds.features.Image(encoding_format='png'), + 'question': tfds.features.Text(), + 'answer': tfds.features.Text(), + }), + homepage='https://github.com/vis-nlp/ChartQA', + citation=_CITATION, + ) + + def _split_generators(self, dl_manager: tfds.download.DownloadManager): + """Returns SplitGenerators.""" + return {split: self._generate_examples(split, self.builder_config.name) + for split in ('val', 'train', 'test')} + + def _generate_examples(self, split: str, source: str): + """Yields (key, example) tuples from test set.""" + annot_fname = os.path.join(_CHARTQA_PATH, split, f'{split}_{source}.json') + + with open(annot_fname, 'r') as f: + data = json.loads(f.read()) + + for idx, v in enumerate(data): + yield idx, { + 'question_id': idx, + 'image/filename': v['imgname'], + 'image': os.path.join(_CHARTQA_PATH, split, 'png', v['imgname']), + 'question': v['query'], + 'answer': v['label'], + } diff --git a/big_vision/datasets/coco35l/coco35l.py b/big_vision/datasets/coco35l/coco35l.py new file mode 100644 index 0000000000000000000000000000000000000000..0f390222506c8a16a0bcbe5eab495a2e7efa30d0 --- /dev/null +++ b/big_vision/datasets/coco35l/coco35l.py @@ -0,0 +1,154 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""Generates COCO-35L in a TFDS-ready structure. + +First, download the captions from https://google.github.io/crossmodal-3600/ and the images from https://cocodataset.org/#download. +The coco Karpathy split is available at http://cs.stanford.edu/people/karpathy/deepimagesent/caption_datasets.zip: + mkdir -p /tmp/data/coco35l/images + wget https://storage.googleapis.com/crossmodal-3600/coco_mt_train.jsonl.bz2 -P /tmp/data/coco35l + wget https://storage.googleapis.com/crossmodal-3600/coco_mt_dev.jsonl.bz2 -P /tmp/data/coco35l + bzip2 -dk /tmp/data/coco35l/coco_mt_train.jsonl.bz2 /tmp/data/coco35l/coco_mt_dev.jsonl.bz2 + wget http://cs.stanford.edu/people/karpathy/deepimagesent/caption_datasets.zip -P /tmp/data/coco35l + unzip /tmp/data/coco35l/caption_datasets.zip -d /tmp/data/coco35l/ + wget http://images.cocodataset.org/zips/train2014.zip -P /tmp/data/coco35l/images + wget http://images.cocodataset.org/zips/val2014.zip -P /tmp/data/coco35l/images + unzip /tmp/data/coco35l/images/train2014.zip -d /tmp/data/coco35l/images/ + unzip /tmp/data/coco35l/images/val2014.zip -d /tmp/data/coco35l/images/ + +Then, run conversion locally (make sure to install tensorflow-datasets for the `tfds` util): + + cd big_vision/datasets + env TFDS_DATA_DIR=/tmp/tfds tfds build --datasets=coco35l + +Example to load: + + import tensorflow_datasets as tfds + dataset = tfds.load( + 'coco35l', split='dev_en', + data_dir='/tmp/tfds') +""" + +import json +import os.path + +import tensorflow_datasets as tfds + +_DESCRIPTION = """ +COCO image + captions, translated from English to 35 languages (English incl.). +""" + +# pylint: disable=line-too-long +_CITATION = """ +@inproceedings{thapliyal-etal-2022-crossmodal, + title = "Crossmodal-3600: A Massively Multilingual Multimodal Evaluation Dataset", + author = "Thapliyal, Ashish V. and + Pont Tuset, Jordi and + Chen, Xi and + Soricut, Radu", + editor = "Goldberg, Yoav and + Kozareva, Zornitsa and + Zhang, Yue", + booktitle = "Proceedings of the 2022 Conference on Empirical Methods in Natural Language Processing", + month = dec, + year = "2022", + address = "Abu Dhabi, United Arab Emirates", + publisher = "Association for Computational Linguistics", + url = "https://aclanthology.org/2022.emnlp-main.45", + doi = "10.18653/v1/2022.emnlp-main.45", + pages = "715--729", +} +""" +# pylint: enable=line-too-long + + +_CAPTIONS_PATH = '/tmp/data/coco35l' +_IMAGES_PATH = '/tmp/data/mscoco/images' +_COCOCAPS_PATH = '/tmp/data/mscoco/dataset_coco.json' + +LANGUAGES = [ + 'ar', 'bn', 'cs', 'da', 'de', 'el', 'en', 'es', 'fa', 'fi', 'fil', 'fr', + 'he', 'hi', 'hr', 'hu', 'id', 'it', 'ja', 'ko', 'mi', 'nl', 'no', 'pl', + 'pt', 'ro', 'ru', 'sv', 'sw', 'te', 'th', 'tr', 'uk', 'vi', 'zh', +] + + +class Coco35l(tfds.core.GeneratorBasedBuilder): + """DatasetBuilder for COCO-35L dataset.""" + + VERSION = tfds.core.Version('1.0.0') + RELEASE_NOTES = {'1.0.0': 'First release.'} + + def _info(self): + """Returns the metadata.""" + + return tfds.core.DatasetInfo( + builder=self, + description=_DESCRIPTION, + features=tfds.features.FeaturesDict({ + 'image/id': tfds.features.Text(), + 'image': tfds.features.Image(encoding_format='jpeg'), + 'captions': tfds.features.Sequence(tfds.features.Text()), + 'language': tfds.features.Text(), + }), + supervised_keys=None, + homepage='https://google.github.io/crossmodal-3600/', + citation=_CITATION, + ) + + def _split_generators(self, dl_manager: tfds.download.DownloadManager): + """Returns SplitGenerators.""" + splits = [] + for lang in LANGUAGES: + splits.extend([f'train_{lang}', f'dev_{lang}']) + return {split: self._generate_examples(split) for split in splits} + + def _generate_examples(self, split: str): + """Yields (key, example) tuples from dataset.""" + split, language = split.split('_') + + id_to_path = dict() + with open(_COCOCAPS_PATH, 'r') as f: + data = json.load(f)['images'] + for d in data: + id_to_path[d['cocoid']] = os.path.join( + _IMAGES_PATH, d['filepath'], d['filename'] + ) + + annot_fname = os.path.join(_CAPTIONS_PATH, f'coco_mt_{split}.jsonl') + data = {} + with open(annot_fname, 'r') as f: + for line in f: + j = json.loads(line) + image_id = f'{j["image_id"].split("_")[0]}_{language}' + if image_id not in data: + data[image_id] = [] + if language == 'en': + # COCO-35L was constructed from English into 35 other languages. + # To add English in our TFDS, we just select a language (eg. "de") to + # have each unique example, and add the corresponding source caption. + if j['trg_lang'] == 'de': + data[image_id].append(j['caption_tokenized']) + else: + if j['trg_lang'] == language: + data[image_id].append(j['translation_tokenized']) + + for image_id, captions in data.items(): + yield image_id, { + 'image/id': image_id, + 'image': id_to_path[int(image_id.split('_')[0])], + 'captions': captions, + 'language': language, + } diff --git a/big_vision/datasets/core.py b/big_vision/datasets/core.py new file mode 100644 index 0000000000000000000000000000000000000000..07d2a2c6814646908fc5133cb5a54aec6d3b57b3 --- /dev/null +++ b/big_vision/datasets/core.py @@ -0,0 +1,77 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Core data functions, dispatch calls to the requested dataset.""" +import importlib + + +# Note: intentionally not using ABC to avoid forcing implementation of every +# method, since one can imagine train-only datasets for example. +class DataSource: + """The API that any data source should implement.""" + + def get_tfdata(self, ordered, *, process_split=True, allow_cache=True): + """Creates this data object as a tf.data.Dataset. + + This will be called separately in each process, and it is up to the dataset + implementation to shard it accordingly if desired! + + Args: + ordered: if True, the dataset should use deterministic ordering, if False + it may have undefined ordering. Think of True == val, False == train. + process_split: if False then every process receives the entire dataset + (e.g. for evaluators running in a single process). + allow_cache: whether to allow caching the opened data or not. + + Returns: + A tf.data.Dataset object. + + Raises: + RuntimeError: if not implemented by the dataset, but called. + """ + raise RuntimeError("not implemented for {self.__class__.__name__}") + + @property + def total_examples(self): + """Returns number of examples in the dataset, regardless of sharding.""" + raise RuntimeError("not implemented for {self.__class__.__name__}") + + def num_examples_per_process(self): + """Returns a list of the numer of examples for each process. + + This is only needed for datasets that should go through make_for_inference. + + Returns: + Returns a list of the numer of examples for each process. + + Ideally, this would always be `[total() / nprocess] * nprocess`, but in + reality we can almost never perfectly shard a dataset across arbitrary + number of processes. + + One alternative option that can work in some cases is to not even shard + the dataset and thus return `[num_examples()] * nprocess. + + Raises: + RuntimeError: if not implemented by the dataset, but called. + """ + raise RuntimeError("not implemented for {self.__class__.__name__}") + + +def get(name, **kw): + if name.startswith("bv:"): + mod = importlib.import_module(f"big_vision.datasets.{name[3:]}") + return mod.DataSource(**kw) + else: + mod = importlib.import_module("big_vision.datasets.tfds") + return mod.DataSource(name, **kw) diff --git a/big_vision/datasets/countbenchqa/countbenchqa.py b/big_vision/datasets/countbenchqa/countbenchqa.py new file mode 100644 index 0000000000000000000000000000000000000000..0994cae86b2ff4758dd4611c0ae4295c2d86b4e6 --- /dev/null +++ b/big_vision/datasets/countbenchqa/countbenchqa.py @@ -0,0 +1,164 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +"""Import CountBenchQA dataset (CountBench dataset with added QA annotations). + +It's small data, so simple to run locally. First, download all the data: + + mkdir /tmp/data/ ; cd /tmp/data + wget https://huggingface.co/datasets/nielsr/countbench/resolve/main/data/train-00000-of-00001-cf54c241ba947306.parquet + wget https://raw.githubusercontent.com/teaching-clip-to-count/teaching-clip-to-count.github.io/main/CountBench.json + +Then, update the PATHs below and run conversion locally like so: + + cd big_vision/datasets + env TFDS_DATA_DIR=/tmp/tfds tfds build --datasets=countbenchqa + +The dataset contains 540 images so the dataset creation is very quick. + +There is a single split called huggingface to denote that the images come from +the hugginface parquet file. +""" + +import io +import json + +import numpy as np +import pandas as pd +import PIL +import tensorflow_datasets as tfds + + +# Huggingface dataset path; this is missing about 10% of the images. +_COUNTBENCH_PARQUET_PATH = '/tmp/data/train-00000-of-00001-cf54c241ba947306.parquet' +# Public path to the original CountBench JSON file. +_COUNTBENCH_JSON_PATH = '/tmp/data/CountBench.json' +# VQA annotations +_QA_JSON_PATH = 'countbenchqa/data/countbench_paired_questions.json' + +_DESCRIPTION = """ +CountBench: We introduce a new object counting benchmark called CountBench, + automatically curated (and manually verified) from the publicly available + LAION-400M image-text dataset. CountBench contains a total of 540 images + containing between two and ten instances of a particular object, where their + corresponding captions reflect this number. + +CountBenchQA: Each image is paired with a manually generated question about the + number of objects in the image to turn CountBench into a VQA task. +""" + +_CITATION = """ +@article{beyer2024paligemma, + title={{PaliGemma: A versatile 3B VLM for transfer}}, + author={Lucas Beyer and Andreas Steiner and André Susano Pinto and Alexander Kolesnikov and Xiao Wang and Daniel Salz and Maxim Neumann and Ibrahim Alabdulmohsin and Michael Tschannen and Emanuele Bugliarello and Thomas Unterthiner and Daniel Keysers and Skanda Koppula and Fangyu Liu and Adam Grycner and Alexey Gritsenko and Neil Houlsby and Manoj Kumar and Keran Rong and Julian Eisenschlos and Rishabh Kabra and Matthias Bauer and Matko Bošnjak and Xi Chen and Matthias Minderer and Paul Voigtlaender and Ioana Bica and Ivana Balazevic and Joan Puigcerver and Pinelopi Papalampidi and Olivier Henaff and Xi Xiong and Radu Soricut and Jeremiah Harmsen and Xiaohua Zhai}, + year={2024}, + journal={arXiv preprint arXiv:2407.07726} +} + +@article{paiss2023countclip, + title={{Teaching CLIP to Count to Ten}}, + author={Paiss, Roni and Ephrat, Ariel and Tov, Omer and Zada, Shiran and Mosseri, Inbar and Irani, Michal and Dekel, Tali}, + year={2023}, + journal={arXiv preprint arXiv:2302.12066} +} +""" + +_HOMEPAGE = 'https://teaching-clip-to-count.github.io/' + + +class CountbenchQA(tfds.core.GeneratorBasedBuilder): + """Create CountbenchQA dataset.""" + + VERSION = tfds.core.Version('1.2.0') + RELEASE_NOTES = {'1.1.0': 'Add `huggingface` split.', + '1.2.0': 'Fix image loading for `huggingface` split.'} + MANUAL_DOWNLOAD_INSTRUCTIONS = """ + There are two parts which should be downloaded: + * Countbench from Huggingface + * Questions found in `data/countbench_paired_questions.json` + """ + + def _info(self) -> tfds.core.DatasetInfo: + """Returns the dataset metadata.""" + features = tfds.features.FeaturesDict({ + 'image': tfds.features.Image(shape=(None, None, 3)), + 'image_id': tfds.features.Scalar(dtype=np.int32), + 'question': tfds.features.Text(), + 'text': tfds.features.Text(), + 'image_url': tfds.features.Text(), + 'number': tfds.features.Scalar(dtype=np.int32), + }) + + return tfds.core.DatasetInfo( + builder=self, + features=features, + description=_DESCRIPTION, + supervised_keys=None, + homepage=_HOMEPAGE, + citation=_CITATION, + ) + + def _split_generators(self, dl_manager: tfds.download.DownloadManager): + """Call the function which defines the splits.""" + del dl_manager + return { + 'huggingface': self._generate_examples(split='huggingface'), + } + + def _generate_examples_hf(self): + """Generate examples from Huggingface parquet file. + + Note that the parquet file provided on Huggingface is missing about 10% + of the images as can be verified by running + ``` + import pyarrow.parquet as pq + with open(_COUNTBENCH_PARQUET_PATH, 'rb') as f: + x = pq.read_table(f) + sum([x['image'][i].is_valid for i in range(len(x['image']))]) # result: 491 + ``` + + Yields: + An index and a dictionary with features. + """ + with open(_COUNTBENCH_PARQUET_PATH, 'rb') as f: + df = pd.read_parquet(f) + + with open(_QA_JSON_PATH, 'r') as fq: + df_question = pd.read_json(fq) + + df['question'] = df_question + + for idx, row in df.iterrows(): + # Some entries have no image. + if row['image'] is None: + continue + image = np.array(PIL.Image.open(io.BytesIO(row['image']['bytes']))) + if len(image.shape) != 3: + continue # Filter out one bad image. + countbenchqa_dict = { + 'image': image, + 'image_id': idx, + 'question': row['question'], + 'text': row['text'], + 'image_url': row['image_url'], + 'number': row['number'], + } + yield idx, countbenchqa_dict + + def _generate_examples(self, split: str): + if split == 'huggingface': + yield from self._generate_examples_hf() + else: + raise ValueError(f'Unknown split: {split}') diff --git a/big_vision/datasets/countbenchqa/data/countbench_paired_questions.json b/big_vision/datasets/countbenchqa/data/countbench_paired_questions.json new file mode 100644 index 0000000000000000000000000000000000000000..bb684fac07600d00e624d203878736e5e3e90e1d --- /dev/null +++ b/big_vision/datasets/countbenchqa/data/countbench_paired_questions.json @@ -0,0 +1 @@ +[{"question": "How many headsets are there in the image?"}, {"question": "How many light bulbs are there in the image?"}, {"question": "How many prints are there in the image?"}, {"question": "How many arrows are there in the image?"}, {"question": "How many spoons are there in the image?"}, {"question": "How many girls are there in the image?"}, {"question": "How many parrots are there in the image?"}, {"question": "How many coloring pages are there in the image?"}, {"question": "How many food containers are there in the image?"}, {"question": "How many birdhouse patterns are there in the image?"}, {"question": "How many sofas are there in the image?"}, {"question": "How many waterlilies are there in the image?"}, {"question": "How many chairs are there in the image?"}, {"question": "How many eggs are there in the image?"}, {"question": "How many golfers are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many outfits are there in the image?"}, {"question": "How many pigs are there in the image?"}, {"question": "How many cars are there in the image?"}, {"question": "How many aum symbols are there in the image?"}, {"question": "How many men are there in the image?"}, {"question": "How many chairs are there in the image?"}, {"question": "How many buttons are there in the image?"}, {"question": "How many rackets are there in the image?"}, {"question": "How many pots are there in the image?"}, {"question": "How many stars are there in the image?"}, {"question": "How many chairs are there in the image?"}, {"question": "How many dice are there in the image?"}, {"question": "How many silhouettes are there in the image?"}, {"question": "How many kids are there in the image?"}, {"question": "How many moth silhouettes are there in the image?"}, {"question": "How many pumpkin candles are there in the image?"}, {"question": "How many essential oils are there in the image?"}, {"question": "How many stencils are there in the image?"}, {"question": "How many text boxes are there in the image?"}, {"question": "How many basketball players are there in the image?"}, {"question": "How many photos are there in the image?"}, {"question": "How many forks are there in the image?"}, {"question": "How many icons are there in the image?"}, {"question": "How many bags are there in the image?"}, {"question": "How many couples are there in the image?"}, {"question": "How many weights are there in the image?"}, {"question": "How many women are there in the image?"}, {"question": "How many fish are there in the image?"}, {"question": "How many icons are there in the image?"}, {"question": "How many clocks are there in the image?"}, {"question": "How many banners are there in the image?"}, {"question": "How many chests are there in the image?"}, {"question": "How many stars are there in the image?"}, {"question": "How many men are there in the image?"}, {"question": "How many globe icons are there in the image?"}, {"question": "How many posters are there in the image?"}, {"question": "How many socks are there in the image?"}, {"question": "How many men are there in the image?"}, {"question": "How many snails are there in the image?"}, {"question": "How many crochet potholders are there in the image?"}, {"question": "How many christmas cards are there in the image?"}, {"question": "How many double beds are there in the image?"}, {"question": "How many baseball players are there in the image?"}, {"question": "How many icons are there in the image?"}, {"question": "How many symbols are there in the image?"}, {"question": "How many baseball players are there in the image?"}, {"question": "How many cars are there in the image?"}, {"question": "How many trees are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many tree trunk cuts are there in the image?"}, {"question": "How many women are there in the image?"}, {"question": "How many chairs are there in the image?"}, {"question": "How many icons are there in the image?"}, {"question": "How many banners are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many individual earrings are there in the image?"}, {"question": "How many men are there in the image?"}, {"question": "How many tomatoes are there in the image?"}, {"question": "How many chairs are there in the image?"}, {"question": "How many prints are there in the image?"}, {"question": "How many ice creams are there in the image?"}, {"question": "How many plates are there in the image?"}, {"question": "How many sumo wrestlers are there in the image?"}, {"question": "How many compositions are there in the image?"}, {"question": "How many PVC vinyls are there in the image?"}, {"question": "How many photos of fruit are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many icons are there in the image?"}, {"question": "How many bottles are there in the image?"}, {"question": "How many mirrors are there in the image?"}, {"question": "How many groomsmen are there in the image?"}, {"question": "How many posters are there in the image?"}, {"question": "How many pumpkins are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many kittens are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many boots are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many figurines are there in the image?"}, {"question": "How many moais are there in the image?"}, {"question": "How many icons are there in the image?"}, {"question": "How many cats are there in the image?"}, {"question": "How many wallpaper variants are there in the image?"}, {"question": "How many nail polishes are there in the image?"}, {"question": "How many bumble bees are there in the image?"}, {"question": "How many tickets are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many wine bottles are there in the image?"}, {"question": "How many silhouettes of couples are there in the image?"}, {"question": "How many owls are there in the image?"}, {"question": "How many candles are there in the image?"}, {"question": "How many flower pots are there in the image?"}, {"question": "How many coins are there in the image?"}, {"question": "How many placemats are there in the image?"}, {"question": "How many women are there in the image?"}, {"question": "How many bottles are there in the image?"}, {"question": "How many cars are there in the image?"}, {"question": "How many men are there in the image?"}, {"question": "How many candles are there in the image?"}, {"question": "How many pairs of socks are there in the image?"}, {"question": "How many blinds are there in the image?"}, {"question": "How many floral patterns are there in the image?"}, {"question": "How many icons are there in the image?"}, {"question": "How many dice are there in the image?"}, {"question": "How many bicycles are there in the image?"}, {"question": "How many dwarfs are there in the image?"}, {"question": "How many stickers are there in the image?"}, {"question": "How many symbols are there in the image?"}, {"question": "How many stamps are there in the image?"}, {"question": "How many pumpkins are there in the image?"}, {"question": "How many game cartridges are there in the image?"}, {"question": "How many chairs are there in the image?"}, {"question": "How many glasses are there in the image?"}, {"question": "How many bottles are there in the image?"}, {"question": "How many coins are there in the image?"}, {"question": "How many gears are there in the image?"}, {"question": "How many flowers are there in the image?"}, {"question": "How many crip packages are there in the image?"}, {"question": "How many bridesmaids are there in the image?"}, {"question": "How many apples are there in the image?"}, {"question": "How many bowls are there in the image?"}, {"question": "How many women are there in the image?"}, {"question": "How many photographs are there in the image?"}, {"question": "How many warning signs are there in the image?"}, {"question": "How many children are there in the image?"}, {"question": "How many cyclists are there in the image?"}, {"question": "How many women are there in the image?"}, {"question": "How many chairs are there in the image?"}, {"question": "How many giraffes are there in the image?"}, {"question": "How many men are there in the image?"}, {"question": "How many male nurses are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many couples are there in the image?"}, {"question": "How many cars are there in the image?"}, {"question": "How many apples are there in the image?"}, {"question": "How many smartphones are there in the image?"}, {"question": "How many roses are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many greeting cards are there in the image?"}, {"question": "How many guitars are there in the image?"}, {"question": "How many ironman suits are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many CDs are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many symbols are there in the image?"}, {"question": "How many chairs are there in the image?"}, {"question": "How many pot holders are there in the image?"}, {"question": "How many stamps are there in the image?"}, {"question": "How many bookmarks are there in the image?"}, {"question": "How many portraits are there in the image?"}, {"question": "How many girls are there in the image?"}, {"question": "How many labels are there in the image?"}, {"question": "How many mandalas are there in the image?"}, {"question": "How many silhouettes are there in the image?"}, {"question": "How many peacocks are there in the image?"}, {"question": "How many roses are there in the image?"}, {"question": "How many cars are there in the image?"}, {"question": "How many children are there in the image?"}, {"question": "How many chairs are there in the image?"}, {"question": "How many symbols are there in the image?"}, {"question": "How many canvases are there in the image?"}, {"question": "How many cards are there in the image?"}, {"question": "How many bell pepper halves are there in the image?"}, {"question": "How many pigs are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many newspapers are there in the image?"}, {"question": "How many leggings are there in the image?"}, {"question": "How many medals are there in the image?"}, {"question": "How many patterns are there in the image?"}, {"question": "How many women are there in the image?"}, {"question": "How many doors are there in the image?"}, {"question": "How many pairs of socks are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many backgrounds are there in the image?"}, {"question": "How many images of dogs are there in the image?"}, {"question": "How many broadheads are there in the image?"}, {"question": "How many figurines are there in the image?"}, {"question": "How many candles are there in the image?"}, {"question": "How many petals does each flower have in this image?"}, {"question": "How many crayons are there in the image?"}, {"question": "How many men are there in the image?"}, {"question": "How many caps are there in the image?"}, {"question": "How many children are there in the image?"}, {"question": "How many frames are there in the image?"}, {"question": "How many animals are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many cats are there in the image?"}, {"question": "How many sconces are there in the image?"}, {"question": "How many spoons are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many icons are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many planes are there in the image?"}, {"question": "How many cats are there in the image?"}, {"question": "How many sketches are there in the image?"}, {"question": "How many trees are there in the image?"}, {"question": "How many dice are there in the image?"}, {"question": "How many figurines are there in the image?"}, {"question": "How many napkins are there in the image?"}, {"question": "How many dogs are there in the image?"}, {"question": "How many kids are there in the image?"}, {"question": "How many hearts are there in the image?"}, {"question": "How many apples are there in the image?"}, {"question": "How many post-its are there in the image?"}, {"question": "How many dice are there in the image?"}, {"question": "How many books are there in the image?"}, {"question": "How many women are there in the image?"}, {"question": "How many sofas are there in the image?"}, {"question": "How many dice are there in the image?"}, {"question": "How many girls are there in the image?"}, {"question": "How many sketches are there in the image?"}, {"question": "How many men are there in the image?"}, {"question": "How many moons are there in the image?"}, {"question": "How many labels are there in the image?"}, {"question": "How many cylinders are there in the image?"}, {"question": "How many silhouettes of couples are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many dogs are there in the image?"}, {"question": "How many people on stage are there in the image?"}, {"question": "How many lambs are there in the image?"}, {"question": "How many violins are there in the image?"}, {"question": "How many armchairs are there in the image?"}, {"question": "How many silhouettes are there in the image?"}, {"question": "How many eggs are there in the image?"}, {"question": "How many photos are there in the image?"}, {"question": "How many chairs are there in the image?"}, {"question": "How many figurines are there in the image?"}, {"question": "How many moais are there in the image?"}, {"question": "How many symbols are there in the image?"}, {"question": "How many glasses are there in the image?"}, {"question": "How many sinks are there in the image?"}, {"question": "How many women are there in the image?"}, {"question": "How many buildings are there in the image?"}, {"question": "How many flyers are there in the image?"}, {"question": "How many dice are there in the image?"}, {"question": "How many pizza slices are there in the image?"}, {"question": "How many stamps are there in the image?"}, {"question": "How many stickers are there in the image?"}, {"question": "How many gifts are there in the image?"}, {"question": "How many bowls are there in the image?"}, {"question": "How many onesies are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many posters are there in the image?"}, {"question": "How many cards are there in the image?"}, {"question": "How many portraits are there in the image?"}, {"question": "How many poinsettias are there in the image?"}, {"question": "How many chicken thighs are there in the image?"}, {"question": "How many glass windows are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many baskets are there in the image?"}, {"question": "How many tulips are there in the image?"}, {"question": "How many bottles are there in the image?"}, {"question": "How many stuffed animals are there in the image?"}, {"question": "How many keychains are there in the image?"}, {"question": "How many photos are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many adults are there in the image?"}, {"question": "How many tulips are there in the image?"}, {"question": "How many frames are there in the image?"}, {"question": "How many samosas are there in the image?"}, {"question": "How many strawberries are there in the image?"}, {"question": "How many cocktails are there in the image?"}, {"question": "How many men are there in the image?"}, {"question": "How many wineglasses are there in the image?"}, {"question": "How many goblets are there in the image?"}, {"question": "How many prints are there in the image?"}, {"question": "How many flowers are there in the image?"}, {"question": "How many zebras are there in the image?"}, {"question": "How many paint brushes are there in the image?"}, {"question": "How many prints are there in the image?"}, {"question": "How many figurines are there in the image?"}, {"question": "How many glasses are there in the image?"}, {"question": "How many sunglasses are there in the image?"}, {"question": "How many banners are there in the image?"}, {"question": "How many bottle caps are there in the image?"}, {"question": "How many prints are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many eggs are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many books are there in the image?"}, {"question": "How many animals are there in the image?"}, {"question": "How many eyeshadows are there in the image?"}, {"question": "How many keychains are there in the image?"}, {"question": "How many pairs of earrings are there in the image?"}, {"question": "How many canisters are there in the image?"}, {"question": "How many bags are there in the image?"}, {"question": "How many baking trays are there in the image?"}, {"question": "How many diamonds are there in the image?"}, {"question": "How many portraits are there in the image?"}, {"question": "How many framed images are there in the image?"}, {"question": "How many flags are there in the image?"}, {"question": "How many chairs are there in the image?"}, {"question": "How many framed pictures are there in the image?"}, {"question": "How many eggs are there in the image?"}, {"question": "How many croissants are there in the image?"}, {"question": "How many Manikins are there in the image?"}, {"question": "How many symbols are there in the image?"}, {"question": "How many labels are there in the image?"}, {"question": "How many people in the foreground are there in the image?"}, {"question": "How many armchairs are there in the image?"}, {"question": "How many cups are there in the image?"}, {"question": "How many helicopters are there in the image?"}, {"question": "How many icons are there in the image?"}, {"question": "How many children are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many buttons are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many Cookies are there in the image?"}, {"question": "How many sculptures are there in the image?"}, {"question": "How many school uniforms are there in the image?"}, {"question": "How many sculptures are there in the image?"}, {"question": "How many eggs are there in the image?"}, {"question": "How many packages are there in the image?"}, {"question": "How many bottles are there in the image?"}, {"question": "How many stories does this cottage have?"}, {"question": "How many gift cards are there in the image?"}, {"question": "How many eggs are there in the image?"}, {"question": "How many beers are there in the image?"}, {"question": "How many children are there in the image?"}, {"question": "How many planes are there in the image?"}, {"question": "How many cactus pots are there in the image?"}, {"question": "How many smartphones are there in the image?"}, {"question": "How many picture frames are there in the image?"}, {"question": "How many elephants are there in the image?"}, {"question": "How many guitars are there in the image?"}, {"question": "How many samurai are there in the image?"}, {"question": "How many ghosts are there in the image?"}, {"question": "How many bottles are there in the image?"}, {"question": "How many candles are there in the image?"}, {"question": "How many vases are there in the image?"}, {"question": "How many sets of headphones are there in the image?"}, {"question": "How many pandas are there in the image?"}, {"question": "How many books are there in the image?"}, {"question": "How many stickers are there in the image?"}, {"question": "How many rings are there in the image?"}, {"question": "How many dice are there in the image?"}, {"question": "How many dragon balls are there in the image?"}, {"question": "How many tanks are there in the image?"}, {"question": "How many students are there in the image?"}, {"question": "How many cups are there in the image?"}, {"question": "How many cubs are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many cushions are there in the image?"}, {"question": "How many shoes are there in the image?"}, {"question": "How many beers are there in the image?"}, {"question": "How many wine glasses are there in the image?"}, {"question": "How many cards are there in the image?"}, {"question": "How many boots are there in the image?"}, {"question": "How many stickers are there in the image?"}, {"question": "How many dice are there in the image?"}, {"question": "How many butterflies are there in the image?"}, {"question": "How many women are there in the image?"}, {"question": "How many butterflies are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many chairs are there in the image?"}, {"question": "How many hexagons are there in the image?"}, {"question": "How many chairs are there in the image?"}, {"question": "How many football players are there in the image?"}, {"question": "How many gifts are there in the image?"}, {"question": "How many light bulbs are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many chairs are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many women are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many candles are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many silhouettes are there in the image?"}, {"question": "How many pots are there in the image?"}, {"question": "How many pumpkins are there in the image?"}, {"question": "How many owls are there in the image?"}, {"question": "How many doctors are there in the image?"}, {"question": "How many pigs are there in the image?"}, {"question": "How many pillars are there in the image?"}, {"question": "How many symbols are there in the image?"}, {"question": "How many pumpkins are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many eggs are there in the image?"}, {"question": "How many roses are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many bottles are there in the image?"}, {"question": "How many illustrations are there in the image?"}, {"question": "How many photos are there in the image?"}, {"question": "How many butterflies are there in the image?"}, {"question": "How many spoons are there in the image?"}, {"question": "How many potato spreads are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many women are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many wall arts are there in the image?"}, {"question": "How many covers are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many soldiers are there in the image?"}, {"question": "How many posters are there in the image?"}, {"question": "How many chairs are there in the image?"}, {"question": "How many patterns are there in the image?"}, {"question": "How many dogs are there in the image?"}, {"question": "How many pennies are there in the image?"}, {"question": "How many windows are there in the image?"}, {"question": "How many eggs are there in the image?"}, {"question": "How many geese are there in the image?"}, {"question": "How many tickets are there in the image?"}, {"question": "How many figurines are there in the image?"}, {"question": "How many cases are there in the image?"}, {"question": "How many people are in the foreground of this image?"}, {"question": "How many glasses are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many cups are there in the image?"}, {"question": "How many dice are there in the image?"}, {"question": "How many kittens are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many scones are there in the image?"}, {"question": "How many schoolgirls are there in the image?"}, {"question": "How many padlocks are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many windows are there in the image?"}, {"question": "How many icons are there in the image?"}, {"question": "How many banners are there in the image?"}, {"question": "How many dogs are there in the image?"}, {"question": "How many groomsmen are there in the image?"}, {"question": "How many children are there in the image?"}, {"question": "How many glasses are there in the image?"}, {"question": "How many icons are there in the image?"}, {"question": "How many coins are there in the image?"}, {"question": "How many trees are there in the image?"}, {"question": "How many turtles are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many frames are there in the image?"}, {"question": "How many firescreens are there in the image?"}, {"question": "How many bowls are there in the image?"}, {"question": "How many stickers are there in the image?"}, {"question": "How many archaic mirrors are there in the image?"}, {"question": "How many dogs are there in the image?"}, {"question": "How many candles are there in the image?"}, {"question": "How many paint brushes are there in the image?"}, {"question": "How many forks are there in the image?"}, {"question": "How many figurines are there in the image?"}, {"question": "How many owls are there in the image?"}, {"question": "How many icons are there in the image?"}, {"question": "How many silhouettes are there in the image?"}, {"question": "How many fried eggs are there in the image?"}, {"question": "How many guitars are there in the image?"}, {"question": "How many signs are there in the image?"}, {"question": "How many watches are there in the image?"}, {"question": "How many women are there in the image?"}, {"question": "How many children are there in the image?"}, {"question": "How many kittens are there in the image?"}, {"question": "How many kids are there in the image?"}, {"question": "How many kittens are there in the image?"}, {"question": "How many toys are there in the image?"}, {"question": "How many pairs of socks are there in the image?"}, {"question": "How many pairs of socks are there in the image?"}, {"question": "How many men are there in the image?"}, {"question": "How many stars are there in the image?"}, {"question": "How many quarters are there in the image?"}, {"question": "How many women are there in the image?"}, {"question": "How many chairs are there in the image?"}, {"question": "How many diamonds are there in the image?"}, {"question": "How many moais are there in the image?"}, {"question": "How many banners are there in the image?"}, {"question": "How many silhouettes are there in the image?"}, {"question": "How many coins are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many banners are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many flags are there in the image?"}, {"question": "How many frames are there in the image?"}, {"question": "How many contestants are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many globes are there in the image?"}, {"question": "How many animals are there in the image?"}, {"question": "How many bottles are there in the image?"}, {"question": "How many buttons are there in the image?"}, {"question": "How many icons are there in the image?"}, {"question": "How many piles of candy are there in the image?"}, {"question": "How many pictures are there in the image?"}, {"question": "How many plates are there in the image?"}, {"question": "How many calendars are there in the image?"}, {"question": "How many oranges are there in the image?"}, {"question": "How many puppies are there in the image?"}, {"question": "How many buffalos are there in the image?"}, {"question": "How many bottles are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many icons are there in the image?"}, {"question": "How many christmas balls are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many armchairs are there in the image?"}, {"question": "How many frames are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many soldiers are there in the image?"}, {"question": "How many chairs are there in the image?"}, {"question": "How many chairs are there in the image?"}, {"question": "How many cards are there in the image?"}, {"question": "How many colored tiles are there in the image?"}, {"question": "How many trapezoids are there in the image?"}, {"question": "How many pastries are there in the image?"}, {"question": "How many plants are there in the image?"}, {"question": "How many place mats are there in the image?"}, {"question": "How many symbols are there in the image?"}, {"question": "How many mazes are there in the image?"}, {"question": "How many tea bags are there in the image?"}, {"question": "How many photos are there in the image?"}, {"question": "How many children are there in the image?"}, {"question": "How many starfish are there in the image?"}, {"question": "How many mugs are there in the image?"}, {"question": "How many bottles are there in the image?"}, {"question": "How many eggs are there in the image?"}, {"question": "How many icons are there in the image?"}, {"question": "How many balls are there in the image?"}, {"question": "How many paper bags are there in the image?"}, {"question": "How many garage doors are there in the image?"}, {"question": "How many silhouettes are there in the image?"}, {"question": "How many couples are there in the image?"}, {"question": "How many dice are there in the image?"}, {"question": "How many women are there in the image?"}, {"question": "How many icons are there in the image?"}, {"question": "How many wooden spoons are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many chairs are there in the image?"}, {"question": "How many chairs are there in the image?"}, {"question": "How many beds are there in the image?"}, {"question": "How many chairs are there in the image?"}] \ No newline at end of file diff --git a/big_vision/datasets/docvqa/docvqa.py b/big_vision/datasets/docvqa/docvqa.py new file mode 100644 index 0000000000000000000000000000000000000000..f1a496ca24c4b2c5714c10d155585cb6194bd783 --- /dev/null +++ b/big_vision/datasets/docvqa/docvqa.py @@ -0,0 +1,110 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""Implements DocVQA in TFDS structure. + +It's small data, so simple to run locally. First, copy the data to local disk. +An account will be needed in https://rrc.cvc.uab.es/?ch=17&com=downloads and +from there the task annotations and images can be fetched separatedly. + + mkdir -p /tmp/data/docvqa + + +Then, run conversion locally (make sure to install tensorflow-datasets for the `tfds` util): + + cd big_vision/datasets + env TFDS_DATA_DIR=/tmp/tfds tfds build --datasets=docvqa + +Example to load: + + import tensorflow_datasets as tfds + dataset = tfds.load('docvqa', split='val', data_dir='/tmp/tfds') +""" +import json +import os + +import numpy as np +import tensorflow_datasets as tfds + + +_DESCRIPTION = """DocVQA dataset.""" + +# pylint: disable=line-too-long +_CITATION = """ +@article{DBLP:journals/corr/abs-2007-00398, + author = {Minesh Mathew and + Dimosthenis Karatzas and + R. Manmatha and + C. V. Jawahar}, + title = {DocVQA: {A} Dataset for {VQA} on Document Images}, + journal = {CoRR}, + volume = {abs/2007.00398}, + year = {2020}, + url = {https://arxiv.org/abs/2007.00398}, + eprinttype = {arXiv}, + eprint = {2007.00398}, + timestamp = {Mon, 06 Jul 2020 15:26:01 +0200}, + biburl = {https://dblp.org/rec/journals/corr/abs-2007-00398.bib}, + bibsource = {dblp computer science bibliography, https://dblp.org} +} +""" +# pylint: enable=line-too-long + +# When running locally (recommended), copy files as above an use these: +_DOCVQA_PATH = '/tmp/data/docvqa/' + + +class DocVQA(tfds.core.GeneratorBasedBuilder): + """DatasetBuilder for DocVQA dataset.""" + + VERSION = tfds.core.Version('1.0.0') + RELEASE_NOTES = {'1.0.0': 'First release.'} + + def _info(self): + """Returns the metadata.""" + return tfds.core.DatasetInfo( + builder=self, + description=_DESCRIPTION, + features=tfds.features.FeaturesDict({ + 'question_id': tfds.features.Scalar(np.int32), + 'image/filename': tfds.features.Text(), + 'image': tfds.features.Image(encoding_format='png'), + 'question': tfds.features.Text(), + 'answers': tfds.features.Sequence(tfds.features.Text()), + }), + supervised_keys=None, + homepage='https://www.docvqa.org/', + citation=_CITATION, + ) + + def _split_generators(self, dl_manager: tfds.download.DownloadManager): + """Returns SplitGenerators.""" + return {split: self._generate_examples(split) + for split in ('val', 'train', 'test')} + + def _generate_examples(self, split: str): + """Yields (key, example) tuples from split.""" + suffix = '' if split == 'test' else '_withQT' + with open(os.path.join(_DOCVQA_PATH, f'{split}_v1.0{suffix}.json')) as f: + data = json.load(f) + for v in data['data']: + question_id = v['questionId'] + yield question_id, { + 'question_id': question_id, + 'image/filename': v['image'], + 'image': os.path.join(_DOCVQA_PATH, split, v['image']), + 'question': v['question'], + 'answers': v.get('answers', []), + } diff --git a/big_vision/datasets/gqa/gqa.py b/big_vision/datasets/gqa/gqa.py new file mode 100644 index 0000000000000000000000000000000000000000..ea137cbfa41923fca712ee40977f75b855a7966d --- /dev/null +++ b/big_vision/datasets/gqa/gqa.py @@ -0,0 +1,167 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""Generates GQA in a TFDS-ready structure, using Beam. + +Instructions below are to generate the dataset with a *local* Beam pipeline. +It's advisable to run the Beam job on Google Cloud Dataflow, see + https://www.tensorflow.org/datasets/beam_datasets. +for more details, which would significantly speed up generation. This would +involve uploading the locally downloaded data to a GCS bucket, and then +adding in the Beam pipeline options and your GCP/GCS bucket details +to the `tfds build` command below (as detailed in the link). + +First, copy the data to local disk: + + mkdir -p /tmp/data/gqa + wget -O /tmp/data/gqa/question1.2.zip https://downloads.cs.stanford.edu/nlp/data/gqa/questions1.2.zip?download=true + unzip /tmp/data/gqa/question1.2.zip + mv /tmp/data/gqa/question1.2/* /tmp/data/gqa/ + wget -O /tmp/data/gqa/images.zip https://downloads.cs.stanford.edu/nlp/data/gqa/images.zip?download=true + unzip /tmp/data/gqa/images.zip + +Then, run conversion (make sure to install tensorflow-datasets for the `tfds` util): + + cd big_vision/datasets + env TFDS_DATA_DIR=/tmp/tfds tfds build --datasets=gqa + +Example to load: + + import tensorflow_datasets as tfds + dataset = tfds.load('gqa', split='testdev_balanced', data_dir='/tmp/tfds') + +Some statistics: + train_all: 14305356 examples + train_balanced: 943000 examples + val_all: 2011853 examples + val_balanced: 132062 examples + testdev_all: 172174 examples + testdev_balanced: 12578 examples +""" +import glob +import json +import os + +import numpy as np +import tensorflow_datasets as tfds + + +_DESCRIPTION = """GQA: Visual Reasoning in the Real World.""" + +# pylint: disable=line-too-long +_CITATION = """ +@article{DBLP:journals/corr/abs-2306-14610, + author = {Drew Hudson and + Christopher Manning}, + title = {GQA: A New Dataset for Real-World Visual Reasoning and Compositional Question Answering}, + journal = {CVPR}, + volume = {abs/1902.09506}, + year = {2019}, + url = {https://doi.org/10.48550/arXiv.1902.09506}, + doi = {10.48550/arXiv.1902.09506}, + eprinttype = {arXiv}, + eprint = {1902.09506}, + timestamp = {Tue, 25 Jun 2019 00:00:00 +0100}, + biburl = {https://dblp.org/rec/journals/corr/abs-1902-09506}, + bibsource = {dblp computer science bibliography, https://dblp.org} +} +""" +# pylint: enable=line-too-long + + +_DATA_PATH = '/tmp/data/gqa/' + + +class GQA(tfds.core.GeneratorBasedBuilder): + """DatasetBuilder for GQA dataset.""" + + VERSION = tfds.core.Version('1.0.0') + RELEASE_NOTES = {'1.0.0': 'First release.'} + + def _info(self): + """Returns the metadata.""" + + return tfds.core.DatasetInfo( + builder=self, + description=_DESCRIPTION, + features=tfds.features.FeaturesDict({ + 'example_id': tfds.features.Scalar(np.int64), + 'image/id': tfds.features.Text(), + 'image': tfds.features.Image(encoding_format='jpeg'), + 'question': tfds.features.Text(), + 'answer': tfds.features.Text(), + 'full_answer': tfds.features.Text(), + 'is_balanced': tfds.features.Scalar(np.bool_), + }), + homepage='https://cs.stanford.edu/people/dorarad/gqa/', + citation=_CITATION, + ) + + def _split_generators(self, dl_manager: tfds.download.DownloadManager): + """Returns SplitGenerators.""" + splits = [ + # 'debug', + 'train_all', + 'train_balanced', + 'testdev_all', + 'testdev_balanced', + 'val_all', + 'val_balanced', + 'challenge_all', + 'challenge_balanced', + ] + return {split: self._generate_examples(split) for split in splits} + + def _generate_examples(self, split: str): + """Yields (key, example) tuples from dataset.""" + if split == 'train_all': + train_json_dir = os.path.join(_DATA_PATH, 'train_all_questions', '*.json') + json_files = glob.glob(train_json_dir) + else: + json_files = [os.path.join(_DATA_PATH, f'{split}_questions.json')] + + def _prepare_data(json_path): + with open(os.path.join(json_path)) as f: + annotations = json.load(f) + return [(k, v) for k, v in annotations.items()] + + def _process_example(entry): + question_id, question_data = entry + image_id = question_data['imageId'] + image_path = os.path.join(_DATA_PATH, 'images', f'{image_id}.jpg') + answer = question_data['answer'] if 'answer' in question_data else '' + if 'fullAnswer' in question_data: + full_answer = question_data['fullAnswer'] + else: + full_answer = '' + + example = { + 'example_id': question_id, + 'image/id': image_id, + 'image': image_path, + 'question': question_data['question'], + 'answer': answer, + 'full_answer': full_answer, + 'is_balanced': question_data['isBalanced'], + } + return question_id, example + + beam = tfds.core.lazy_imports.apache_beam + return ( + beam.Create(json_files) + | beam.FlatMap(_prepare_data) + | beam.Reshuffle() + | beam.Map(_process_example) + ) diff --git a/big_vision/datasets/imagenet/class_names.py b/big_vision/datasets/imagenet/class_names.py new file mode 100644 index 0000000000000000000000000000000000000000..490594ebc2a23d508d21102cb7d0596c88396fb3 --- /dev/null +++ b/big_vision/datasets/imagenet/class_names.py @@ -0,0 +1,270 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Imagenet class names.""" + +# Copied from +# https://github.com/openai/CLIP/blob/main/notebooks/Prompt_Engineering_for_ImageNet.ipynb +CLIP_IMAGENET_CLASS_NAMES = [ + 'tench', 'goldfish', 'great white shark', 'tiger shark', 'hammerhead shark', + 'electric ray', 'stingray', 'rooster', 'hen', 'ostrich', 'brambling', + 'goldfinch', 'house finch', 'junco', 'indigo bunting', 'American robin', + 'bulbul', 'jay', 'magpie', 'chickadee', 'American dipper', + 'kite (bird of prey)', 'bald eagle', 'vulture', 'great grey owl', + 'fire salamander', 'smooth newt', 'newt', 'spotted salamander', 'axolotl', + 'American bullfrog', 'tree frog', 'tailed frog', 'loggerhead sea turtle', + 'leatherback sea turtle', 'mud turtle', 'terrapin', 'box turtle', + 'banded gecko', 'green iguana', 'Carolina anole', + 'desert grassland whiptail lizard', 'agama', 'frilled-necked lizard', + 'alligator lizard', 'Gila monster', 'European green lizard', 'chameleon', + 'Komodo dragon', 'Nile crocodile', 'American alligator', 'triceratops', + 'worm snake', 'ring-necked snake', 'eastern hog-nosed snake', + 'smooth green snake', 'kingsnake', 'garter snake', 'water snake', + 'vine snake', 'night snake', 'boa constrictor', 'African rock python', + 'Indian cobra', 'green mamba', 'sea snake', 'Saharan horned viper', + 'eastern diamondback rattlesnake', 'sidewinder rattlesnake', 'trilobite', + 'harvestman', 'scorpion', 'yellow garden spider', 'barn spider', + 'European garden spider', 'southern black widow', 'tarantula', + 'wolf spider', 'tick', 'centipede', 'black grouse', 'ptarmigan', + 'ruffed grouse', 'prairie grouse', 'peafowl', 'quail', 'partridge', + 'african grey parrot', 'macaw', 'sulphur-crested cockatoo', 'lorikeet', + 'coucal', 'bee eater', 'hornbill', 'hummingbird', 'jacamar', 'toucan', + 'duck', 'red-breasted merganser', 'goose', 'black swan', 'tusker', + 'echidna', 'platypus', 'wallaby', 'koala', 'wombat', 'jellyfish', + 'sea anemone', 'brain coral', 'flatworm', 'nematode', 'conch', 'snail', + 'slug', 'sea slug', 'chiton', 'chambered nautilus', 'Dungeness crab', + 'rock crab', 'fiddler crab', 'red king crab', 'American lobster', + 'spiny lobster', 'crayfish', 'hermit crab', 'isopod', 'white stork', + 'black stork', 'spoonbill', 'flamingo', 'little blue heron', 'great egret', + 'bittern bird', 'crane bird', 'limpkin', 'common gallinule', + 'American coot', 'bustard', 'ruddy turnstone', 'dunlin', 'common redshank', + 'dowitcher', 'oystercatcher', 'pelican', 'king penguin', 'albatross', + 'grey whale', 'killer whale', 'dugong', 'sea lion', 'Chihuahua', + 'Japanese Chin', 'Maltese', 'Pekingese', 'Shih Tzu', 'King Charles Spaniel', + 'Papillon', 'toy terrier', 'Rhodesian Ridgeback', 'Afghan Hound', + 'Basset Hound', 'Beagle', 'Bloodhound', 'Bluetick Coonhound', + 'Black and Tan Coonhound', 'Treeing Walker Coonhound', 'English foxhound', + 'Redbone Coonhound', 'borzoi', 'Irish Wolfhound', 'Italian Greyhound', + 'Whippet', 'Ibizan Hound', 'Norwegian Elkhound', 'Otterhound', 'Saluki', + 'Scottish Deerhound', 'Weimaraner', 'Staffordshire Bull Terrier', + 'American Staffordshire Terrier', 'Bedlington Terrier', 'Border Terrier', + 'Kerry Blue Terrier', 'Irish Terrier', 'Norfolk Terrier', 'Norwich Terrier', + 'Yorkshire Terrier', 'Wire Fox Terrier', 'Lakeland Terrier', + 'Sealyham Terrier', 'Airedale Terrier', 'Cairn Terrier', + 'Australian Terrier', 'Dandie Dinmont Terrier', 'Boston Terrier', + 'Miniature Schnauzer', 'Giant Schnauzer', 'Standard Schnauzer', + 'Scottish Terrier', 'Tibetan Terrier', 'Australian Silky Terrier', + 'Soft-coated Wheaten Terrier', 'West Highland White Terrier', 'Lhasa Apso', + 'Flat-Coated Retriever', 'Curly-coated Retriever', 'Golden Retriever', + 'Labrador Retriever', 'Chesapeake Bay Retriever', + 'German Shorthaired Pointer', 'Vizsla', 'English Setter', 'Irish Setter', + 'Gordon Setter', 'Brittany dog', 'Clumber Spaniel', + 'English Springer Spaniel', 'Welsh Springer Spaniel', 'Cocker Spaniel', + 'Sussex Spaniel', 'Irish Water Spaniel', 'Kuvasz', 'Schipperke', + 'Groenendael dog', 'Malinois', 'Briard', 'Australian Kelpie', 'Komondor', + 'Old English Sheepdog', 'Shetland Sheepdog', 'collie', 'Border Collie', + 'Bouvier des Flandres dog', 'Rottweiler', 'German Shepherd Dog', + 'Dobermann', 'Miniature Pinscher', 'Greater Swiss Mountain Dog', + 'Bernese Mountain Dog', 'Appenzeller Sennenhund', 'Entlebucher Sennenhund', + 'Boxer', 'Bullmastiff', 'Tibetan Mastiff', 'French Bulldog', 'Great Dane', + 'St. Bernard', 'husky', 'Alaskan Malamute', 'Siberian Husky', 'Dalmatian', + 'Affenpinscher', 'Basenji', 'pug', 'Leonberger', 'Newfoundland dog', + 'Great Pyrenees dog', 'Samoyed', 'Pomeranian', 'Chow Chow', 'Keeshond', + 'brussels griffon', 'Pembroke Welsh Corgi', 'Cardigan Welsh Corgi', + 'Toy Poodle', 'Miniature Poodle', 'Standard Poodle', + 'Mexican hairless dog (xoloitzcuintli)', 'grey wolf', 'Alaskan tundra wolf', + 'red wolf or maned wolf', 'coyote', 'dingo', 'dhole', 'African wild dog', + 'hyena', 'red fox', 'kit fox', 'Arctic fox', 'grey fox', 'tabby cat', + 'tiger cat', 'Persian cat', 'Siamese cat', 'Egyptian Mau', 'cougar', 'lynx', + 'leopard', 'snow leopard', 'jaguar', 'lion', 'tiger', 'cheetah', + 'brown bear', 'American black bear', 'polar bear', 'sloth bear', 'mongoose', + 'meerkat', 'tiger beetle', 'ladybug', 'ground beetle', 'longhorn beetle', + 'leaf beetle', 'dung beetle', 'rhinoceros beetle', 'weevil', 'fly', 'bee', + 'ant', 'grasshopper', 'cricket insect', 'stick insect', 'cockroach', + 'praying mantis', 'cicada', 'leafhopper', 'lacewing', 'dragonfly', + 'damselfly', 'red admiral butterfly', 'ringlet butterfly', + 'monarch butterfly', 'small white butterfly', 'sulphur butterfly', + 'gossamer-winged butterfly', 'starfish', 'sea urchin', 'sea cucumber', + 'cottontail rabbit', 'hare', 'Angora rabbit', 'hamster', 'porcupine', + 'fox squirrel', 'marmot', 'beaver', 'guinea pig', 'common sorrel horse', + 'zebra', 'pig', 'wild boar', 'warthog', 'hippopotamus', 'ox', + 'water buffalo', 'bison', 'ram (adult male sheep)', 'bighorn sheep', + 'Alpine ibex', 'hartebeest', 'impala (antelope)', 'gazelle', + 'arabian camel', 'llama', 'weasel', 'mink', 'European polecat', + 'black-footed ferret', 'otter', 'skunk', 'badger', 'armadillo', + 'three-toed sloth', 'orangutan', 'gorilla', 'chimpanzee', 'gibbon', + 'siamang', 'guenon', 'patas monkey', 'baboon', 'macaque', 'langur', + 'black-and-white colobus', 'proboscis monkey', 'marmoset', + 'white-headed capuchin', 'howler monkey', 'titi monkey', + 'Geoffroy\'s spider monkey', 'common squirrel monkey', 'ring-tailed lemur', + 'indri', 'Asian elephant', 'African bush elephant', 'red panda', + 'giant panda', 'snoek fish', 'eel', 'silver salmon', 'rock beauty fish', + 'clownfish', 'sturgeon', 'gar fish', 'lionfish', 'pufferfish', 'abacus', + 'abaya', 'academic gown', 'accordion', 'acoustic guitar', + 'aircraft carrier', 'airliner', 'airship', 'altar', 'ambulance', + 'amphibious vehicle', 'analog clock', 'apiary', 'apron', 'trash can', + 'assault rifle', 'backpack', 'bakery', 'balance beam', 'balloon', + 'ballpoint pen', 'Band-Aid', 'banjo', 'baluster / handrail', 'barbell', + 'barber chair', 'barbershop', 'barn', 'barometer', 'barrel', 'wheelbarrow', + 'baseball', 'basketball', 'bassinet', 'bassoon', 'swimming cap', + 'bath towel', 'bathtub', 'station wagon', 'lighthouse', 'beaker', + 'military hat (bearskin or shako)', 'beer bottle', 'beer glass', + 'bell tower', 'baby bib', 'tandem bicycle', 'bikini', 'ring binder', + 'binoculars', 'birdhouse', 'boathouse', 'bobsleigh', 'bolo tie', + 'poke bonnet', 'bookcase', 'bookstore', 'bottle cap', 'hunting bow', + 'bow tie', 'brass memorial plaque', 'bra', 'breakwater', 'breastplate', + 'broom', 'bucket', 'buckle', 'bulletproof vest', 'high-speed train', + 'butcher shop', 'taxicab', 'cauldron', 'candle', 'cannon', 'canoe', + 'can opener', 'cardigan', 'car mirror', 'carousel', 'tool kit', + 'cardboard box / carton', 'car wheel', 'automated teller machine', + 'cassette', 'cassette player', 'castle', 'catamaran', 'CD player', 'cello', + 'mobile phone', 'chain', 'chain-link fence', 'chain mail', 'chainsaw', + 'storage chest', 'chiffonier', 'bell or wind chime', 'china cabinet', + 'Christmas stocking', 'church', 'movie theater', 'cleaver', + 'cliff dwelling', 'cloak', 'clogs', 'cocktail shaker', 'coffee mug', + 'coffeemaker', 'spiral or coil', 'combination lock', 'computer keyboard', + 'candy store', 'container ship', 'convertible', 'corkscrew', 'cornet', + 'cowboy boot', 'cowboy hat', 'cradle', 'construction crane', 'crash helmet', + 'crate', 'infant bed', 'Crock Pot', 'croquet ball', 'crutch', 'cuirass', + 'dam', 'desk', 'desktop computer', 'rotary dial telephone', 'diaper', + 'digital clock', 'digital watch', 'dining table', 'dishcloth', 'dishwasher', + 'disc brake', 'dock', 'dog sled', 'dome', 'doormat', 'drilling rig', 'drum', + 'drumstick', 'dumbbell', 'Dutch oven', 'electric fan', 'electric guitar', + 'electric locomotive', 'entertainment center', 'envelope', + 'espresso machine', 'face powder', 'feather boa', 'filing cabinet', + 'fireboat', 'fire truck', 'fire screen', 'flagpole', 'flute', + 'folding chair', 'football helmet', 'forklift', 'fountain', 'fountain pen', + 'four-poster bed', 'freight car', 'French horn', 'frying pan', 'fur coat', + 'garbage truck', 'gas mask or respirator', 'gas pump', 'goblet', 'go-kart', + 'golf ball', 'golf cart', 'gondola', 'gong', 'gown', 'grand piano', + 'greenhouse', 'radiator grille', 'grocery store', 'guillotine', 'hair clip', + 'hair spray', 'half-track', 'hammer', 'hamper', 'hair dryer', + 'hand-held computer', 'handkerchief', 'hard disk drive', 'harmonica', + 'harp', 'combine harvester', 'hatchet', 'holster', 'home theater', + 'honeycomb', 'hook', 'hoop skirt', 'gymnastic horizontal bar', + 'horse-drawn vehicle', 'hourglass', 'iPod', 'clothes iron', + 'carved pumpkin', 'jeans', 'jeep', 'T-shirt', 'jigsaw puzzle', 'rickshaw', + 'joystick', 'kimono', 'knee pad', 'knot', 'lab coat', 'ladle', 'lampshade', + 'laptop computer', 'lawn mower', 'lens cap', 'letter opener', 'library', + 'lifeboat', 'lighter', 'limousine', 'ocean liner', 'lipstick', + 'slip-on shoe', 'lotion', 'music speaker', 'loupe magnifying glass', + 'sawmill', 'magnetic compass', 'messenger bag', 'mailbox', 'tights', + 'one-piece bathing suit', 'manhole cover', 'maraca', 'marimba', 'mask', + 'matchstick', 'maypole', 'maze', 'measuring cup', 'medicine cabinet', + 'megalith', 'microphone', 'microwave oven', 'military uniform', 'milk can', + 'minibus', 'miniskirt', 'minivan', 'missile', 'mitten', 'mixing bowl', + 'mobile home', 'ford model t', 'modem', 'monastery', 'monitor', 'moped', + 'mortar and pestle', 'graduation cap', 'mosque', 'mosquito net', 'vespa', + 'mountain bike', 'tent', 'computer mouse', 'mousetrap', 'moving van', + 'muzzle', 'metal nail', 'neck brace', 'necklace', 'baby pacifier', + 'notebook computer', 'obelisk', 'oboe', 'ocarina', 'odometer', 'oil filter', + 'pipe organ', 'oscilloscope', 'overskirt', 'bullock cart', 'oxygen mask', + 'product packet / packaging', 'paddle', 'paddle wheel', 'padlock', + 'paintbrush', 'pajamas', 'palace', 'pan flute', 'paper towel', 'parachute', + 'parallel bars', 'park bench', 'parking meter', 'railroad car', 'patio', + 'payphone', 'pedestal', 'pencil case', 'pencil sharpener', 'perfume', + 'Petri dish', 'photocopier', 'plectrum', 'Pickelhaube', 'picket fence', + 'pickup truck', 'pier', 'piggy bank', 'pill bottle', 'pillow', + 'ping-pong ball', 'pinwheel', 'pirate ship', 'drink pitcher', 'block plane', + 'planetarium', 'plastic bag', 'plate rack', 'farm plow', 'plunger', + 'Polaroid camera', 'pole', 'police van', 'poncho', 'pool table', + 'soda bottle', 'plant pot', 'potter\'s wheel', 'power drill', 'prayer rug', + 'printer', 'prison', 'missile', 'projector', 'hockey puck', 'punching bag', + 'purse', 'quill', 'quilt', 'race car', 'racket', 'radiator', 'radio', + 'radio telescope', 'rain barrel', 'recreational vehicle', + 'fishing casting reel', 'reflex camera', 'refrigerator', 'remote control', + 'restaurant', 'revolver', 'rifle', 'rocking chair', 'rotisserie', 'eraser', + 'rugby ball', 'ruler measuring stick', 'sneaker', 'safe', 'safety pin', + 'salt shaker', 'sandal', 'sarong', 'saxophone', 'scabbard', + 'weighing scale', 'school bus', 'schooner', 'scoreboard', 'CRT monitor', + 'screw', 'screwdriver', 'seat belt', 'sewing machine', 'shield', + 'shoe store', 'shoji screen / room divider', 'shopping basket', + 'shopping cart', 'shovel', 'shower cap', 'shower curtain', 'ski', + 'balaclava ski mask', 'sleeping bag', 'slide rule', 'sliding door', + 'slot machine', 'snorkel', 'snowmobile', 'snowplow', 'soap dispenser', + 'soccer ball', 'sock', 'solar thermal collector', 'sombrero', 'soup bowl', + 'keyboard space bar', 'space heater', 'space shuttle', 'spatula', + 'motorboat', 'spider web', 'spindle', 'sports car', 'spotlight', 'stage', + 'steam locomotive', 'through arch bridge', 'steel drum', 'stethoscope', + 'scarf', 'stone wall', 'stopwatch', 'stove', 'strainer', 'tram', + 'stretcher', 'couch', 'stupa', 'submarine', 'suit', 'sundial', 'sunglasses', + 'sunglasses', 'sunscreen', 'suspension bridge', 'mop', 'sweatshirt', + 'swim trunks / shorts', 'swing', 'electrical switch', 'syringe', + 'table lamp', 'tank', 'tape player', 'teapot', 'teddy bear', 'television', + 'tennis ball', 'thatched roof', 'front curtain', 'thimble', + 'threshing machine', 'throne', 'tile roof', 'toaster', 'tobacco shop', + 'toilet seat', 'torch', 'totem pole', 'tow truck', 'toy store', 'tractor', + 'semi-trailer truck', 'tray', 'trench coat', 'tricycle', 'trimaran', + 'tripod', 'triumphal arch', 'trolleybus', 'trombone', 'hot tub', + 'turnstile', 'typewriter keyboard', 'umbrella', 'unicycle', 'upright piano', + 'vacuum cleaner', 'vase', 'vaulted or arched ceiling', 'velvet fabric', + 'vending machine', 'vestment', 'viaduct', 'violin', 'volleyball', + 'waffle iron', 'wall clock', 'wallet', 'wardrobe', 'military aircraft', + 'sink', 'washing machine', 'water bottle', 'water jug', 'water tower', + 'whiskey jug', 'whistle', 'hair wig', 'window screen', 'window shade', + 'Windsor tie', 'wine bottle', 'airplane wing', 'wok', 'wooden spoon', + 'wool', 'split-rail fence', 'shipwreck', 'sailboat', 'yurt', 'website', + 'comic book', 'crossword', 'traffic or street sign', 'traffic light', + 'dust jacket', 'menu', 'plate', 'guacamole', 'consomme', 'hot pot', + 'trifle', 'ice cream', 'popsicle', 'baguette', 'bagel', 'pretzel', + 'cheeseburger', 'hot dog', 'mashed potatoes', 'cabbage', 'broccoli', + 'cauliflower', 'zucchini', 'spaghetti squash', 'acorn squash', + 'butternut squash', 'cucumber', 'artichoke', 'bell pepper', 'cardoon', + 'mushroom', 'Granny Smith apple', 'strawberry', 'orange', 'lemon', 'fig', + 'pineapple', 'banana', 'jackfruit', 'cherimoya (custard apple)', + 'pomegranate', 'hay', 'carbonara', 'chocolate syrup', 'dough', 'meatloaf', + 'pizza', 'pot pie', 'burrito', 'red wine', 'espresso', 'tea cup', 'eggnog', + 'mountain', 'bubble', 'cliff', 'coral reef', 'geyser', 'lakeshore', + 'promontory', 'sandbar', 'beach', 'valley', 'volcano', 'baseball player', + 'bridegroom', 'scuba diver', 'rapeseed', 'daisy', 'yellow lady\'s slipper', + 'corn', 'acorn', 'rose hip', 'horse chestnut seed', 'coral fungus', + 'agaric', 'gyromitra', 'stinkhorn mushroom', 'earth star fungus', + 'hen of the woods mushroom', 'bolete', 'corn cob', 'toilet paper' +] + +# ImageNet-A and ImageNet-R do not use the full label space of ImageNet. +# These were copied from third_party/py/robustness_metrics/datasets/tfds.py +# Kudos to mjlm@ who helped us notice this. +IMAGENET_A_LABELSET = [ + 6, 11, 13, 15, 17, 22, 23, 27, 30, 37, 39, 42, 47, 50, 57, 70, 71, 76, 79, + 89, 90, 94, 96, 97, 99, 105, 107, 108, 110, 113, 124, 125, 130, 132, 143, + 144, 150, 151, 207, 234, 235, 254, 277, 283, 287, 291, 295, 298, 301, 306, + 307, 308, 309, 310, 311, 313, 314, 315, 317, 319, 323, 324, 326, 327, 330, + 334, 335, 336, 347, 361, 363, 372, 378, 386, 397, 400, 401, 402, 404, 407, + 411, 416, 417, 420, 425, 428, 430, 437, 438, 445, 456, 457, 461, 462, 470, + 472, 483, 486, 488, 492, 496, 514, 516, 528, 530, 539, 542, 543, 549, 552, + 557, 561, 562, 569, 572, 573, 575, 579, 589, 606, 607, 609, 614, 626, 627, + 640, 641, 642, 643, 658, 668, 677, 682, 684, 687, 701, 704, 719, 736, 746, + 749, 752, 758, 763, 765, 768, 773, 774, 776, 779, 780, 786, 792, 797, 802, + 803, 804, 813, 815, 820, 823, 831, 833, 835, 839, 845, 847, 850, 859, 862, + 870, 879, 880, 888, 890, 897, 900, 907, 913, 924, 932, 933, 934, 937, 943, + 945, 947, 951, 954, 956, 957, 959, 971, 972, 980, 981, 984, 986, 987, 988, +] + +# Also check out https://github.com/hendrycks/imagenet-r/blob/master/eval.py +IMAGENET_R_LABELSET = [ + 1, 2, 4, 6, 8, 9, 11, 13, 22, 23, 26, 29, 31, 39, 47, 63, 71, 76, 79, 84, + 90, 94, 96, 97, 99, 100, 105, 107, 113, 122, 125, 130, 132, 144, 145, 147, + 148, 150, 151, 155, 160, 161, 162, 163, 171, 172, 178, 187, 195, 199, 203, + 207, 208, 219, 231, 232, 234, 235, 242, 245, 247, 250, 251, 254, 259, 260, + 263, 265, 267, 269, 276, 277, 281, 288, 289, 291, 292, 293, 296, 299, 301, + 308, 309, 310, 311, 314, 315, 319, 323, 327, 330, 334, 335, 337, 338, 340, + 341, 344, 347, 353, 355, 361, 362, 365, 366, 367, 368, 372, 388, 390, 393, + 397, 401, 407, 413, 414, 425, 428, 430, 435, 437, 441, 447, 448, 457, 462, + 463, 469, 470, 471, 472, 476, 483, 487, 515, 546, 555, 558, 570, 579, 583, + 587, 593, 594, 596, 609, 613, 617, 621, 629, 637, 657, 658, 701, 717, 724, + 763, 768, 774, 776, 779, 780, 787, 805, 812, 815, 820, 824, 833, 847, 852, + 866, 875, 883, 889, 895, 907, 928, 931, 932, 933, 934, 936, 937, 943, 945, + 947, 948, 949, 951, 953, 954, 957, 963, 965, 967, 980, 981, 983, 988, +] diff --git a/big_vision/datasets/infovqa/infovqa.py b/big_vision/datasets/infovqa/infovqa.py new file mode 100644 index 0000000000000000000000000000000000000000..4e6500548954b706ab0ebd203589871cca898b24 --- /dev/null +++ b/big_vision/datasets/infovqa/infovqa.py @@ -0,0 +1,141 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""Implements InfoVqa in TFDS structure. + +First, download and unzip the dataset from https://rrc.cvc.uab.es/?ch=17 +and place it in /tmp/data/infovqa. + +Then, run conversion locally (make sure to install tensorflow-datasets for the `tfds` util): + + cd third_party/py/big_vision/datasets + env TFDS_DATA_DIR=/tmp/tfds tfds build --datasets=infovqa + +Example to load: + + import tensorflow_datasets as tfds + dataset = tfds.load('infovqa', split='train', data_dir='/tmp/tfds') + +Dataset splits: + train: 23946 examples/questions (4406 images) + val: 2801 examples/questions (500 images) + test: 3288 examples/questions (579 images) (no answers) + +Recommended training splits: + train: train[:95%] (22749 examples/questions) + minitrain: train[:5%] (1197 examples/questions) + minival: train[95%:] (1197 examples/questions) + eval: val (2801 examples/questions) + +Note that according to task description in +https://rrc.cvc.uab.es/?ch=17&com=tasks: + - Order of items in a multi span answer does not matter. Therefore, we include + all permutations of the answer in the val split. + - Answers are not case sensitive. We leave it to the user to lower case + answers if they want to. +""" +import itertools +import json +import os + +import numpy as np +import tensorflow_datasets as tfds + + +_DESCRIPTION = """InfographicVQA dataset.""" + +# pylint: disable=line-too-long +_CITATION = """ +@inproceedings{Mathew_2022, + title={InfographicVQA}, + url={http://dx.doi.org/10.1109/WACV51458.2022.00264}, + DOI={10.1109/wacv51458.2022.00264}, + booktitle={2022 IEEE/CVF Winter Conference on Applications of Computer Vision (WACV)}, + publisher={IEEE}, + author={Mathew, Minesh and Bagal, Viraj and Tito, Ruben and Karatzas, Dimosthenis and Valveny, Ernest and Jawahar, C. V.}, + year={2022}, + month=jan } +""" +# pylint: enable=line-too-long + +# When running locally (recommended), copy files as above an use these: +_INFOVQA_PATH = '/tmp/data/infovqa/' +_ANNOTATIONS = { + 'train': 'infographicsVQA_train_v1.0.json', + 'val': 'infographicsVQA_val_v1.0_withQT.json', + 'test': 'infographicsVQA_test_v1.0.json', + } + + +class Infovqa(tfds.core.GeneratorBasedBuilder): + """DatasetBuilder for infovqa dataset.""" + + VERSION = tfds.core.Version('1.1.0') + RELEASE_NOTES = { + '1.0.0': 'First release.', + '1.1.0': 'Add multi-span permutations to the val split answers.', + } + + def _info(self): + """Returns the metadata.""" + + return tfds.core.DatasetInfo( + builder=self, + description=_DESCRIPTION, + features=tfds.features.FeaturesDict({ + 'question_id': tfds.features.Scalar(np.int32), + 'filename': tfds.features.Text(), + 'image': tfds.features.Image(encoding_format='jpeg'), + 'question': tfds.features.Text(), + 'answers': tfds.features.Sequence(tfds.features.Text()), + }), + supervised_keys=None, + homepage='https://www.docvqa.org/datasets/infographicvqa', + citation=_CITATION, + ) + + def _split_generators(self, dl_manager: tfds.download.DownloadManager): + """Returns SplitGenerators.""" + return {split: self._generate_examples(split) + for split in ('train', 'val', 'test')} + + def _generate_examples(self, split): + """Yields (key, example) tuples from test set.""" + annot_fname = os.path.join(_INFOVQA_PATH, _ANNOTATIONS[split]) + with open(annot_fname, 'r') as f: + data = json.loads(f.read()) + + for x in data['data']: + yield x['questionId'], { + 'question_id': x['questionId'], + 'filename': x['image_local_name'], + 'image': os.path.join(_INFOVQA_PATH, 'images', x['image_local_name']), + 'question': x['question'], + 'answers': maybe_permute(x.get('answers', []), split), + } + + +def maybe_permute(answers, split): + if split != 'val': + return answers + new_answers = [] + for x in answers: + if ', ' in x: # Create all permutations. + # The first element remains the same. + new_answers.extend([', '.join(y) + for y in itertools.permutations(x.split(', '))]) + else: + new_answers.append(x) + return new_answers diff --git a/big_vision/datasets/jsonl.py b/big_vision/datasets/jsonl.py new file mode 100644 index 0000000000000000000000000000000000000000..719deba2b25987a4b9d58b56474e420cb5b1e706 --- /dev/null +++ b/big_vision/datasets/jsonl.py @@ -0,0 +1,177 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Simple data input from .jsonl files.""" + +import hashlib +import json +from multiprocessing.pool import ThreadPool +import os +import tempfile +import urllib.request + +from absl import logging +import big_vision.datasets.core as ds_core +import jax +import numpy as np +import overrides +import tensorflow as tf + + +def cached_download(url, dest=None, verbose=True): + """Download `url` to local file and return path to that, but with caching.""" + # NOTE: there is a small chance of saving corrupted data if the process is + # interrupted in the middle of writing the file. Then, reading in the input + # pipeline will fail, and the fix is to nuke the temp folder. + + # Compute a temp name based on the URL, so we can check if we already + # downloaded it before. + dest = dest or os.path.join(tempfile.gettempdir(), "bv") + os.makedirs(dest, exist_ok=True) + dest = os.path.join(dest, hashlib.md5(url.encode()).hexdigest()) + + # NOTE: we should use last-modified header to know whether to re-download. + if os.path.isfile(dest): + return dest + + if verbose: + print(f"\rRetrieving {url} into {dest}", end="", flush=True) + + with urllib.request.urlopen(url) as f: + data = f.read() + with open(dest, "wb+") as f: + f.write(data) + return dest + + +class DataSource(ds_core.DataSource): + """.jsonl DataSource.""" + + def __init__(self, fname, *, fopen_keys=(), download_keys=(), + start=0, stop=float("inf")): + """Create data-source that's jsonl + data files (eg images). + + This correctly supports multi-host in that each host only reads a subset of + the dataset automatically. However, currently, all hosts download all items + if `download_keys` is specified. TODO: b/lbeyer - This can be improved. + + Args: + fname: str, the path to the jsonl file that holds the dataset. + fopen_keys: collection of str or dict, the keys in the dataset whose + string value actually is a file-path that should be opened and read, + and its content is what goes into the batch (eg image filenames + commonly ["image"]). + If a dict, the values are folders prefixed to the filenames. + Supports gs:// for reading from buckets. + download_keys: collection of str, the keys in the dataset whose string + value actually is a URL from which the file should be downloaded first. + files are downloaded to a persistent tmp folder using the URL hash as + filename. If the file already exists, the download is skipped. + Must be a subset of `fopen_keys`. + start: int, index of the first row to use; use for slicing the data. + stop: int or inf, index of the row after the last one to use. + + Note: + This simple data input does not allow for nested/hierarchical values, + or in any way more complicated values like vectors. Use TFDS for that. + + The way start/stop arguments are used is as in list slicing[start:stop]. + """ + self.examples = [] + + with tf.io.gfile.GFile(fname) as f: + for i, line in enumerate(f): + if (start or 0) <= i < (stop or float("inf")): + try: + self.examples.append(json.loads(line)) + except json.decoder.JSONDecodeError as e: + raise ValueError(f"Invalid JSON in line {i}:\n{line}") from e + + if download_keys: + for k in download_keys: + assert k in fopen_keys, ( + f"{k} in download_keys but missing from fopen_keys {fopen_keys}") + + # TODO: b/lbeyer - use info from trainer instead, move that to utils. + logging.info( # pylint: disable=logging-fstring-interpolation + f"\u001b[33mNOTE\u001b[0m: Downloading {download_keys} " + f"for dataset {fname} ({len(self.examples)} examples) ...") + + def _dl_one(ex): + for k in download_keys: + ex[k] = cached_download(ex[k]) + + ThreadPool(100).map(_dl_one, self.examples) + print("Done") + logging.info("\u001b[33mNOTE\u001b[0m: Done downloading.") + + # Normalize. + if isinstance(fopen_keys, (list, tuple)): + self.fopen_keys = {k: "" for k in fopen_keys} + else: + self.fopen_keys = fopen_keys or {} + + # We need to apply fopen path prefix here already, because doing so while + # actually reading the files in TF, things are symbolic :( + for ex in self.examples: + for k, dirname in self.fopen_keys.items(): + ex[k] = os.path.join(dirname, ex[k]) + + def _indices(self, *, process_split=True, process_index=None): + indices = np.arange(len(self.examples)) + + if not process_split: + return list(indices) + + pid = jax.process_index() if process_index is None else process_index + return list(np.array_split(indices, jax.process_count())[pid]) + + @overrides.overrides + def get_tfdata(self, ordered=False, *, process_split=True, allow_cache=True): + del allow_cache # We don't cache anything anyways. + assert not process_split or len(self.examples) >= jax.process_count(), ( + "Process splitting the data with fewer examples than processes!?") + + my_idxs = self._indices(process_split=process_split) + if not ordered: + np.random.shuffle(my_idxs) + + dataset = tf.data.Dataset.from_generator( + generator=lambda: ({"id": str(i), **self.examples[i]} for i in my_idxs), + output_signature={ + "id": _guess_signature("0"), + **{k: _guess_signature(v) for k, v in self.examples[0].items()}, + }) + + def _read_files(example): + for k in self.fopen_keys: + example[k] = tf.io.read_file(example[k]) + return example + dataset = dataset.map(_read_files) + + return dataset + + @property + @overrides.overrides + def total_examples(self): + return len(self.examples) + + @overrides.overrides + def num_examples_per_process(self): + return [len(self._indices(process_index=pid)) + for pid in range(jax.process_count())] + + +def _guess_signature(value): + return tf.TensorSpec.from_tensor(tf.constant(value)) diff --git a/big_vision/datasets/nocaps/nocaps.py b/big_vision/datasets/nocaps/nocaps.py new file mode 100644 index 0000000000000000000000000000000000000000..fab90082bf3852100c59230ee23ddeb923144792 --- /dev/null +++ b/big_vision/datasets/nocaps/nocaps.py @@ -0,0 +1,160 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""Implements nocaps val/test set in TFDS structure. + +It's small data, so simple to run locally. First, copy the data to local disk: + + mkdir -p /tmp/data/nocaps_data + cd /tmp/data/nocaps_data + wget https://s3.amazonaws.com/open-images-dataset/tar/test.tar.gz + wget https://s3.amazonaws.com/open-images-dataset/tar/validation.tar.gz + curl -O https://nocaps.s3.amazonaws.com/nocaps_val_4500_captions.json + curl -O https://s3.amazonaws.com/nocaps/nocaps_test_image_info.json + + mkdir -p /tmp/data/nocaps_data/Images + tar -xf validation.tar.gz -C Images + rm validation.tar.gz + tar -xf test.tar.gz -C Images + rm test.tar.gz + +Then, run conversion locally (make sure to install tensorflow-datasets for the `tfds` util): + + cd big_vision/datasets + env TFDS_DATA_DIR=/tmp/tfds tfds build --datasets=nocaps + +Example to load: + + import tensorflow_datasets as tfds + dataset = tfds.load('nocaps', split='val', data_dir='/tmp/tfds') +""" +import collections +import json +import os + +from absl import logging +import numpy as np +import tensorflow as tf +import tensorflow_datasets as tfds + + +_DESCRIPTION = """Nocaps dataset.""" + +_CITATION = ( + '@inproceedings{agrawal2019nocaps,' + 'title={nocaps: novel object captioning at scale},' + 'author={Agrawal, Harsh and Desai, Karan and Wang, Yufei and Chen, Xinlei' + 'and Jain, Rishabh and Johnson, Mark and Batra, Dhruv and Parikh, Devi' + 'and Lee, Stefan and Anderson, Peter},' + 'booktitle={ICCV},' + 'pages={8948--8957},' + 'year={2019}}') + +# When running locally (recommended), copy files as above an use these: +_FILEPATH = '/tmp/data/nocaps_data/Images/' +_VAL_FILES = '/tmp/data/nocaps_data/nocaps_val_4500_captions.json' +_TEST_FILES = '/tmp/data/nocaps_data/nocaps_test_image_info.json' + + +class NoCaps(tfds.core.GeneratorBasedBuilder): + """DatasetBuilder for nocaps dataset.""" + + VERSION = tfds.core.Version('1.0.0') + RELEASE_NOTES = { + '1.0.0': 'Initial release.', + } + + def _info(self) -> tfds.core.DatasetInfo: + """Returns the dataset metadata. + + (tfds.core.DatasetInfo object) + These are the features of your dataset like images, labels, etc. + """ + + return tfds.core.DatasetInfo( + builder=self, + description=_DESCRIPTION, + features=tfds.features.FeaturesDict({ + 'image/id': tf.int64, + 'image_filepath': tfds.features.Text(), + 'url': tfds.features.Text(), + 'image': tfds.features.Image(encoding_format='jpeg'), + 'texts': tfds.features.Sequence(tfds.features.Text()), + }), + # If there's a common (input, target) tuple from the + # features, specify them here. They'll be used if + # `as_supervised=True` in `builder.as_dataset`. + supervised_keys=None, # Set to `None` to disable + homepage='https://nocaps.org/', + citation=_CITATION, + ) + + def _split_generators(self, dl_manager: tfds.download.DownloadManager): + """Returns SplitGenerators.""" + def group_by_id(data, image_dir): + id2caps = collections.defaultdict(list) + for ex in data.get('annotations', []): + id2caps[ex['image_id']].append(ex['caption']) + + id_to_example = {} + for ex in data['images']: + id_to_example[ex['id']] = { + 'image/id': ex['id'], + 'image_filepath': os.path.join( + _FILEPATH, image_dir, ex['file_name']), + 'url': ex['coco_url'], + 'image': os.path.join(_FILEPATH, image_dir, ex['file_name']), + 'texts': id2caps[ex['id']] if ex['id'] in id2caps else ['N/A'], + } + return id_to_example + + # Returns the Dict[split names, Iterator[Key, Example]] + with open(_VAL_FILES) as f: + val_data = group_by_id(json.load(f), 'validation') + with open(_TEST_FILES) as f: + test_data = group_by_id(json.load(f), 'test') + return { + 'val': self._generate_examples(val_data), + 'test': self._generate_examples(test_data), + } + + def _generate_examples(self, data): + """Generate a tf.Example object. + + This contains the image, objects, attributes, regions and relationships. + + Args: + data: a dictionary with the image/id. + + Yields: + (key, example) tuples from dataset. The example has format specified in + the above DatasetInfo. + """ + for k, v in data.items(): + try: + # Jpeg decode test to check early errors. The decoded images are not + # used, instead we rely on the default tfds.features.Image function. + unused_image = tf.io.read_file(v['image_filepath']) + unused_image = np.array(tf.image.decode_jpeg(unused_image)) + except tf.errors.InvalidArgumentError: + # Unable to read image, skip this image and output download link. + logging.error('Unable to decode: curl -O %s', v['url']) + continue + except tf.errors.NotFoundError: + # Unable to read image, skip this image and output download link. + logging.error('File not found: curl -O %s', v['url']) + continue + + yield k, v diff --git a/big_vision/datasets/okvqa/okvqa.py b/big_vision/datasets/okvqa/okvqa.py new file mode 100644 index 0000000000000000000000000000000000000000..9a5631c440bdca28627f656c87fdb4b8b3f7164e --- /dev/null +++ b/big_vision/datasets/okvqa/okvqa.py @@ -0,0 +1,213 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""Implements the OKVQA dataset for TFDS. + +Download the required files from https://okvqa.allenai.org/download.html: + +mkdir -p /tmp/tfds +cd /tmp/tfds/ +wget http://images.cocodataset.org/zips/train2014.zip +wget http://images.cocodataset.org/zips/val2014.zip +wget https://okvqa.allenai.org/static/data/mscoco_train2014_annotations.json.zip +wget https://okvqa.allenai.org/static/data/mscoco_val2014_annotations.json.zip +wget https://okvqa.allenai.org/static/data/OpenEnded_mscoco_train2014_questions.json.zip +wget https://okvqa.allenai.org/static/data/OpenEnded_mscoco_val2014_questions.json.zip +unzip val2014.zip +unzip train2014.zip +unzip OpenEnded_mscoco_train2014_questions.json.zip +unzip OpenEnded_mscoco_val2014_questions.json.zip +unzip mscoco_train2014_annotations.json.zip +unzip mscoco_val2014_annotations.json.zip + +Then, run conversion locally (make sure to install tensorflow-datasets for the +`tfds` util): + + cd big_vision/datasets + env TFDS_DATA_DIR=/tmp/tfds tfds build --datasets=okvqa + +Example to load: + + import tensorflow_datasets as tfds + dataset = tfds.load('okvqa', split='val', data_dir='/tmp/tfds') +""" + +import json +import os +from typing import Any +import numpy as np +import tensorflow_datasets as tfds + +_DESCRIPTION = """ +OKVQA addresses the task of VQA with outside knowledge. +This version of the dataset contains: +- Questions + Answers from OKVQA. +- Images from COCO. +""" + +_CITATION = """ +@InProceedings{okvqa, +author = {Kenneth Marino and Mohammad Rastegari and Ali Farhadi and Roozbeh Mottaghi}, +title = {OK-VQA: A Visual Question Answering Benchmark Requiring External Knowledge}, +booktitle = {Conference on Computer Vision and Pattern Recognition (CVPR)}, +year = {2019}, +} +""" + +ANNOTATION_FILE = { + 'train': 'mscoco_train2014_annotations.json', + 'val': 'mscoco_val2014_annotations.json', +} +QUESTIONS_FILE = { + 'train': 'OpenEnded_mscoco_train2014_questions.json', + 'val': 'OpenEnded_mscoco_val2014_questions.json', +} +QUESTION_TYPES = { + 'one': 'Vehicles and Transportation', + 'two': 'Brands, Companies and Products', + 'three': 'Objects, Material and Clothing', + 'four': 'Sports and Recreation', + 'five': 'Cooking and Food', + 'six': 'Geography, History, Language and Culture', + 'seven': 'People and Everyday life', + 'eight': 'Plants and Animals', + 'nine': 'Science and Technology', + 'ten': 'Weather and Climate', + 'other': 'Other', +} + + +# When running locally (recommended), copy files as above an use these: +_OKVQA_PATH = '/media/scratch/okvqa' + + +class OkVqa(tfds.core.GeneratorBasedBuilder): + """Import COCO dataset for OKVQA with KAT features.""" + + VERSION = tfds.core.Version('1.0.0') + RELEASE_NOTES = {'1.0.0': 'Changed to array record format.'} + MANUAL_DOWNLOAD_INSTRUCTIONS = """ + In manual_dir/ you should have a directory okvqa which contains the + following files and directories: + From the OKVQA dataset: + - mscoco_train2014_annotations.json + - mscoco_val2014_annotations.json + - OpenEnded_mscoco_train2014_questions.json + - OpenEnded_mscoco_val2014_questions.json + - train2014.zip + - val2014.zip + """ + + def _info(self) -> tfds.core.DatasetInfo: + """Returns the dataset metadata.""" + features = tfds.features.FeaturesDict({ + 'image': tfds.features.Image(shape=(None, None, 3)), + 'image_id': tfds.features.Scalar(dtype=np.int64), + 'answer_type': tfds.features.Text(), + 'answers': tfds.features.Sequence(tfds.features.Text()), + 'answers_confidence': tfds.features.Tensor(shape=[10], dtype=np.bool_), + 'answers_raw': tfds.features.Sequence(tfds.features.Text()), + 'question_id': tfds.features.Scalar(dtype=np.int64), + 'question_type': tfds.features.Text(), + 'question_type_readable': tfds.features.Text(), + 'question': tfds.features.Text(), + }) + + return tfds.core.DatasetInfo( + builder=self, + features=features, + description=_DESCRIPTION, + supervised_keys=None, + homepage='https://okvqa.allenai.org/', + citation=_CITATION, + ) + + def _split_generators(self, dl_manager: tfds.download.DownloadManager) -> ...: + """Call the function which defines the splits.""" + # data_dir = dl_manager.manual_dir + data_dir = _OKVQA_PATH + return { + 'train': self._generate_examples(data_dir, 'train'), + 'val': self._generate_examples(data_dir, 'val'), + } + + def _generate_examples(self, data_dir: str, split: str) -> ...: + annotations = get_okvqa_annotations(data_dir, split) + + for question_id, annotation in annotations.items(): + image_id = annotation['image_id'] + + # Sanity check. + if len(annotation['answers']) != 10: + num_answers = len(annotation['answers']) + raise ValueError( + f'The number of answers for {image_id} is not 10 but {num_answers}') + + feature_dict = { + 'image': self.get_image_path(data_dir, split, image_id), + 'image_id': image_id, + 'answer_type': annotation['answer_type'], + 'answers': [a['answer'] for a in annotation['answers']], + 'answers_confidence': _get_answer_confidence(annotation['answers']), + 'answers_raw': [a['raw_answer'] for a in annotation['answers']], + 'question_id': annotation['question_id'], + 'question_type': annotation['question_type'], + 'question_type_readable': QUESTION_TYPES[annotation['question_type']], + 'question': annotation['question'], + } + yield f'{question_id}', feature_dict + + def get_image_path(self, data_dir: str, split: str, image_id: int) -> str: + subdir = {'train': 'train2014', 'val': 'val2014'}[split] + return f'{data_dir}/{subdir}/COCO_{subdir}_{image_id:012d}.jpg' + + +def _get_answer_confidence(answers: list[dict[str, str]]) -> np.ndarray: + """Get OKVQA answer confidences as bool.""" + confidences = [] + for a in answers: + confidence = a['answer_confidence'] + if confidence == 'yes': + confidences.append(True) + elif confidence == 'no': + confidences.append(False) + else: + raise ValueError(f'Unknown confidence: {confidence}') + return np.array(confidences, dtype=bool) + + +def _read_json( + data_dir: str, file: str, key: str +) -> dict[int, dict[str, Any]]: + with open(os.path.join(data_dir, file)) as f: + data = json.load(f) + questions = {d['question_id']: d for d in data[key]} + return questions + + +def get_okvqa_annotations( + data_dir: str, split: str +) -> dict[int, dict[str, Any]]: + """Return okvqa annotations (quesions and answers) as dictionary.""" + questions = _read_json(data_dir, QUESTIONS_FILE[split], 'questions') + annotations = _read_json(data_dir, ANNOTATION_FILE[split], 'annotations') + + assert len(annotations) == len(questions) + for question_id, question in questions.items(): + assert question['image_id'] == annotations[question_id]['image_id'] + assert question['question_id'] == annotations[question_id]['question_id'] + annotations[question_id]['question'] = question['question'] + + return annotations diff --git a/big_vision/datasets/pope/pope.py b/big_vision/datasets/pope/pope.py new file mode 100644 index 0000000000000000000000000000000000000000..3f266d6ce21d218186736e434db328744a377389 --- /dev/null +++ b/big_vision/datasets/pope/pope.py @@ -0,0 +1,145 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""Implements POPE test-set in TFDS structure. + +It's small data, so simple to run locally. First, copy the data to local disk: +First download json files from https://github.com/AoiDragon/POPE; then download +MSCOCO (val 2014) images from https://cocodataset.org/#download + + mkdir -p /tmp/data/pope/ + mkdir -p /tmp/data/pope/pope/ + mkdir -p /tmp/data/pope/images/ + git clone https://github.com/AoiDragon/POPE.git + cp POPE/output/coco/* /tmp/data/pope/pope/ + wget http://images.cocodataset.org/zips/val2014.zip + unzip val2014.zip + cp -r val2014/ /tmp/data/pope/images/ + +Then, run conversion locally (make sure to install tensorflow-datasets for the `tfds` util): + + cd big_vision/datasets + env TFDS_DATA_DIR=/tmp/tfds tfds build --datasets=pope + +Example to load: + + import tensorflow_datasets as tfds + dataset_random = tfds.load('pope/pope_random', split='test', data_dir='/tmp/tfds') + dataset_popular = tfds.load('pope/pope_popular', split='test', data_dir='/tmp/tfds') + dataset_adversarial = tfds.load('pope/pope_adversarial', split='test', data_dir='/tmp/tfds') + +""" +import json +import os + +import numpy as np +import tensorflow_datasets as tfds + + +_DESCRIPTION = """POPE dataset.""" + +# pylint: disable=line-too-long +_CITATION = """ +@inproceedings{li-etal-2023-evaluating, + title = "Evaluating Object Hallucination in Large Vision-Language Models", + author = "Li, Yifan and + Du, Yifan and + Zhou, Kun and + Wang, Jinpeng and + Zhao, Xin and + Wen, Ji-Rong", + editor = "Bouamor, Houda and + Pino, Juan and + Bali, Kalika", + booktitle = "Proceedings of the 2023 Conference on Empirical Methods in Natural Language Processing", + month = dec, + year = "2023", + address = "Singapore", + publisher = "Association for Computational Linguistics", + url = "https://aclanthology.org/2023.emnlp-main.20", + doi = "10.18653/v1/2023.emnlp-main.20", + pages = "292--305", + abstract = "Inspired by the superior language abilities of large language models (LLM), large vision-language models (LVLM) have been recently proposed by integrating powerful LLMs for improving the performance on complex multimodal tasks. Despite the promising progress on LVLMs, we find that they suffer from object hallucinations, i.e., they tend to generate objects inconsistent with the target images in the descriptions. To investigate it, this work presents the first systematic study on object hallucination of LVLMs. We conduct the evaluation experiments on several representative LVLMs, and show that they mostly suffer from severe object hallucination issues. We further discuss that the visual instructions may influence the hallucination, and find that: objects that frequently appear in the visual instructions or co-occur with the image objects are obviously prone to be hallucinated by LVLMs. Besides, we further design a polling-based query method called POPE for better evaluation of object hallucination. Experiment results show that our POPE can evaluate object hallucination in a more stable and flexible way.", +} +""" +# pylint: enable=line-too-long + +# When running locally (recommended), copy files as above and use these: +_POPE_PATH = '/tmp/data/pope/' + + +class POPEConfig(tfds.core.BuilderConfig): + """Configuration to build the dataset.""" + + pass + + +class POPE(tfds.core.GeneratorBasedBuilder): + """DatasetBuilder for POPE dataset.""" + + VERSION = tfds.core.Version('1.0.0') + RELEASE_NOTES = {'1.0.0': 'First release.'} + BUILDER_CONFIGS = [ + POPEConfig(name='pope_random', description='Random set'), + POPEConfig(name='pope_popular', description='Popular set'), + POPEConfig(name='pope_adversarial', description='Adversarial set'), + ] + + def _info(self): + """Returns the metadata.""" + return tfds.core.DatasetInfo( + builder=self, + description=_DESCRIPTION, + features=tfds.features.FeaturesDict({ + 'question_id': tfds.features.Scalar(np.int32), + 'image/filename': tfds.features.Text(), + 'image': tfds.features.Image(encoding_format='png'), + 'question': tfds.features.Text(), + 'answer': tfds.features.Text(), + 'thing': tfds.features.Text(), + }), + supervised_keys=None, + homepage='https://github.com/AoiDragon/POPE', + citation=_CITATION, + ) + + def _split_generators(self, dl_manager: tfds.download.DownloadManager): + """Returns SplitGenerators.""" + return {'test': self._generate_examples('test', self.builder_config.name)} + + def _generate_examples(self, split: str, source: str): + """Yields (key, example) tuples from test set.""" + annot_fname = os.path.join( + _POPE_PATH, f'pope/coco_{source}.json' + ) + + with open(annot_fname, 'r') as f: + data = [json.loads(line) for line in f] + + for idx, v in enumerate(data): + question = v['text'] + thing = ( + question.replace('Is there an ', '') + .replace('Is there a ', '') + .replace(' in the image?', '') + ) + yield idx, { + 'question_id': idx, + 'image/filename': v['image'], + 'image': os.path.join(_POPE_PATH, 'images/val2014/', v['image']), + 'question': question, + 'answer': v['label'], + 'thing': thing, + } diff --git a/big_vision/datasets/refcoco/refcoco.py b/big_vision/datasets/refcoco/refcoco.py new file mode 100644 index 0000000000000000000000000000000000000000..27d391bfdbc4ffcdb0467940ad5597119089a1d5 --- /dev/null +++ b/big_vision/datasets/refcoco/refcoco.py @@ -0,0 +1,448 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""Unbatch RefCOCO, RefCOCO+, RefCOCOg datasets in TFDS structure.""" + +# Based on tensorflow_datasets/datasets/ref_coco + +import io +import os +import pickle + +import numpy as np +import PIL.Image +import pycocotools.coco +import tensorflow_datasets as tfds + +_ROOT_PATH = '/tmp/data/' + + +class RefCocoConfig(tfds.core.BuilderConfig): + """Config to specify each RefCoco variant.""" + + def __init__(self, dataset, dataset_partition, **kwargs): + name = f'{dataset}_{dataset_partition}' + super(RefCocoConfig, self).__init__(name=name, **kwargs) + self.dataset = dataset + self.dataset_partition = dataset_partition + + +_DESCRIPTION = """RefCOCO, RefCOCO+, RefCOCOg datasets. + +Images, boxes and segmentations are from the original COCO dataset +(Lin et al, ECCV 2014). The referential segmentations are from two different +sources: + +1) RefCOCOg (Mao et al, CVPR 2016): + - https://github.com/mjhucla/Google_Refexp_toolbox + - This is the split used in the "refcocog_google" dataset. Note that this + split has overlapping images in train/validation. The same split is also + provided in 2). + +2) Source of RefCOCO and RefCOCO+ (Yu et al, ECCV 2016): + - https://github.com/lichengunc/refer + - Apache License 2.0 + - Provides all the splits used for generation of these datasets, including the + "refcocog_google" split that is identical with the split from 1). + +For convenience, we provide an additional dataset "refcocox_combined" that +combines the datasets "refcoco_unc", "refcocoplus_unc", and "refcocog_umd", +unifying "testA" and "testB" into a single "test" split, and removing any images +from "train" that appear either in "validation" or "test". + +Also for convenience, every split is unrolled twice (at the "objects" level and +at the "object/refs" level) and saved as "{split}_flat". +""" + +# pylint: disable=line-too-long +_CITATION = r""" +@inproceedings{DBLP:conf/cvpr/MaoHTCY016, + author = {Junhua Mao and + Jonathan Huang and + Alexander Toshev and + Oana Camburu and + Alan L. Yuille and + Kevin Murphy}, + title = {Generation and Comprehension of Unambiguous Object Descriptions}, + booktitle = {2016 {IEEE} Conference on Computer Vision and Pattern Recognition, + {CVPR} 2016, Las Vegas, NV, USA, June 27-30, 2016}, + pages = {11--20}, + publisher = {{IEEE} Computer Society}, + year = {2016}, + url = {https://doi.org/10.1109/CVPR.2016.9}, + doi = {10.1109/CVPR.2016.9}, + timestamp = {Fri, 24 Mar 2023 00:02:52 +0100}, + biburl = {https://dblp.org/rec/conf/cvpr/MaoHTCY016.bib}, + bibsource = {dblp computer science bibliography, https://dblp.org} +} + +@inproceedings{DBLP:conf/eccv/YuPYBB16, + author = {Licheng Yu and + Patrick Poirson and + Shan Yang and + Alexander C. Berg and + Tamara L. Berg}, + editor = {Bastian Leibe and + Jiri Matas and + Nicu Sebe and + Max Welling}, + title = {Modeling Context in Referring Expressions}, + booktitle = {Computer Vision - {ECCV} 2016 - 14th European Conference, Amsterdam, + The Netherlands, October 11-14, 2016, Proceedings, Part {II}}, + series = {Lecture Notes in Computer Science}, + volume = {9906}, + pages = {69--85}, + publisher = {Springer}, + year = {2016}, + url = {https://doi.org/10.1007/978-3-319-46475-6\_5}, + doi = {10.1007/978-3-319-46475-6\_5}, + timestamp = {Wed, 07 Dec 2022 23:10:23 +0100}, + biburl = {https://dblp.org/rec/conf/eccv/YuPYBB16.bib}, + bibsource = {dblp computer science bibliography, https://dblp.org} +} + +@article{DBLP:journals/corr/LinMBHPRDZ14, + author = {Tsung{-}Yi Lin and + Michael Maire and + Serge J. Belongie and + Lubomir D. Bourdev and + Ross B. Girshick and + James Hays and + Pietro Perona and + Deva Ramanan and + Piotr Doll{\'{a}}r and + C. Lawrence Zitnick}, + title = {Microsoft {COCO:} Common Objects in Context}, + journal = {CoRR}, + volume = {abs/1405.0312}, + year = {2014}, + url = {http://arxiv.org/abs/1405.0312}, + archivePrefix = {arXiv}, + eprint = {1405.0312}, + timestamp = {Mon, 13 Aug 2018 16:48:13 +0200}, + biburl = {https://dblp.org/rec/bib/journals/corr/LinMBHPRDZ14}, + bibsource = {dblp computer science bibliography, https://dblp.org} +} +""" + +# coco_data = json.load(open('annotations/instances_train2017.json')) +# [l['name'] for l in coco_data['licenses']] +LICENSES = [ + 'Attribution-NonCommercial-ShareAlike License', + 'Attribution-NonCommercial License', + 'Attribution-NonCommercial-NoDerivs License', + 'Attribution License', + 'Attribution-ShareAlike License', + 'Attribution-NoDerivs License', + 'No known copyright restrictions', + 'United States Government Work', +] +# _licenses_map = {l['id']: i for i, l in enumerate(coco_data['licenses'])} +_licenses_map = {1: 0, 2: 1, 3: 2, 4: 3, 5: 4, 6: 5, 7: 6, 8: 7} + +# pyformat: disable +# [c['name'] for c in coco_data['categories']] +CATEGORIES = [ + 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', + 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', + 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', + 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', + 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', + 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', + 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', + 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', + 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', + 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', + 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', + 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', + 'hair drier', 'toothbrush', +] +# sorted(set(c['supercategory'] for c in coco_data['categories'])) +SUPERCATEGORIES = [ + 'accessory', 'animal', 'appliance', 'electronic', 'food', 'furniture', + 'indoor', 'kitchen', 'outdoor', 'person', 'sports', 'vehicle', +] +# pyformat: enable + + +# Will be exported into directory `$TFDS_DATA_DIR/ref_coco_bv` +# If the class name was `RefCOCO` then it would be exported into +# `$TFDS_DATA_DIR/ref_coco`, which would collide with the default TFDS dataset +# also named `ref_coco` (which has precedence over `data_dir` builder arg). +class RefCocoBv(tfds.core.GeneratorBasedBuilder): + """DatasetBuilder for RefCoco datasets.""" + + VERSION = tfds.core.Version('1.4.0') + RELEASE_NOTES = { + '1.4.0': 'Added flat versions of all dataset splits.', + '1.3.0': 'Added "refcocox_combined" dataset.', + '1.2.0': 'Added "train_flat" splits.', + '1.1.0': 'Added more features (mask etc), nested "refs" in "objects".', + '1.0.0': 'Initial release.', + } + + MANUAL_DOWNLOAD_INSTRUCTIONS = """ + 1. Install https://pypi.org/project/pycocotools/. + + 2. Download data (requires ~20G for COCO images): + + (mkdir -p /tmp/tfds/downloads/manual && + cd /tmp/tfds/downloads/manual && + wget http://images.cocodataset.org/zips/train2017.zip && + wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip && + wget https://bvisionweb1.cs.unc.edu/licheng/referit/data/refcoco.zip && + wget https://bvisionweb1.cs.unc.edu/licheng/referit/data/refcoco+.zip && + wget https://bvisionweb1.cs.unc.edu/licheng/referit/data/refcocog.zip && + for zip in *.zip; do unzip $zip; done + ) + + 3. Run the generation script with `TFDS_DATA_DIR=/tmp/tfds` + """ + + BUILDER_CONFIGS = [ + RefCocoConfig(dataset='refcoco', dataset_partition='unc'), + RefCocoConfig(dataset='refcoco', dataset_partition='google'), + RefCocoConfig(dataset='refcocoplus', dataset_partition='unc'), + RefCocoConfig(dataset='refcocog', dataset_partition='google'), + RefCocoConfig(dataset='refcocog', dataset_partition='umd'), + RefCocoConfig(dataset='refcocox', dataset_partition='combined'), + ] + + def _info(self) -> tfds.core.DatasetInfo: + return tfds.core.DatasetInfo( + builder=self, + features=tfds.features.FeaturesDict({ + 'id': tfds.features.Scalar(np.int32), + 'image': tfds.features.Image(encoding_format='jpeg'), + 'height': tfds.features.Scalar(np.int32), + 'width': tfds.features.Scalar(np.int32), + 'license': tfds.features.ClassLabel(names=LICENSES), + 'file_name': tfds.features.Text(), + 'flickr_url': tfds.features.Text(), + 'coco_url': tfds.features.Text(), + 'objects': tfds.features.Sequence({ + 'id': tfds.features.Scalar(np.int64), + 'area': tfds.features.Scalar(np.float32), + 'bbox': tfds.features.BBoxFeature(), + 'mask': tfds.features.Image(encoding_format='png'), + 'category': tfds.features.ClassLabel(names=CATEGORIES), + 'supercategory': tfds.features.ClassLabel( + names=SUPERCATEGORIES + ), + 'iscrowd': tfds.features.Scalar(np.bool_), + # refcoco, refcoco+, refcocog features: + 'refs': tfds.features.Sequence({ + 'id': tfds.features.Scalar(np.int32), + 'sentence': tfds.features.Text(), + }), + }), + }), + supervised_keys=None, # Set to `None` to disable + citation=_CITATION, + description=_DESCRIPTION, + ) + + def _split_generators(self, dl_manager: tfds.download.DownloadManager): + allowed_splits = { + ('refcoco', 'google'): [ + tfds.Split.TRAIN, + tfds.Split.VALIDATION, + tfds.Split.TEST, + ], + ('refcoco', 'unc'): [ + tfds.Split.TRAIN, + tfds.Split.VALIDATION, + 'testA', + 'testB', + ], + ('refcocoplus', 'unc'): [ + tfds.Split.TRAIN, + tfds.Split.VALIDATION, + 'testA', + 'testB', + ], + # Verified manually that image and annotation IDs match the ones in + # https://storage.googleapis.com/refexp/google_refexp_dataset_release.zip + ('refcocog', 'google'): [ + tfds.Split.TRAIN, + tfds.Split.VALIDATION, + ], + ('refcocog', 'umd'): [ + tfds.Split.TRAIN, + tfds.Split.VALIDATION, + tfds.Split.TEST, + ], + ('refcocox', 'combined'): [ + tfds.Split.TRAIN, + tfds.Split.VALIDATION, + tfds.Split.TEST, + ], + } + bc = self.builder_config + splits = allowed_splits[(bc.dataset, bc.dataset_partition)] + + data_dir = dl_manager.manual_dir + for url, components in ( + # pylint: disable=line-too-long + # pyformat: disable + ('http://images.cocodataset.org/zips/train2017.zip', ('train2017', '000000147328.jpg')), + ('http://images.cocodataset.org/annotations/annotations_trainval2017.zip', ('annotations', 'instances_train2017.json')), + ('https://bvisionweb1.cs.unc.edu/licheng/referit/data/refcoco.zip', ('refcoco', 'refs(unc).p')), + ('https://bvisionweb1.cs.unc.edu/licheng/referit/data/refcoco+.zip', ('refcoco+', 'refs(unc).p')), + ('https://bvisionweb1.cs.unc.edu/licheng/referit/data/refcocog.zip', ('refcocog', 'refs(umd).p')), + # pyformat: enable + # pylint: enable=line-too-long + ): + path = os.path.exists(os.path.join(data_dir, *components)) + if not path: + raise FileNotFoundError( + f'Could not find {path}: please download {url} and unzip into' + f' {data_dir}' + ) + + coco = pycocotools.coco.COCO( + os.path.join(data_dir, 'annotations', 'instances_train2017.json') + ) + + return { + split + suffix: self._generate_examples( + coco, data_dir, bc.dataset, bc.dataset_partition, split + suffix, + ) + for split in splits + for suffix in ('', '_flat') + } + + # Builder must overwrite all abstract methods. + def _generate_examples( + self, coco, data_dir, dataset, dataset_partition, split): + return _generate_examples(coco, data_dir, dataset, dataset_partition, split) + + +def _get_ids(data_dir, dataset, dataset_partition, split): + """Returns `img_ids, ann_to_refs` for specified dataset/partition/split.""" + + def load(dataset, dataset_partition): + fname = f'refs({dataset_partition}).p' + path = os.path.join(data_dir, dataset, fname) + refcoco = pickle.load(open(path, 'rb')) + return refcoco + + if split == tfds.Split.VALIDATION: + split = 'val' + + if (dataset, dataset_partition) == ('refcocox', 'combined'): + refcoco = ( + load('refcocog', 'umd') + + load('refcoco', 'unc') + + load('refcoco+', 'unc') + ) + if split == 'test': + splits = ('test', 'testA', 'testB') + else: + splits = (split,) + + exclude_img_ids = set() + if split == 'train': + # Exclude all images with val/test annotations from train set. + exclude_img_ids = { + r['image_id'] for r in refcoco if r['split'] != 'train' + } + refcoco = [ + r + for r in refcoco + if r['split'] in splits and r['image_id'] not in exclude_img_ids + ] + + else: + if dataset == 'refcocoplus': + dataset = 'refcoco+' + refcoco = load(dataset, dataset_partition) + refcoco = [r for r in refcoco if r['split'] == split] + + img_ids = {r['image_id'] for r in refcoco} + ann_to_refs = {} + for r in refcoco: + for sent in r['sentences']: + ann_to_refs.setdefault(r['ann_id'], []).append(dict( + id=sent['sent_id'], + sentence=sent['sent'] + )) + + return img_ids, ann_to_refs + + +def _generate_examples(coco, data_dir, dataset, dataset_partition, split): + """Generates examples for a given split.""" + + flat = '_flat' in split + split = split.replace('_flat', '') + img_ids, ann_to_refs = _get_ids(data_dir, dataset, dataset_partition, split) + + for img_id in coco.getImgIds(): + + if img_id not in img_ids: + continue + img, = coco.loadImgs([img_id]) + + example = { + 'id': img_id, + 'image': os.path.join(data_dir, 'train2017', img['file_name']), + 'height': img['height'], + 'width': img['width'], + 'license': LICENSES[_licenses_map[img['license']]], + 'file_name': img['file_name'], + 'flickr_url': img['flickr_url'], + 'coco_url': img['coco_url'], + 'objects': [], + } + for ann in coco.loadAnns(coco.getAnnIds(img_id)): + refs = ann_to_refs.get(ann['id']) + if not refs: + continue + cat, = coco.loadCats([ann['category_id']]) + mask = coco.annToMask(ann).astype(np.bool_) + mask_buf = io.BytesIO() + PIL.Image.fromarray(mask).save(mask_buf, 'png') + mask_buf.seek(0) + object_ = { + 'id': ann['id'], + 'mask': mask_buf, + 'category': cat['name'], + 'supercategory': cat['supercategory'], + 'iscrowd': ann['iscrowd'], + 'area': ann['area'], + 'bbox': _convert_bbox(img, *ann['bbox']), + 'refs': refs, + } + if flat: + example['objects'] = [object_] + for ref_i, ref in enumerate(refs): + object_['refs'] = [ref] + mask_buf.seek(0) + yield f'{img_id}_{ann["id"]}_{ref_i}', example + else: + example['objects'].append(object_) + + if not flat: + yield img_id, example + + +def _convert_bbox(img, x, y, w, h): + return tfds.features.BBox( + ymin=y / img['height'], + xmin=x / img['width'], + ymax=(y + h) / img['height'], + xmax=(x + w) / img['width'], + ) diff --git a/big_vision/datasets/rsvqa_hr/rsvqa_hr.py b/big_vision/datasets/rsvqa_hr/rsvqa_hr.py new file mode 100644 index 0000000000000000000000000000000000000000..9f41612edfa9606572c316a68c8908061933e6c7 --- /dev/null +++ b/big_vision/datasets/rsvqa_hr/rsvqa_hr.py @@ -0,0 +1,193 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""Implements RSVQA-HR dataset in TFDS. + +Remote sensing visual question answering task, using high-resolution airborne +image data at 15cm resolution per pixel. + +It's small dataset at source (14G), so simple to run locally. +First, download and unzip the dataset from https://zenodo.org/records/6344367 +and place it in /tmp/data/rsvqa_hr. + +Then, run conversion locally (make sure to install tensorflow-datasets for the `tfds` util): + + cd third_party/py/big_vision/datasets + env TFDS_DATA_DIR=/tmp/tfds tfds build --datasets=rsvqa_hr + +Example to load: + + import tensorflow_datasets as tfds + dataset = tfds.load('rsvqa_hr', split='train', data_dir='/tmp/tfds') + +Dataset splits (all): + train: 625,340 examples/questions + val: 102,843 examples/questions + test: 222,684 examples/questions + test_2: 105,647 examples/questions (other area, unknown instrument) +Non-numeric data splits (nonum): + train: 371,834 examples/questions + val: 60,405 examples/questions + test: 131,468 examples/questions + test_2: 62,554 examples/questions + +Note: due to image duplication with each question, the dataset size is +significatnly increased by the number of questions per image. + +Recommended training splits: + train: train + minitrain: train[:5%] + eval: val + full_train: train+val + test: test + +Image sizes: 512x512 +Number of answers per question: 1 +Question types distribution in train split: + - Area (area): 14.6% (integers, binned into {0m2, 1-10m2, 11-100m2, 101-1000m2, >1000m2}) + - Comparison(comp): 33.5% + - Count (count): 26.0% (integers, not binned, maximum number of objects is 89) + - Presence (presence): 26.0% +""" +import json +import os + +import numpy as np +import tensorflow_datasets as tfds + + +_DESCRIPTION = """RSVQA-HR dataset.""" + +# pylint: disable=line-too-long +_CITATION = """ +@article{Lobry_2020, + title={RSVQA: Visual Question Answering for Remote Sensing Data}, + volume={58}, + ISSN={1558-0644}, + url={http://dx.doi.org/10.1109/TGRS.2020.2988782}, + DOI={10.1109/tgrs.2020.2988782}, + number={12}, + journal={IEEE Transactions on Geoscience and Remote Sensing}, + publisher={Institute of Electrical and Electronics Engineers (IEEE)}, + author={Lobry, Sylvain and Marcos, Diego and Murray, Jesse and Tuia, Devis}, + year={2020}, + month=dec, pages={8555-8566} } +""" +# pylint: enable=line-too-long + +# When running locally (recommended), copy files as above an use these: +PATH = '/tmp/data/rsvqa_hr/' + + +class RsvqaHrConfig(tfds.core.BuilderConfig): + """Config to specify each variant.""" + + def __init__(self, nonum, **kwargs): + name = 'nonum' if nonum else 'all' + super(RsvqaHrConfig, self).__init__(name=name, **kwargs) + self.nonum = nonum + + +class RsvqaHr(tfds.core.GeneratorBasedBuilder): + """DatasetBuilder for RSVQA-HR dataset.""" + + VERSION = tfds.core.Version('1.0.2') + RELEASE_NOTES = { + '1.0.0': 'First release.', + '1.0.1': 'Rename binned values.', + '1.0.2': 'Removed explicit png image encoding.', + } + + BUILDER_CONFIGS = [ + RsvqaHrConfig(nonum=False), + RsvqaHrConfig(nonum=True), + ] + + def _info(self): + """Returns the metadata.""" + + return tfds.core.DatasetInfo( + builder=self, + description=_DESCRIPTION, + features=tfds.features.FeaturesDict({ + 'question_id': tfds.features.Scalar(np.int32), + 'filename': tfds.features.Text(), + 'image': tfds.features.Image(), + 'question': tfds.features.Text(), + 'question_type': tfds.features.Text(), + 'answers': tfds.features.Sequence(tfds.features.Text()), + 'raw_answers': tfds.features.Sequence(tfds.features.Text()), + }), + supervised_keys=None, + homepage='https://rsvqa.sylvainlobry.com/', + citation=_CITATION, + ) + + def _split_generators(self, dl_manager: tfds.download.DownloadManager): + """Returns SplitGenerators.""" + return { + split: self._generate_examples(split) + for split in ('train', 'val', 'test', 'test_2') + } + + def _generate_examples(self, split): + """Yields (key, example) tuples.""" + if split == 'test_2': + split = 'test_phili' + questions_path = os.path.join(PATH + f'USGS_split_{split}_questions.json') + answers_path = os.path.join(PATH + f'USGS_split_{split}_answers.json') + images_path = os.path.join(PATH + 'Data') + + with open(questions_path, 'r') as f: + questions = json.loads(f.read())['questions'] + with open(answers_path, 'r') as f: + answers = json.loads(f.read())['answers'] + + for q, a in zip(questions, answers): + assert q['active'] == a['active'] + if not q['active']: + continue + if self.builder_config.nonum and q['type'] in ('area', 'count'): + continue + assert q['answers_ids'][0] == a['id'] + assert q['id'] == a['question_id'] + + filename = f'{q["img_id"]}.png' + yield q['id'], { + 'question_id': q['id'], + 'filename': filename, + 'image': os.path.join(images_path, filename), + 'question': q['question'], + 'question_type': q['type'], + 'answers': [bin_answer(a['answer'], q['type'])], + 'raw_answers': [a['answer']], + } + + +def bin_answer(answer, question_type): + """Bins answers into expected ranges.""" + if question_type == 'area': + area = int(answer[:-2]) + if area == 0: + return '0 m2' + elif area <= 10: + return 'between 1 m2 and 10 m2' + elif area <= 100: + return 'between 11 m2 and 100 m2' + elif area <= 1000: + return 'between 101 m2 and 1000 m2' + else: + return 'more than 1000 m2' + return answer diff --git a/big_vision/datasets/rsvqa_lr/rsvqa_lr.py b/big_vision/datasets/rsvqa_lr/rsvqa_lr.py new file mode 100644 index 0000000000000000000000000000000000000000..7b2a963446ab2524141e55f07af4b64cf483c54e --- /dev/null +++ b/big_vision/datasets/rsvqa_lr/rsvqa_lr.py @@ -0,0 +1,198 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""Implements RSVQA-LR dataset in TFDS. + +Remote sensing visual question answering task, using low-resolution satellite +(Sentinel-2) RGB channels data at 10m resolution per pixel. + +It's small dataset at source (200M), so simple to run locally. +First, download and unzip the dataset from https://zenodo.org/records/6344334 +and place it in /tmp/data/rsvqa_lr. + +Then, run conversion locally (make sure to install tensorflow-datasets for the `tfds` util): + + cd third_party/py/big_vision/datasets + env TFDS_DATA_DIR=/tmp/tfds tfds build --datasets=rsvqa_lr + +Example to load: + + import tensorflow_datasets as tfds + dataset = tfds.load('rsvqa_lr', split='train', data_dir='/tmp/tfds') + +Dataset splits: + train: 57223 examples/questions + val: 10005 examples/questions + test: 10004 examples/questions +And the same splits are available excluding numeric questions: + train_nonum: 39441 examples/questions + val_nonum: 6782 examples/questions + test_nonum: 6782 examples/questions + +Note: due to image duplication with each question, the dataset size is +significatnly increased by the number of questions per image. + +Recommended training splits: + train: train + minitrain: train[:5%] + eval: val + full_train: train+val + test: test + +Image sizes: 256x256 +Number of answers per question: 1 +Question types distribution in train split: + - Comparison(comp): 39.4% + - Count (count): 29.9% (integers, binned at evaluation into + {0, 1-10, 11-100, 101-1000, >10000}) + - Presence (presence): 29.7% + - Rural/Urban (rural_urban): 1% +""" +import io +import json +import os + +import numpy as np +import tensorflow_datasets as tfds + + +_DESCRIPTION = """RSVQA-LR dataset.""" + +# pylint: disable=line-too-long +_CITATION = """ +@article{Lobry_2020, + title={RSVQA: Visual Question Answering for Remote Sensing Data}, + volume={58}, + ISSN={1558-0644}, + url={http://dx.doi.org/10.1109/TGRS.2020.2988782}, + DOI={10.1109/tgrs.2020.2988782}, + number={12}, + journal={IEEE Transactions on Geoscience and Remote Sensing}, + publisher={Institute of Electrical and Electronics Engineers (IEEE)}, + author={Lobry, Sylvain and Marcos, Diego and Murray, Jesse and Tuia, Devis}, + year={2020}, + month=dec, pages={8555–8566} } +""" +# pylint: enable=line-too-long + +# When running locally (recommended), copy files as above an use these: +PATH = '/tmp/data/rsvqa_lr/' + + +class RsvqaLrConfig(tfds.core.BuilderConfig): + """Config to specify each variant.""" + + def __init__(self, nonum, **kwargs): + name = 'nonum' if nonum else 'all' + super(RsvqaLrConfig, self).__init__(name=name, **kwargs) + self.nonum = nonum + + +class RsvqaLr(tfds.core.GeneratorBasedBuilder): + """DatasetBuilder for RSVQA-LR dataset.""" + + VERSION = tfds.core.Version('1.0.2') + RELEASE_NOTES = { + '1.0.0': 'First release.', + '1.0.1': 'Rename binned values.', + '1.0.2': 'Removed explicit png image encoding.', + } + + BUILDER_CONFIGS = [ + RsvqaLrConfig(nonum=False), + RsvqaLrConfig(nonum=True), + ] + + def _info(self): + """Returns the metadata.""" + + return tfds.core.DatasetInfo( + builder=self, + description=_DESCRIPTION, + features=tfds.features.FeaturesDict({ + 'question_id': tfds.features.Scalar(np.int32), + 'filename': tfds.features.Text(), + 'image': tfds.features.Image(), + 'question': tfds.features.Text(), + 'question_type': tfds.features.Text(), + 'answers': tfds.features.Sequence(tfds.features.Text()), + 'raw_answers': tfds.features.Sequence(tfds.features.Text()), + }), + supervised_keys=None, + homepage='https://rsvqa.sylvainlobry.com/', + citation=_CITATION, + ) + + def _split_generators(self, dl_manager: tfds.download.DownloadManager): + """Returns SplitGenerators.""" + return { + split: self._generate_examples(split) + for split in ('train', 'val', 'test') + } + + def _generate_examples(self, split): + """Yields (key, example) tuples.""" + questions_path = os.path.join(PATH + f'LR_split_{split}_questions.json') + answers_path = os.path.join(PATH + f'LR_split_{split}_answers.json') + images_path = os.path.join(PATH + 'Images_LR') + + with open(questions_path, 'r') as f: + questions = json.loads(f.read())['questions'] + with open(answers_path, 'r') as f: + answers = json.loads(f.read())['answers'] + + for q, a in zip(questions, answers): + assert q['active'] == a['active'] + if not q['active']: + continue + if self.builder_config.nonum and q['type'] == 'count': + continue + assert q['answers_ids'] == [a['id']] + assert q['id'] == a['question_id'] + + filename = f'{q["img_id"]}.tif' + img = read_tif(os.path.join(images_path, filename)) + yield q['id'], { + 'question_id': q['id'], + 'filename': filename, + 'image': img, + 'question': q['question'], + 'question_type': q['type'], + 'answers': [bin_answer(a['answer'], q['type'])], + 'raw_answers': [a['answer']], + } + + +def bin_answer(answer, question_type): + """Bins answers into expected ranges.""" + if question_type == 'count': + count = int(answer) + if count == 0: + return '0' + elif count <= 10: + return 'between 1 and 10' + elif count <= 100: + return 'between 11 and 100' + elif count <= 1000: + return 'between 101 and 1000' + else: + return 'more than 1000' + return answer + + +def read_tif(path): + with open(path, 'rb') as f: + img = tfds.core.lazy_imports.tifffile.imread(io.BytesIO(f.read())) + return img.astype(np.uint8) diff --git a/big_vision/datasets/scicap/scicap.py b/big_vision/datasets/scicap/scicap.py new file mode 100644 index 0000000000000000000000000000000000000000..3aa714981785b6c5cda5a6e880160987e2a2e250 --- /dev/null +++ b/big_vision/datasets/scicap/scicap.py @@ -0,0 +1,205 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""Creates TFDS dataset for SciCap. + +Preparing the data: + 1) mkdir /tmp/data/scicap && cd /tmp/data/scicap + 2) wget 'https://www.dropbox.com/s/t1sjqesl0pynaxo/scicap_data.zip?dl=0' + 3) unzip -UU 'scicap_data.zip?dl=0' && rm 'scicap_data.zip?dl=0' + +Then, run conversion locally (make sure to install tensorflow-datasets for the `tfds` util): + + cd big_vision/datasets + env TFDS_DATA_DIR=/tmp/tfds tfds build --datasets=scicap + +Example to load: + + import tensorflow_datasets as tfds + dataset = tfds.load('scicap', split='train', data_dir='/tmp/tfds') +""" +# pylint: enable=line-too-long +import enum +import functools +import json +import os + +import tensorflow_datasets as tfds + + +_DESCRIPTION = """SciCap dataset.""" +_CITATION = """ +@article{hsu2021scicap, + title={SciCap: Generating captions for scientific figures}, + author={Hsu, Ting-Yao and Giles, C Lee and Huang, Ting-Hao'Kenneth'}, + journal={arXiv preprint arXiv:2110.11624}, + year={2021} +} +""" + +# When running locally (recommended), copy files as above an use these: +_SCICAP_DIR = "/tmp/data/scicap/scicap_data" + + +class ScicapSubset(enum.Enum): + """Versions of the SciCap dataset.""" + SINGLE_SENTENCE = "single_sentence" + FIRST_SENTENCE = "first_sentence" + LEQ_100_TOKENS = "leq_100_tokens" + +_SPLITS_TO_GENERATE = ["train", "test", "val"] +_CONFIG_TO_IDS_PATH = { + (ScicapSubset.SINGLE_SENTENCE, True): "Single-Sentence-Caption/Yes-Subfig", + (ScicapSubset.SINGLE_SENTENCE, False): "Single-Sentence-Caption/No-Subfig", + (ScicapSubset.FIRST_SENTENCE, True): "First-Sentence/Yes-Subfig", + (ScicapSubset.FIRST_SENTENCE, False): "First-Sentence/No-Subfig", + (ScicapSubset.LEQ_100_TOKENS, True): + "Caption-No-More-Than-100-Tokens/Yes-Subfig", + (ScicapSubset.LEQ_100_TOKENS, False): + "Caption-No-More-Than-100-Tokens/No-Subfig", +} +_SUBFIG_TO_PATH = { + True: "SciCap-Yes-Subfig-Img", False: "SciCap-No-Subfig-Img" +} + + +class ScicapConfig(tfds.core.BuilderConfig): + """"Configuration for SciCap caption length and subfigure inclusion.""" + + def __init__(self, *, subset: ScicapSubset, subfig: bool, **kwargs): + """Parameters specifying how the dataset will be processed. + + Args: + subset: Subset of the Scicap data (see enum above). + subfig: Whether or not figure with subfigures are included. + **kwargs: Passed on to the constructor of `BuilderConfig`. + """ + super(ScicapConfig, self).__init__(**kwargs) + self.subset = subset + self.subfig = subfig + + +@functools.cache +def _read_annotations(split: str, image_id: str): + """Reads annotations for a single file.""" + path = os.path.join(_SCICAP_DIR, "SciCap-Caption-All", split) + fname = os.path.join(path, image_id + ".json") + with open(fname, "r") as fin: + return json.load(fin) + + +class Scicap(tfds.core.GeneratorBasedBuilder): + """DatasetBuilder for the SciCap dataset.""" + + VERSION = tfds.core.Version("1.0.0") + RELEASE_NOTES = {"1.0.0": "First release."} + + BUILDER_CONFIGS = [ + ScicapConfig( + name="single_sentence_subfig_yes", + description="Single sentence caption with subfigures allowed.", + subset=ScicapSubset.SINGLE_SENTENCE, + subfig=True + ), + ScicapConfig( + name="single_sentence_subfig_no", + description="Single sentence caption with subfigures not allowed.", + subset=ScicapSubset.SINGLE_SENTENCE, + subfig=False + ), + ScicapConfig( + name="first_sentence_subfig_yes", + description="First sentence of captions with subfigures allowed.", + subset=ScicapSubset.FIRST_SENTENCE, + subfig=True + ), + ScicapConfig( + name="first_sentence_subfig_no", + description="First sentence of captions with subfigures not allowed.", + subset=ScicapSubset.FIRST_SENTENCE, + subfig=False + ), + ScicapConfig( + name="leq_100_tokens_subfig_yes", + description="Captions with <= 100 tokens with subfigures allowed.", + subset=ScicapSubset.LEQ_100_TOKENS, + subfig=True + ), + ScicapConfig( + name="leq_100_tokens_subfig_no", + description=("Captions with <= 100 tokens with subfigures" + " not allowed."), + subset=ScicapSubset.LEQ_100_TOKENS, + subfig=False + ), + ] + + def _info(self): + """Returns the metadata.""" + + return tfds.core.DatasetInfo( + builder=self, + description=_DESCRIPTION, + features=tfds.features.FeaturesDict({ + "image/id": tfds.features.Text(), + "image/filename": tfds.features.Text(), + "image": tfds.features.Image(encoding_format="png"), + "caption/originally_extracted": tfds.features.Text(), + "caption/lowercase_and_token_and_remove_figure_index": + tfds.features.Text(), + "caption/normalized/basic_num": tfds.features.Text(), + "caption/normalized/advanced_equation_bracket": + tfds.features.Text(), + }), + supervised_keys=None, + homepage="https://github.com/tingyaohsu/SciCap", + citation=_CITATION, + ) + + def _split_generators(self, dl_manager: tfds.download.DownloadManager): + """Returns SplitGenerators.""" + return {split: self._generate_examples(split) + for split in _SPLITS_TO_GENERATE} + + def _generate_examples(self, split: str): + """Yields (key, example) tuples from test set.""" + config_path = _CONFIG_TO_IDS_PATH[ + (self.builder_config.subset, self.builder_config.subfig)] + image_path = os.path.join( + _SCICAP_DIR, _SUBFIG_TO_PATH[self.builder_config.subfig], split) + id_list_fname = os.path.join( + _SCICAP_DIR, "List-of-Files-for-Each-Experiments", + config_path, split, "file_idx.json") + with open(id_list_fname, "r") as fin: + split_images = json.load(fin) + + for fname in split_images: + assert fname.endswith(".png") + image_id = fname[:-len(".png")] + annotations = _read_annotations(split, image_id) + yield fname, { + "image/id": image_id, + "image/filename": fname, + "image": os.path.join(image_path, fname), + "caption/originally_extracted": annotations["0-originally-extracted"], + "caption/lowercase_and_token_and_remove_figure_index": + annotations["1-lowercase-and-token-and-remove-figure-index"][ + "caption"], + "caption/normalized/basic_num": annotations["2-normalized"][ + "2-1-basic-num"]["caption"], + "caption/normalized/advanced_equation_bracket": + annotations["2-normalized"][ + "2-2-advanced-euqation-bracket"]["caption"] + } diff --git a/big_vision/datasets/science_qa/science_qa.py b/big_vision/datasets/science_qa/science_qa.py new file mode 100644 index 0000000000000000000000000000000000000000..8eb8189e57fbe9352d4b82044e8d7e44576052f5 --- /dev/null +++ b/big_vision/datasets/science_qa/science_qa.py @@ -0,0 +1,156 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""Implements ScienceQA train/val/test-set in TFDS structure. + +First, download the science QA dataset from their website https://scienceqa.github.io/#download + - mkdir -p /tmp/data/ScienceQA_DATA + - From Google Drive: https://drive.google.com/corp/drive/folders/1w8imCXWYn2LxajmGeGH_g5DaL2rabHev +Then, run conversion locally (make sure to install tensorflow-datasets for the `tfds` util): + - cd big_vision/datasets + - env TFDS_DATA_DIR=/tmp/tfds tfds build --datasets=science_qa + +Example to load: + + import tensorflow_datasets as tfds + dataset = tfds.load( + 'science_qa', split='train', + data_dir='/tmp/tfds') + +""" +import json +import os + +import numpy as np +import tensorflow_datasets as tfds + + +_DESCRIPTION = """Sci QA test-set.""" + +# pylint: disable=line-too-long +_CITATION = """ +@inproceedings{lu2022learn, + title={Learn to Explain: Multimodal Reasoning via Thought Chains for Science Question Answering}, + author={Lu, Pan and Mishra, Swaroop and Xia, Tony and Qiu, Liang and Chang, Kai-Wei and Zhu, Song-Chun and Tafjord, Oyvind and Clark, Peter and Ashwin Kalyan}, + booktitle={The 36th Conference on Neural Information Processing Systems (NeurIPS)}, + year={2022} +} +""" +# pylint: enable=line-too-long + +# When running locally (recommended), copy files as above an use these: +_SCIQA_PATH = '/tmp/data/ScienceQA_DATA/' +# _IMAGE_COCO_PATH = '/tmp/data/val2014' + +_ALPHABETS = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ' + + +class ScienceQA(tfds.core.GeneratorBasedBuilder): + """DatasetBuilder for ScienceQA dataset.""" + + VERSION = tfds.core.Version('1.0.0') + RELEASE_NOTES = {'1.0.0': 'First release.'} + + def _info(self): + """Returns the metadata.""" + + return tfds.core.DatasetInfo( + builder=self, + description=_DESCRIPTION, + features=tfds.features.FeaturesDict({ + 'question': tfds.features.Text(), + 'choices': tfds.features.Sequence(tfds.features.Text()), + 'answer': tfds.features.Scalar(np.int32), + 'hint': tfds.features.Text(), + 'task': tfds.features.Text(), + 'grade': tfds.features.Text(), + 'subject': tfds.features.Text(), + 'topic': tfds.features.Text(), + 'category': tfds.features.Text(), + 'skill': tfds.features.Text(), + 'lecture': tfds.features.Text(), + 'solution': tfds.features.Text(), + 'image': tfds.features.Image(encoding_format='png'), + 'indexed_choices': tfds.features.Text(), + 'indexed_answer': tfds.features.Text(), + }), + supervised_keys=None, + homepage='https://github.com/lupantech/ScienceQA/tree/main', + citation=_CITATION, + ) + + def _split_generators(self, dl_manager: tfds.download.DownloadManager): + """Returns SplitGenerators.""" + return { + split: self._generate_examples(split) + for split in ('train', 'test', 'val') + } + + def _generate_examples(self, split): + """Yields (key, example) tuples from test set.""" + annot_fname = os.path.join(_SCIQA_PATH, 'problems.json') + + with open(annot_fname, 'r') as f: + data = json.loads(f.read()) + + for k, v in data.items(): + if v['split'] == split: # "split":"train" + image = v['image'] + # Science QA contains the example without image as well. As this + # conversion is for VQA tasks, we dropped the examples without Image. + # TODO: Include the examples without image, and udpate the + # downstream pipeline to skip the examples without image, instead of + # doing it at pre-processing. + if image: + image = os.path.join(f'{_SCIQA_PATH}/{split}/{k}/', f'{image}') + else: + # image = None + continue + question = v['question'] + choices = v['choices'] + answer = v['answer'] + hint = v['hint'] + if not hint: + hint = 'N/A' # align with orignal github implementation + task = v['task'] + grade = v['grade'] + subject = v['subject'] + topic = v['topic'] + category = v['category'] + skill = v['skill'] + lecture = v['lecture'] + solution = v['solution'] + split = v['split'] + indexed_choices = ', '.join( + f'({_ALPHABETS[i]}) {c}' for i, c in enumerate(choices) + ) + indexed_answer = _ALPHABETS[int(answer)] + yield int(k), { + 'question': question, + 'choices': choices, + 'answer': answer, + 'hint': hint, + 'task': task, + 'grade': grade, + 'subject': subject, + 'topic': topic, + 'category': category, + 'skill': skill, + 'lecture': lecture, + 'solution': solution, + 'image': image, + 'indexed_choices': indexed_choices, + 'indexed_answer': indexed_answer, + } diff --git a/big_vision/datasets/screen2words/screen2words.py b/big_vision/datasets/screen2words/screen2words.py new file mode 100644 index 0000000000000000000000000000000000000000..97d469b13734d1a566cbfa05e8e1de91c6e55e6b --- /dev/null +++ b/big_vision/datasets/screen2words/screen2words.py @@ -0,0 +1,120 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""Creates TFDS dataset for Screen2words. + + +Preparing the data: + 1) mkdir /tmp/data/rico && cd /tmp/data/rico + 2) wget https://storage.googleapis.com/crowdstf-rico-uiuc-4540/rico_dataset_v0.1/unique_uis.tar.gz + 3) tar xvfz unique_uis.tar.gz && rm unique_uis.tar.gz + 4) git clone https://github.com/google-research-datasets/screen2words.git + +Then, run conversion locally (make sure to install tensorflow-datasets for the `tfds` util): + + cd big_vision/datasets + env TFDS_DATA_DIR=/tmp/tfds tfds build --datasets=screen2words + +Example to load: + + import tensorflow_datasets as tfds + dataset = tfds.load('screen2_words', split='train', data_dir='/tmp/tfds') +""" +# pylint: enable=line-too-long +import collections +import csv +import os + +import numpy as np +import tensorflow_datasets as tfds + + +_DESCRIPTION = """Screen2words dataset.""" +_CITATION = """ +@inproceedings{wang2021screen2words, + title={Screen2words: Automatic mobile UI summarization with multimodal + learning}, + author={Wang, Bryan and + Li, Gang and + Zhou, Xin and + Chen, Zhourong and + Grossman, Tovi and + Li, Yang}, + booktitle={The 34th Annual ACM Symposium on User Interface Software + and Technology}, + pages={498--510}, + year={2021} +} +""" + +# When running locally (recommended), copy files as above an use these: +_SCREEN2WORDS_DIR = "/tmp/data/rico/screen2words" +_RICO_DIR = "/tmp/data/rico/combined" + + +# (name, path) tuples for splits to be generated. +_SPLITS_TO_GENERATE = ["train", "dev", "test"] + + +class Screen2Words(tfds.core.GeneratorBasedBuilder): + """DatasetBuilder for the Screen2words dataset.""" + + VERSION = tfds.core.Version("1.0.0") + RELEASE_NOTES = {"1.0.0": "First release."} + + def _info(self): + """Returns the metadata.""" + + return tfds.core.DatasetInfo( + builder=self, + description=_DESCRIPTION, + features=tfds.features.FeaturesDict({ + "image/id": tfds.features.Scalar(np.int32), + "image/filename": tfds.features.Text(), + "image": tfds.features.Image(encoding_format="jpeg"), + "summary": tfds.features.Sequence(tfds.features.Text()), + }), + supervised_keys=None, + homepage="https://github.com/google-research-datasets/screen2words", + citation=_CITATION, + ) + + def _split_generators(self, dl_manager: tfds.download.DownloadManager): + """Returns SplitGenerators.""" + return {split: self._generate_examples(split) + for split in _SPLITS_TO_GENERATE} + + def _generate_examples(self, split: str): + """Yields (key, example) tuples from test set.""" + id_list_fname = os.path.join( + _SCREEN2WORDS_DIR, "split", f"{split}_screens.txt") + with open(id_list_fname, "r") as fin: + split_ids = fin.readlines() + + summaries_fname = os.path.join(_SCREEN2WORDS_DIR, "screen_summaries.csv") + summaries = collections.defaultdict(list) + with open(summaries_fname, "r") as fin: + for entry in csv.DictReader(fin): + summaries[int(entry["screenId"])].append(entry["summary"]) + + for line in split_ids: + line = line.strip() + image_id = int(line) + yield image_id, { + "image/id": image_id, + "image/filename": f"{image_id}.jpg", + "image": os.path.join(_RICO_DIR, f"{image_id}.jpg"), + "summary": summaries[image_id], + } diff --git a/big_vision/datasets/sequence_packing.py b/big_vision/datasets/sequence_packing.py new file mode 100644 index 0000000000000000000000000000000000000000..91629eadf591d6f00e7411a210c07eb75b7c3f0b --- /dev/null +++ b/big_vision/datasets/sequence_packing.py @@ -0,0 +1,209 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Packed Sequence Op.""" + +# Forked from +# https://github.com/google/maxtext/blob/main/MaxText/sequence_packing.py. + + +from typing import Dict, Optional, List, Union + +import tensorflow as tf + +AUTOTUNE = tf.data.experimental.AUTOTUNE + + +def pack_dataset(dataset: tf.data.Dataset, + key2length: Union[int, Dict[str, int]], + keys: Optional[List[str]] = None) -> tf.data.Dataset: + """Creates a 'packed' version of a dataset on-the-fly. + + Adapted from the mesh-tf implementation. + This is meant to replace the irritation of having to create a separate + "packed" version of a dataset to train efficiently on TPU. + Each example in the output dataset represents several examples in the + input dataset. + For each key in the input dataset, two additional keys are created: + _seg: an int32 tensor identifying the parts + representing the original example. + _pos: an int32 tensor identifying the position within the original + example. + Example: + Two input examples get combined to form an output example. + The input examples are: + {"inputs": [8, 7, 1, 0], "targets":[4, 1, 0]} + {"inputs": [2, 3, 4, 1], "targets":[5, 6, 1]} + The output example is: + { + "inputs": [8, 7, 1, 2, 3, 4, 1, 0, 0, 0] + "inputs_seg": [1, 1, 1, 2, 2, 2, 2, 0, 0, 0] + "inputs_pos": [0, 1, 2, 0, 1, 2, 3, 0, 0, 0] + "targets": [4, 1, 5, 6, 1, 0, 0, 0, 0, 0] + "targets_seg": [1, 1, 2, 2, 2, 0, 0, 0, 0, 0] + "targets_pos": [0, 1, 0, 1, 2, 0, 0, 0, 0, 0] + } + 0 represents padding in both the inputs and the outputs. + Sequences in the incoming examples are truncated to length "length", and the + sequences in the output examples all have fixed (padded) length "length". + Args: + dataset: a tf.data.Dataset + key2length: an integer, or a dict from feature-key to integer + keys: a list of strings (e.g. ["inputs", "targets"]) + Returns: + a tf.data.Dataset + """ + shapes = tf.nest.map_structure(lambda spec: spec.shape, dataset.element_spec) + if keys is None: + keys = list(shapes.keys()) + for k in keys: + if k not in shapes: + raise ValueError(f"""Key {k} not found in dataset. Available keys are + {shapes.keys()}""") + if not shapes[k].is_compatible_with(tf.TensorShape([None])): + raise ValueError('Tensors to be packed must be one-dimensional.') + # make sure that the length dictionary contains all keys as well as the + # keys suffixed by "_seg" and "_pos" + if isinstance(key2length, int): + key2length = {k: key2length for k in keys} + else: + key2length = dict(key2length) # Make new dict, we'll edit in-place. + for k in keys: + for suffix in ['_seg', '_pos']: + key2length[k + suffix] = key2length[k] + + # trim to length + dataset = dataset.map( + lambda x: {k: x[k][:key2length[k]] for k in keys}, + num_parallel_calls=AUTOTUNE) + # Setting batch_size=length ensures that the concatenated sequences (if they + # have length >=1) are sufficient to fill at least one packed example. + batch_size = max(key2length.values()) + dataset = dataset.padded_batch( + batch_size, padded_shapes={k: [-1] for k in keys}) + dataset = _pack_with_tf_ops(dataset, keys, key2length) + + # Set the Tensor shapes correctly since they get lost in the process. + def my_fn(x): + return {k: tf.reshape(v, [key2length[k]]) for k, v in x.items()} + + return dataset.map(my_fn, num_parallel_calls=AUTOTUNE) + + +def _pack_with_tf_ops(dataset: tf.data.Dataset, keys: List[str], + key2length: Dict[str, int]) -> tf.data.Dataset: + """Helper-function for packing a dataset which has already been batched. + Helper for pack_dataset() Uses tf.while_loop. + Args: + dataset: a dataset containing padded batches of examples. + keys: a list of strings + key2length: an dict from feature-key to integer + Returns: + a dataset. + """ + empty_example = {} + for k in keys: + empty_example[k] = tf.zeros([0], dtype=tf.int32) + empty_example[k + '_pos'] = tf.zeros([0], dtype=tf.int32) + keys_etc = empty_example.keys() + + def write_packed_example(partial, outputs): + new_partial = empty_example.copy() + new_outputs = {} + for k in keys_etc: + new_outputs[k] = outputs[k].write( + outputs[k].size(), + tf.pad(partial[k], [[0, key2length[k] - tf.size(partial[k])]])) + return new_partial, new_outputs + + def map_fn(x): + """Internal function to flat_map over. + Consumes a batch of input examples and produces a variable number of output + examples. + Args: + x: a single example + Returns: + a tf.data.Dataset + """ + partial = empty_example.copy() + i = tf.zeros([], dtype=tf.int32) + dynamic_batch_size = tf.shape(x[keys[0]])[0] + outputs = {} + for k in keys: + outputs[k] = tf.TensorArray( + tf.int32, size=0, dynamic_size=True, element_shape=[key2length[k]]) + outputs[k + '_pos'] = tf.TensorArray( + tf.int32, size=0, dynamic_size=True, element_shape=[key2length[k]]) + + def body_fn(i, partial, outputs): + """Body function for while_loop. + Args: + i: integer scalar + partial: dictionary of Tensor (partially-constructed example) + outputs: dictionary of TensorArray + Returns: + A triple containing the new values of the inputs. + """ + can_append = True + one_example = {} + for k in keys: + val = tf.cast(x[k][i], tf.int32) + val = val[:tf.reduce_sum(tf.cast(tf.not_equal(val, 0), tf.int32))] + one_example[k] = val + for k in keys: + can_append = tf.logical_and( + can_append, + tf.less_equal( + tf.size(partial[k]) + tf.size(one_example[k]), key2length[k])) + + def false_fn(): + return write_packed_example(partial, outputs) + + def true_fn(): + return partial, outputs + + partial, outputs = tf.cond(can_append, true_fn, false_fn) + new_partial = {} + for k in keys: + new_seq = one_example[k][:key2length[k]] + new_seq_len = tf.size(new_seq) + new_partial[k] = tf.concat([partial[k], new_seq], 0) + new_partial[k + '_pos'] = tf.concat( + [partial[k + '_pos'], + tf.range(new_seq_len)], 0) + partial = new_partial + return i + 1, partial, outputs + + # For loop over all examples in the batch. + i, partial, outputs = tf.while_loop( + cond=lambda *_: True, + body=body_fn, + loop_vars=(i, partial, outputs), + shape_invariants=( + tf.TensorShape([]), + {k: tf.TensorShape([None]) for k in keys_etc}, + {k: tf.TensorShape(None) for k in keys_etc}, + ), + maximum_iterations=dynamic_batch_size) + _, outputs = write_packed_example(partial, outputs) + packed = {k: outputs[k].stack() for k in keys_etc} + for k in keys: + packed[k + '_seg'] = ( + tf.cumsum( + tf.cast(tf.equal(packed[k + '_pos'], 0), tf.int32), axis=1) * + tf.cast(tf.not_equal(packed[k], 0), tf.int32)) + return packed + + dataset = dataset.map(map_fn, num_parallel_calls=AUTOTUNE) + return dataset.unbatch() diff --git a/big_vision/datasets/stvqa/stvqa.py b/big_vision/datasets/stvqa/stvqa.py new file mode 100644 index 0000000000000000000000000000000000000000..6d215554bd820411415a1a490c6337a43ad61258 --- /dev/null +++ b/big_vision/datasets/stvqa/stvqa.py @@ -0,0 +1,134 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""Implements ST-VQA dataset in TFDS. + +It's small data, so simple to run locally. +First, download and unzip the dataset from https://rrc.cvc.uab.es/?ch=11 +and place it in /tmp/data/stvqa. + +Then, run conversion locally (make sure to install tensorflow-datasets for the `tfds` util): + + cd third_party/py/big_vision/datasets + env TFDS_DATA_DIR=/tmp/tfds tfds build --datasets=stvqa + +Example to load: + + import tensorflow_datasets as tfds + dataset = tfds.load('stvqa', split='train', data_dir='/tmp/tfds') + +Dataset splits: + train: 23446 examples/questions (subset of original train) + val: 2628 examples/questions (subset of original train) + test: 4070 examples/questions (no answers) + +Note: original source data has no val/holdout split, and we therefore split the +original train split (26074 examples/questions) by ourselves into train & val +splits. + +Recommended training splits: + train: train + minitrain: train[:5%] + eval: val + fulltrain: train+val +""" +import json +import os + +from big_vision.datasets.stvqa import val_ids +import numpy as np +import tensorflow_datasets as tfds + +_VAL_IDS = val_ids.PSEUDO_VAL_IMAGE_PATHS + +_DESCRIPTION = """ST-VQA dataset.""" + +# pylint: disable=line-too-long +_CITATION = """ +@inproceedings{Biten_2019, + title={Scene Text Visual Question Answering}, + url={http://dx.doi.org/10.1109/ICCV.2019.00439}, + DOI={10.1109/iccv.2019.00439}, + booktitle={2019 IEEE/CVF International Conference on Computer Vision (ICCV)}, + publisher={IEEE}, + author={Biten, Ali Furkan and Tito, Ruben and Mafla, Andres and Gomez, Lluis and Rusinol, Marcal and Jawahar, C.V. and Valveny, Ernest and Karatzas, Dimosthenis}, + year={2019}, + month=oct } +""" +# pylint: enable=line-too-long + +# When running locally (recommended), copy files as above an use these: +_STVQA_PATH = '/tmp/data/stvqa/' + + +class Stvqa(tfds.core.GeneratorBasedBuilder): + """DatasetBuilder for ST-VQA dataset.""" + + VERSION = tfds.core.Version('1.2.0') + RELEASE_NOTES = { + '1.0.0': 'First release.', + '1.1.0': 'Switch to COCO high-res images and lower-case answers.', + '1.2.0': 'Rename pseudo splits and remove lower-case answers.', + } + + def _info(self): + """Returns the metadata.""" + + return tfds.core.DatasetInfo( + builder=self, + description=_DESCRIPTION, + features=tfds.features.FeaturesDict({ + 'question_id': tfds.features.Scalar(np.int32), + 'filename': tfds.features.Text(), + 'image': tfds.features.Image(encoding_format='jpeg'), + 'question': tfds.features.Text(), + 'answers': tfds.features.Sequence(tfds.features.Text()), + }), + supervised_keys=None, + homepage='https://rrc.cvc.uab.es/?ch=11', + citation=_CITATION, + ) + + def _split_generators(self, dl_manager: tfds.download.DownloadManager): + """Returns SplitGenerators.""" + return {split: self._generate_examples(split) + for split in ('train', 'val', 'test')} + + def _generate_examples(self, split): + """Yields (key, example) tuples.""" + src_split = 'test' if split == 'test' else 'train' + annot_fname = os.path.join(_STVQA_PATH, f'{src_split}_task_3.json') + images_path = f'{src_split}{"_task3" if src_split == "test" else ""}_images' + + with open(annot_fname, 'r') as f: + data = json.loads(f.read()) + + for x in data['data']: + if split == 'val' and x['file_path'] not in _VAL_IDS: + continue + elif split == 'train' and x['file_path'] in _VAL_IDS: + continue + image_path = os.path.join(_STVQA_PATH, images_path, x['file_path']) + # Always use high-res COCO images from train2014 directory. + if x['file_path'].startswith('coco-text'): + image_path = image_path.replace(os.path.join(images_path, 'coco-text'), + 'train2014') + yield x['question_id'], { + 'question_id': x['question_id'], + 'filename': x['file_path'], + 'image': image_path, + 'question': x['question'], + 'answers': x.get('answers', []), + } diff --git a/big_vision/datasets/tallyqa/tallyqa.py b/big_vision/datasets/tallyqa/tallyqa.py new file mode 100644 index 0000000000000000000000000000000000000000..0fcde490c89f652fc2c3a9397670c340016f1eb7 --- /dev/null +++ b/big_vision/datasets/tallyqa/tallyqa.py @@ -0,0 +1,146 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Import TallyQA into TFDS format. Uses Visual Genome and COCO images. + +It's small data, so simple to run locally. First, download all the data: + + mkdir /tmp/data/ ; cd /tmp/data + wget http://images.cocodataset.org/zips/{train2014,val2014}.zip + wget https://cs.stanford.edu/people/rak248/VG_100K_2/images.zip + wget https://cs.stanford.edu/people/rak248/VG_100K_2/images2.zip + wget https://github.com/manoja328/tallyqa/blob/master/tallyqa.zip?raw=true + unzip *.zip + +Then, update the PATHs below and run conversion locally like so (make sure to +install tensorflow-datasets for the `tfds` util): + + cd big_vision/datasets + env TFDS_DATA_DIR=/tmp/tfds tfds build --datasets=tallyqa + +Example to load: + import tensorflow_datasets as tfds + dataset = tfds.load('tallyqa', split='train', data_dir='/tmp/tfds') + +The test split distinguishes between simple and complex questions. The train +split does not contain this information. We therefore set issimple to `-1` in +the train split to indicate it is not known. +""" + +import json + +import numpy as np +import tensorflow_datasets as tfds + + +_TALLYQA_PATH = '/tmp/data/tallyQA/' +_VISUAL_GENOME_PATH = '/tmp/data/visual_genome/' + +_COCO_PATH = '/tmp/data/coco/' + + +_DESCRIPTION = """ +TallyQA: Answering Complex Counting Questions +Most counting questions in visual question answering (VQA) datasets are simple +and require no more than object detection. Here, we study algorithms for complex +counting questions that involve relationships between objects, attribute +identification, reasoning, and more. To do this, we created TallyQA, the world's +largest dataset for open-ended counting. +""" + +_CITATION = """ +@inproceedings{acharya2019tallyqa, + title={TallyQA: Answering Complex Counting Questions}, + author={Acharya, Manoj and Kafle, Kushal and Kanan, Christopher}, + booktitle={AAAI}, + year={2019} +} +""" + +_HOMEPAGE = 'https://github.com/manoja328/TallyQA_dataset' + + +class TallyQA(tfds.core.GeneratorBasedBuilder): + """Import TallyQA dataset.""" + + VERSION = tfds.core.Version('1.0.0') + RELEASE_NOTES = {'1.0.0': 'Initial release.'} + MANUAL_DOWNLOAD_INSTRUCTIONS = """ + There are three parts which should be downloaded: + * TallyQA (train / test json files) + * Visual Genome images (needed for train and test split) + * COCO (2014) train / val images (only needed for train split) + """ + + def _info(self) -> tfds.core.DatasetInfo: + """Returns the dataset metadata.""" + features = tfds.features.FeaturesDict({ + 'image': tfds.features.Image(shape=(None, None, 3)), + 'image_id': tfds.features.Scalar(dtype=np.int32), + 'image_source': tfds.features.Text(), + 'question': tfds.features.Text(), + 'question_id': tfds.features.Scalar(dtype=np.int32), + 'answer': tfds.features.Scalar(dtype=np.int32), + 'issimple': tfds.features.Scalar(dtype=np.int32), + }) + + return tfds.core.DatasetInfo( + builder=self, + features=features, + description=_DESCRIPTION, + supervised_keys=None, + homepage=_HOMEPAGE, + citation=_CITATION, + ) + + def _split_generators(self, dl_manager: tfds.download.DownloadManager) -> ...: + """Call the function which defines the splits.""" + del dl_manager + return { + 'train': self._generate_examples(split='train'), + 'test': self._generate_examples(split='test'), + } + + def _generate_examples(self, split: str) -> ...: + tally_json_file = f'{_TALLYQA_PATH}/{split}.json' + with open(tally_json_file, 'r') as f: + tally_json = json.load(f) + + for tally_qa in tally_json: + # The TallyQA images come from two sources: Visual Genome and COCO. + # Determine the correct dataset by inspecting the prefix. + filepath = tally_qa['image'] + if filepath.startswith('VG_100K'): + filepath = _VISUAL_GENOME_PATH + filepath + elif filepath.startswith('train2014') or filepath.startswith('val2014'): + filepath = _COCO_PATH + filepath + else: + raise ValueError(f'Unknown image path: {filepath}') + + tally_qa_dict = { + 'image': filepath, + 'image_id': tally_qa['image_id'], + 'image_source': tally_qa['data_source'], + 'question': tally_qa['question'], + 'question_id': tally_qa['question_id'], + 'answer': int(tally_qa['answer']), + } + if split == 'test': + # Field only present in test split. + tally_qa_dict.update({'issimple': tally_qa['issimple']}) + else: + # In the train split, we set issimple to -1 to indicate it is not known. + tally_qa_dict.update({'issimple': -1}) + tally_qa_id = f'{tally_qa_dict["image_id"]} / {tally_qa_dict["question_id"]}' # pylint: disable=line-too-long + yield tally_qa_id, tally_qa_dict diff --git a/big_vision/datasets/textcaps/textcaps.py b/big_vision/datasets/textcaps/textcaps.py new file mode 100644 index 0000000000000000000000000000000000000000..4b004f03082d4e576c3a306dab3a94a3887059b3 --- /dev/null +++ b/big_vision/datasets/textcaps/textcaps.py @@ -0,0 +1,152 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""Implements textcaps val-set in TFDS structure. + +It's small data, so simple to run locally. First, copy the data to local disk: + + mkdir -p /tmp/data/textcaps + cd /tmp/data/textcaps + curl -O https://dl.fbaipublicfiles.com/textvqa/data/textcaps/TextCaps_0.1_train.json + curl -O https://dl.fbaipublicfiles.com/textvqa/data/textcaps/TextCaps_0.1_val.json + curl -O https://dl.fbaipublicfiles.com/textvqa/data/textcaps/TextCaps_0.1_test.json + curl -O https://dl.fbaipublicfiles.com/textvqa/images/train_val_images.zip + curl -O https://dl.fbaipublicfiles.com/textvqa/images/test_images.zip + unzip train_val_images.zip + rm train_val_images.zip + unzip test_images.zip + rm test_images.zip + +Then, run conversion locally (make sure to install tensorflow-datasets for the +`tfds` util): + + cd big_vision/datasets + env TFDS_DATA_DIR=/tmp/tfds tfds build --datasets=textcaps + + +Example to load: + + import tensorflow_datasets as tfds + dataset = tfds.load('text_caps', split='val', data_dir='/tmp/tfds') +""" +import collections +import json +import os + +from absl import logging +import numpy as np +import tensorflow_datasets as tfds + + +_DESCRIPTION = """TextCaps dataset.""" + +# pylint: disable=line-too-long +_CITATION = ( + '@inproceedings{sidorov2019textcaps,' + 'title={TextCaps: a Dataset for Image Captioningwith Reading Comprehension},' + 'author={Sidorov, Oleksii and Hu, Ronghang and Rohrbach, Marcus and Singh, Amanpreet},' + 'journal={European Conference on Computer Vision},' + 'year={2020}}') +# pylint: enable=line-too-long + +# When running locally (recommended), copy files as above an use these: +_FILEPATH = '/tmp/data/textcaps/' +_TRAIN_FILES = '/tmp/data/textcaps/TextCaps_0.1_train.json' +_VAL_FILES = '/tmp/data/textcaps/TextCaps_0.1_val.json' +_TEST_FILES = '/tmp/data/textcaps/TextCaps_0.1_test.json' + + +class TextCaps(tfds.core.GeneratorBasedBuilder): + """DatasetBuilder for TextCaps dataset.""" + + VERSION = tfds.core.Version('1.0.0') + RELEASE_NOTES = { + '1.0.0': 'Initial release.', + } + + def _info(self) -> tfds.core.DatasetInfo: + """Returns the dataset metadata. + + (tfds.core.DatasetInfo object) + These are the features of your dataset like images, labels, etc. + """ + + return tfds.core.DatasetInfo( + builder=self, + description=_DESCRIPTION, + features=tfds.features.FeaturesDict({ + 'image/id': tfds.features.Text(), + 'image_filepath': tfds.features.Text(), + 'image': tfds.features.Image(encoding_format='jpeg'), + 'texts': tfds.features.Sequence(tfds.features.Text()), + }), + supervised_keys=None, # Set to `None` to disable + homepage='https://textvqa.org/textcaps/', + citation=_CITATION, + ) + + def _split_generators(self, dl_manager: tfds.download.DownloadManager): + """Returns SplitGenerators.""" + def group_by_id(data, image_dir): + id_to_example = collections.defaultdict(list) + for ex in data: + id_to_example[ex['image_id']].append(ex) + + for k, exs in id_to_example.items(): + image_ids, image_names, texts = [], [], [] + for ex in exs: + image_ids.append(ex['image_id']) + image_names.append(ex['image_name']) + if ex.get('caption_str'): + texts.append(ex.get('caption_str')) + assert len(set(image_ids)) == 1 + assert len(set(image_names)) == 1 + image_filepath = os.path.join( + _FILEPATH, image_dir, str(image_names[0])+'.jpg') + id_to_example[k] = { + 'image/id': image_ids[0], + 'image_filepath': image_filepath, + 'image': image_filepath, + 'texts': texts, + } + return id_to_example + + # Returns the Dict[split names, Iterator[Key, Example]] + with open(_TRAIN_FILES) as f: + train_data = group_by_id(json.load(f)['data'], 'train_images') + with open(_VAL_FILES) as f: + val_data = group_by_id(json.load(f)['data'], 'train_images') + with open(_TEST_FILES) as f: + test_data = group_by_id(json.load(f)['data'], 'test_images') + return { + 'train': self._generate_examples(train_data), + 'val': self._generate_examples(val_data), + 'test': self._generate_examples(test_data), + } + + def _generate_examples(self, data): + """Generate a tf.Example object. + + This contains the image, objects, attributes, regions and relationships. + + Args: + data: a dictionary with the image/id. + + Yields: + (key, example) tuples from dataset. The example has format specified in + the above DatasetInfo. + """ + for k, v in data.items(): + yield k, v diff --git a/big_vision/datasets/textvqa/textvqa.py b/big_vision/datasets/textvqa/textvqa.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b996c682b9bf6fe156e1dcbc1b089ac40cf751 --- /dev/null +++ b/big_vision/datasets/textvqa/textvqa.py @@ -0,0 +1,186 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""Implements textvqa in TFDS structure. + +It's small data, so simple to run locally. First, copy the data to local disk: + + mkdir -p /tmp/data/textvqa + cd /tmp/data/textvqa + curl -O https://dl.fbaipublicfiles.com/textvqa/images/train_val_images.zip + curl -O https://dl.fbaipublicfiles.com/textvqa/images/test_images.zip + curl -O https://dl.fbaipublicfiles.com/textvqa/data/TextVQA_0.5.1_train.json + curl -O https://dl.fbaipublicfiles.com/textvqa/data/TextVQA_0.5.1_val.json + curl -O https://dl.fbaipublicfiles.com/textvqa/data/TextVQA_0.5.1_test.json + # The Rosetta_OCR files are probably not needed. + # curl -O https://dl.fbaipublicfiles.com/textvqa/data/TextVQA_Rosetta_OCR_v0.2_train.json + # curl -O https://dl.fbaipublicfiles.com/textvqa/data/TextVQA_Rosetta_OCR_v0.2_val.json + # curl -O https://dl.fbaipublicfiles.com/textvqa/data/TextVQA_Rosetta_OCR_v0.2_test.json + unzip train_val_images.zip + rm train_val_images.zip + unzip test_images.zip + rm test_images.zip + # Background: at https://textvqa.org/dataset/ it says: + # "Note: Some of the images in OpenImages are rotated, + # please make sure to check the Rotation field in the Image IDs files + # for train and test." + curl -O https://storage.googleapis.com/openimages/2018_04/train/train-images-boxable-with-rotation.csv + curl -O https://storage.googleapis.com/openimages/2018_04/test/test-images-with-rotation.csv + mv train-images-boxable-with-rotation.csv train_images/rotation.csv + mv test-images-with-rotation.csv test_images/rotation.csv + +Then, run conversion locally (make sure to install tensorflow-datasets for the `tfds` util): + + cd big_vision/datasets + env TFDS_DATA_DIR=/tmp/tfds tfds build --datasets=textvqa + +Example to load: + + import tensorflow_datasets as tfds + dataset = tfds.load('textvqa', split='train', data_dir='/tmp/tfds') +""" +import json +import os + +from absl import logging +import numpy as np +import pandas as pd +import tensorflow as tf +import tensorflow_datasets as tfds + + +_DESCRIPTION = """TextVqa dataset.""" + +# pylint: disable=line-too-long +_CITATION = ( + '@inproceedings{singh2019towards,' + 'title={Towards VQA Models That Can Read},' + 'author={Singh, Amanpreet and Natarjan, Vivek and Shah, Meet and Jiang, Yu and Chen, Xinlei and Parikh, Devi and Rohrbach, Marcus},' + 'booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},' + 'pages={8317-8326},' + 'year={2019}}' + ) +# pylint: enable=line-too-long + +# When running locally (recommended), copy files as above and use these: +_FILEPATH = '/tmp/data/textvqa/' +_TRAIN_FILES = '/tmp/data/textvqa/TextVQA_0.5.1_train.json' +_VAL_FILES = '/tmp/data/textvqa/TextVQA_0.5.1_val.json' +_TEST_FILES = '/tmp/data/textvqa/TextVQA_0.5.1_test.json' +_ROTATION_CSV = 'rotation.csv' + + +class TextVqa(tfds.core.GeneratorBasedBuilder): + """DatasetBuilder for textvqa dataset.""" + + VERSION = tfds.core.Version('1.0.1') + RELEASE_NOTES = { + '1.0.0': 'Initial release.', + '1.0.1': 'Undo rotation for known rotated images.', + } + + def _info(self) -> tfds.core.DatasetInfo: + """Returns the dataset metadata. + + (tfds.core.DatasetInfo object) + These are the features of your dataset like images, labels, etc. + """ + + return tfds.core.DatasetInfo( + builder=self, + description=_DESCRIPTION, + features=tfds.features.FeaturesDict({ + 'image/id': tfds.features.Scalar(np.int32), + 'image_filepath': tfds.features.Text(), + 'image': tfds.features.Image(encoding_format='jpeg'), + 'question_id': tfds.features.Scalar(np.int32), + 'question': tfds.features.Text(), + 'answers': tfds.features.Sequence(tfds.features.Text()), + }), + supervised_keys=None, # Set to `None` to disable + homepage='https://textvqa.org/', + citation=_CITATION, + ) + + def _split_generators(self, dl_manager: tfds.download.DownloadManager): + """Returns SplitGenerators.""" + def json_to_examples(data, image_dir): + # Load rotation csv. + logging.info('Processing %d items in %s', len(data), image_dir) + rot = pd.read_csv(os.path.join(_FILEPATH, image_dir, _ROTATION_CSV)) + rotation_by_id = {} + for row in rot.itertuples(): + rotation = int(row.Rotation) if not np.isnan(row.Rotation) else 0 + rotation_by_id[row.ImageID] = rotation + + examples = {} + for v in data: + image_id = str(v['image_id']) + image_filepath = os.path.join(_FILEPATH, image_dir, image_id + '.jpg') + question_id = v['question_id'] + examples[question_id] = { + 'image/id': question_id, + 'image_filepath': image_filepath, + 'image': image_filepath, + 'rotation': rotation_by_id[image_id], + 'question_id': question_id, + 'question': v['question'], + 'answers': v.get('answers', []), # No answers in test set. + } + return examples + + # Returns the Dict[split names, Iterator[Key, Example]] + with open(_TRAIN_FILES) as f: + train_data = json_to_examples(json.load(f)['data'], 'train_images') + with open(_VAL_FILES) as f: + # Validation images are stored in the train_images folder. + val_data = json_to_examples(json.load(f)['data'], 'train_images') + with open(_TEST_FILES) as f: + test_data = json_to_examples(json.load(f)['data'], 'test_images') + return { + 'train': self._generate_examples(train_data), + 'val': self._generate_examples(val_data), + 'test': self._generate_examples(test_data), + } + + def _generate_examples(self, data): + """Generate a tf.Example object. + + Args: + data: a dictionary with the image/id. + + Yields: + (key, example) tuples from dataset. The example has format specified in + the above DatasetInfo. + """ + for k, v in data.items(): + # If the image is rotated, we undo the rotation here and re-encode. + image_bytes = open(v['image_filepath'], 'rb').read() + if v['rotation'] != 0: + rotation = v['rotation'] + assert rotation % 90 == 0 + turns = int(rotation / 90) + image = tf.image.decode_jpeg(image_bytes) + image_bytes = tf.io.encode_jpeg( + tf.image.rot90(image, turns), quality=100 + ).numpy() + # If no rotation was needed, we just pass along the unchanged bytes. + v['image'] = image_bytes + + # Now all rotation should have been accounted for. And we don't want to + # pass on the (now obsolete) rotation info as features. + del v['rotation'] + + yield k, v diff --git a/big_vision/datasets/tfds.py b/big_vision/datasets/tfds.py new file mode 100644 index 0000000000000000000000000000000000000000..4561c0f4111c877d1d4693d091a2eb7e3778fe08 --- /dev/null +++ b/big_vision/datasets/tfds.py @@ -0,0 +1,95 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""TensorFlow Datasets as data source for big_vision.""" +import functools + +import big_vision.datasets.core as ds_core +import jax +import numpy as np +import overrides +import tensorflow as tf +import tensorflow_datasets as tfds + + +class DataSource(ds_core.DataSource): + """Use TFDS as a data source.""" + + def __init__(self, name, split, data_dir=None, skip_decode=("image",)): + self.builder = _get_builder(name, data_dir) + self.split = split + # Each host is responsible for a fixed subset of data + process_splits = tfds.even_splits(split, jax.process_count()) + self.process_split = process_splits[jax.process_index()] + self.skip_decode = skip_decode + + @overrides.overrides + def get_tfdata( + self, ordered=False, *, process_split=True, allow_cache=True, **kw): + # The tf.data may use a lot of RAM, so we need to expose the option of not + # keeping this in memory when we use lots of input pipelines, such as when + # having many ephemeral evaluators. + return (_cached_get_dataset if allow_cache else _get_dataset)( + self.builder, self.skip_decode, + split=self.process_split if process_split else self.split, + shuffle_files=not ordered, + **kw) + + @property + @overrides.overrides + def total_examples(self): + return self.builder.info.splits[self.split].num_examples + + @overrides.overrides + def num_examples_per_process(self): + splits = tfds.even_splits(self.split, jax.process_count()) + return [self.builder.info.splits[s].num_examples for s in splits] + + +@functools.cache +def _get_builder(dataset, data_dir): + if dataset == "from_data_dir": + return tfds.builder_from_directory(data_dir) + else: + return tfds.builder(dataset, data_dir=data_dir, try_gcs=True) + + +# Cache as it may well take 1-2min on large datasets, and we may use the same +# multiple times (eg various evaluators). +def _get_dataset(builder, skip_decode, **kw): + """Returns a tf.data to be used.""" + rckw = {k: kw.pop(k) for k in ("shuffle_seed",) if k in kw} + ds = builder.as_dataset( + read_config=tfds.ReadConfig( + skip_prefetch=True, # We prefetch after pipeline. + try_autocache=False, # We control this, esp. for few-shot. + add_tfds_id=True, + **rckw, + ), + decoders={ + f: tfds.decode.SkipDecoding() + for f in skip_decode if f in builder.info.features + }, + **kw) + + def _hash_tfds_id(example): + id_ = tf.strings.to_hash_bucket_strong( + example["tfds_id"], + np.iinfo(np.uint32).max, # Max value + [3714561454027272724, 8800639020734831960]) # Magic. + example["_id"] = tf.bitcast(id_, tf.int32)[0] # good device dtype. + return example + + return ds.map(_hash_tfds_id) +_cached_get_dataset = functools.cache(_get_dataset) diff --git a/big_vision/datasets/vizwizvqa/vizwizvqa.py b/big_vision/datasets/vizwizvqa/vizwizvqa.py new file mode 100644 index 0000000000000000000000000000000000000000..9f30a124e916f0557e1d101b2952a22485ccc5bd --- /dev/null +++ b/big_vision/datasets/vizwizvqa/vizwizvqa.py @@ -0,0 +1,128 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""Implements VizWizVQA dataset in TFDS structure. + +It's small data, so simple to run locally. First, copy the data to local disk: + + mkdir -p /tmp/data/vizwizvqa + + wget -O https://vizwiz.cs.colorado.edu/VizWiz_final/images/train.zip /tmp/data/vizwizvqa + wget -O https://vizwiz.cs.colorado.edu/VizWiz_final/images/val.zip /tmp/data/vizwizvqa + wget -O https://vizwiz.cs.colorado.edu/VizWiz_final/images/test.zip /tmp/data/vizwizvqa + +Then, run conversion locally +(make sure to install tensorflow-datasets for the `tfds` util): + + cd big_vision/datasets + env TFDS_DATA_DIR=/tmp/tfds tfds build --datasets=vizwizvqa + +Example to load: + + import tensorflow_datasets as tfds + dataset = tfds.load('vizwizvqa', split='train', data_dir='/tmp/tfds') +""" +import json +import os + +import numpy as np +import tensorflow_datasets as tfds + + +_DESCRIPTION = """VizWiz VQA Dataset.""" + +# pylint: disable=line-too-long +_CITATION = """ +@inproceedings{gurari2018vizwiz, + title={Vizwiz grand challenge: Answering visual questions from blind people}, + author={Gurari, Danna and Li, Qing and Stangl, Abigale J and Guo, Anhong and Lin, Chi and Grauman, Kristen and Luo, Jiebo and Bigham, Jeffrey P}, + booktitle={Proceedings of the IEEE conference on computer vision and pattern recognition}, + pages={3608--3617}, + year={2018} +} +} +""" +# pylint: enable=line-too-long + +# When running locally (recommended), copy files as above an use these: +_VIZWIZVQA_PATH = '/tmp/data/vizwizvqa/' + + +class VizWizVQA(tfds.core.GeneratorBasedBuilder): + """DatasetBuilder for VizWizVQA dataset.""" + + VERSION = tfds.core.Version('1.0.0') + RELEASE_NOTES = {'1.0.0': 'First release.'} + + def _info(self): + """Returns the metadata.""" + return tfds.core.DatasetInfo( + builder=self, + description=_DESCRIPTION, + features=tfds.features.FeaturesDict({ + 'question': tfds.features.Text(), + 'image/filename': tfds.features.Text(), + 'image': tfds.features.Image(encoding_format='jpeg'), + 'answers': tfds.features.Sequence(tfds.features.Text()), + # can be "yes" "no" and "maybe" strings + 'answer_confidences': tfds.features.Sequence(tfds.features.Text()), + 'answerable': tfds.features.Scalar(np.int32), + 'question_id': tfds.features.Scalar(np.int32), + }), + supervised_keys=None, + homepage='https://vizwiz.org/tasks-and-datasets/vqa/', + citation=_CITATION, + ) + + def _split_generators(self, dl_manager: tfds.download.DownloadManager): + """Returns SplitGenerators.""" + return {split: self._generate_examples(split) + for split in ('val', 'train', 'test',)} + + def _generate_examples(self, split: str): + """Yields (key, example) tuples from test set.""" + annot_fname = os.path.join(_VIZWIZVQA_PATH, 'annotations', f'{split}.json') + + with open(annot_fname, 'r') as f: + data = json.loads(f.read()) + + for v in data: + + answers = [] + answer_confidences = [] + + image_file = v['image'] + answerable = -1 + if split != 'test': + for answer in v['answers']: + # A couple of answers in the train set are empty strings. + if not answer['answer']: + continue + answers.append(answer['answer']) + answer_confidences.append(answer['answer_confidence']) + answerable = v['answerable'] + + question_id = image_file[:-4] + question_id = int(question_id.split('_')[-1]) + + yield v['image'], { + 'question': v['question'], + 'image/filename': image_file, + 'question_id': question_id, + 'image': os.path.join(_VIZWIZVQA_PATH, split, image_file), + 'answers': answers, + 'answer_confidences': answer_confidences, + 'answerable': answerable, + } diff --git a/big_vision/datasets/vqa/vqa.py b/big_vision/datasets/vqa/vqa.py new file mode 100644 index 0000000000000000000000000000000000000000..b310ae20cf2efca3eeb18d55819a852bb9a2479a --- /dev/null +++ b/big_vision/datasets/vqa/vqa.py @@ -0,0 +1,147 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""Import VQAv2 into TFDS format. Uses coco-2014 images. + +It's small data, so simple to run locally. First, download all the data: + + mkdir /tmp/data/ ; cd /tmp/data + wget http://images.cocodataset.org/zips/{train2014,val2014,test2015}.zip + wget https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Questions_{Train,Val,Test}_mscoco.zip + wget https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Annotations_{Train,Val}_mscoco.zip + unzip '*.zip' + +Then, run conversion locally (make sure to install tensorflow-datasets for the `tfds` util): + + cd big_vision/datasets + env TFDS_DATA_DIR=/tmp/tfds tfds build --datasets=vqa + +It runs at around 750 examples/sec, so takes around 25min for the 1.2M questions. +Each question is an example; images are repeated, a bit wasteful, but disk is cheap. + + +Example to load: + + import tensorflow_datasets as tfds + dataset = tfds.load('vqa', split='train', data_dir='/tmp/tfds') +""" +import json +import os + +import numpy as np +import tensorflow_datasets as tfds + + +_VQAV2_PATH = '/tmp/data' +_IMAGE_PATH = '/tmp/data' + + +_CITATION = ( + '@InProceedings{balanced_vqa_v2,' + 'author = {Yash Goyal and Tejas Khot and ' + 'Douglas Summers{-}Stay and Dhruv Batra and Devi Parikh},' + 'title = {Making the {V} in {VQA} Matter: Elevating the Role of Image' + 'Understanding in {V}isual {Q}uestion {A}nswering},' + 'booktitle = {Computer Vision and Pattern Recognition (CVPR)},' + 'year = {2017},}') + + +class Vqa(tfds.core.GeneratorBasedBuilder): + """DatasetBuilder for VQAv2 dataset.""" + + VERSION = tfds.core.Version('3.0.0') + RELEASE_NOTES = {'3.0.0': 'Format as needed for PaliGemma'} + + def _info(self) -> tfds.core.DatasetInfo: + """Returns the metadata.""" + + return tfds.core.DatasetInfo( + builder=self, + description='The VQAv2 dataset.', + features=tfds.features.FeaturesDict({ + 'image/id': np.int32, + 'image/filename': tfds.features.Text(), + 'image': tfds.features.Image(encoding_format='jpeg'), + 'question_id': np.int32, + 'question_type': tfds.features.Text(), + 'question_text': tfds.features.Text(), + 'answer_type': tfds.features.Text(), + 'answers': tfds.features.Sequence(tfds.features.Text()), + 'answer_confidences': tfds.features.Sequence( + tfds.features.ClassLabel(names=['no', 'maybe', 'yes'])), + 'top_answer': tfds.features.Text(), + }), + homepage='https://visualqa.org/', + citation=_CITATION, + ) + + def _split_generators(self, dl_manager: tfds.download.DownloadManager): + """Returns SplitGenerators.""" + return { + 'train': self._generate_examples('train2014'), + 'validation': self._generate_examples('val2014'), + 'test': self._generate_examples('test2015'), + 'test-dev': self._generate_examples('test-dev2015', 'test2015'), + } + + def _generate_examples(self, split, image_folder=None): + """Yields (key, example) tuples from test set.""" + image_folder = image_folder or split + + # The questions file has fields image_id, question, question_id. + with open(os.path.join( + _VQAV2_PATH, f'v2_OpenEnded_mscoco_{split}_questions.json')) as f: + examples = json.load(f)['questions'] + + # The questions file has fields: image_id, question_id, answers, + # answer_type, question_type, multiple_choice_answer. + if 'test' not in split: + with open(os.path.join( + _VQAV2_PATH, f'v2_mscoco_{split}_annotations.json')) as f: + annots = {a['question_id']: a for a in json.load(f)['annotations']} + + for ex in examples: + qid = ex['question_id'] + ex = { + 'image/id': ex['image_id'], + 'question_id': qid, + 'question_text': ex['question'], + } + if 'test' not in split: + fname = f'COCO_{image_folder}_{ex["image/id"]:012d}.jpg' + ex['image/filename'] = fname + ex['image'] = os.path.join(_IMAGE_PATH, image_folder, fname) + ann = annots[qid] + ex['question_type'] = ann['question_type'] + ex['answer_type'] = ann['answer_type'] + ex['answers'] = [a['answer'] for a in ann['answers']] + ex['answer_confidences'] = [a['answer_confidence'] + for a in ann['answers']] + ex['top_answer'] = ann['multiple_choice_answer'] + else: + # For test images, a few are from the wrong year... + fname = f'COCO_{image_folder}_{ex["image/id"]:012d}.jpg' + ex['image/filename'] = fname + if os.path.isfile(path := os.path.join(_IMAGE_PATH, image_folder, fname)): + ex['image'] = path + else: + print(ex['image/id']) + continue + ex['question_type'] = '' + ex['answer_type'] = '' + ex['answers'] = [] + ex['answer_confidences'] = [] + ex['top_answer'] = '' + yield qid, ex diff --git a/big_vision/datasets/widgetcap/widgetcap.py b/big_vision/datasets/widgetcap/widgetcap.py new file mode 100644 index 0000000000000000000000000000000000000000..cfb7cab5589291e7f77eda05df3ee02689a1040a --- /dev/null +++ b/big_vision/datasets/widgetcap/widgetcap.py @@ -0,0 +1,151 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""Import widgetcap into TFDS format. + + Widget Captioning all requires images from the RICO dataset: + mkdir -p /tmp/data/rico_images ; cd /tmp/data/rico_images + wget + https://storage.googleapis.com/crowdstf-rico-uiuc-4540/rico_dataset_v0.1/unique_uis.tar.gz + tar xvfz unique_uis.tar.gz + rm unique_uis.tar.gz + + Widget Captioning: + mkdir - /tmp/data/widget_captioning ; cd /tmp/data/widget_captioning + git clone https://github.com/google-research-datasets/widget-caption.git + cp widget-caption/widget_captions.csv ./ + cp widget-caption/split/*.txt ./ + rm -rf widget-caption + +Then, run conversion locally (make sure to install tensorflow-datasets for the +`tfds` util): + + cd big_vision/datasets + env TFDS_DATA_DIR=/tmp/tfds tfds build --datasets=widgetcap + +Example to load: + + import tensorflow_datasets as tfds + dataset_augmented = tfds.load('widgetcap', split='train', + data_dir='/tmp/tfds') +""" +import csv +import json +import os + +import numpy as np +from PIL import Image +import tensorflow_datasets as tfds + +_DATASET_DIR = '/tmp/data/widget_captioning' +# Dataset property indicating the y-dim of the canvas +_RICO_CANVAS_Y = 2560 +_IMAGE_DIR = '/tmp/data/rico_images/combined' + +_CITATION = ( + '@inproceedings{Li2020WidgetCG,title={Widget Captioning: Generating Natural' + ' Language Description for MobileUser Interface Elements},author={Y. Li and' + ' Gang Li and Luheng He and Jingjie Zheng and Hong Li and Zhiwei' + ' Guan},booktitle={Conference on Empirical Methods in Natural Language' + ' Processing},year={2020},}' +) + + +class Widgetcap(tfds.core.GeneratorBasedBuilder): + """DatasetBuilder for widgetcap dataset.""" + + VERSION = tfds.core.Version('1.0.0') + RELEASE_NOTES = {'1.0.0': 'Format as needed for PaliGemma'} + + def _info(self) -> tfds.core.DatasetInfo: + """Returns the metadata.""" + + return tfds.core.DatasetInfo( + builder=self, + description='The widgetcap dataset.', + features=tfds.features.FeaturesDict({ + 'image/id': tfds.features.Text(), + 'image/filename': tfds.features.Text(), + 'image': tfds.features.Image(encoding_format='jpeg'), + 'texts': tfds.features.Sequence(tfds.features.Text()), + 'bbox': tfds.features.BBoxFeature(), + 'screen_id': tfds.features.Text(), + 'node_id': tfds.features.Text(), + 'height': np.int32, + 'width': np.int32, + }), + homepage='https://github.com/google-research-datasets/widget-caption', + citation=_CITATION, + ) + + def _split_generators(self, dl_manager: tfds.download.DownloadManager): + """Returns SplitGenerators.""" + return { + 'train': self._generate_examples('train'), + 'dev': self._generate_examples('dev'), + 'test': self._generate_examples('test'), + } + + def _generate_examples(self, split): + """Yields (key, example) tuples from the dataset.""" + split_screen_ids = set() + with open(os.path.join(_DATASET_DIR, split + '.txt')) as f: + for line in f: + split_screen_ids.add(line.strip()) + + with open(os.path.join(_DATASET_DIR, 'widget_captions.csv')) as f: + for row in csv.DictReader(f): + if row['screenId'] in split_screen_ids: + id_, example = self._get_example( + row['screenId'], row['nodeId'], row['captions'] + ) + yield id_, example + + def _get_node_box(self, screen_id, node_id, height): + index_list = [int(i) for i in node_id.split('.')[1:]] + with open(os.path.join(_IMAGE_DIR, screen_id + '.json')) as f: + view = json.load(f) + curr_node = view['activity']['root'] + for index in index_list: + curr_node = curr_node['children'][index] + normalized_bounds = map( + lambda x: x * height / _RICO_CANVAS_Y, curr_node['bounds'] + ) + return normalized_bounds + + def _get_example(self, screen_id, node_id, captions): + image = Image.open(os.path.join(_IMAGE_DIR, screen_id + '.jpg')) + width, height = image.size + # get bounding box coordinates + xmin, ymin, xmax, ymax = self._get_node_box(screen_id, node_id, height) + + image_id = f'{screen_id}_{node_id}' + example = { + 'image/id': image_id, + 'image/filename': screen_id + '.jpg', + 'image': os.path.join(_IMAGE_DIR, screen_id + '.jpg'), + 'texts': captions.split('|'), + 'bbox': tfds.features.BBox( + ymin=ymin / height, + xmin=xmin / width, + ymax=ymax / height, + xmax=xmax / width, + ), + 'screen_id': screen_id, + 'node_id': node_id, + 'height': height, + 'width': width, + } + return image_id, example diff --git a/big_vision/datasets/xgqa/xgqa.py b/big_vision/datasets/xgqa/xgqa.py new file mode 100644 index 0000000000000000000000000000000000000000..ea94addf8bd04fa1ea48694476b6e9cd7911ffe6 --- /dev/null +++ b/big_vision/datasets/xgqa/xgqa.py @@ -0,0 +1,145 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +"""Generates xGQA in a TFDS-ready structure. + +First, download the data: + mkdir -p /tmp/data/xgqa/annotations + wget https://raw.githubusercontent.com/e-bug/iglue/main/datasets/xGQA/annotations/zero_shot/testdev_balanced_questions_bn.json -P /tmp/data/xgqa/annotations + wget https://raw.githubusercontent.com/e-bug/iglue/main/datasets/xGQA/annotations/zero_shot/testdev_balanced_questions_de.json -P /tmp/data/xgqa/annotations + wget https://raw.githubusercontent.com/e-bug/iglue/main/datasets/xGQA/annotations/zero_shot/testdev_balanced_questions_en.json -P /tmp/data/xgqa/annotations + wget https://raw.githubusercontent.com/e-bug/iglue/main/datasets/xGQA/annotations/zero_shot/testdev_balanced_questions_id.json -P /tmp/data/xgqa/annotations + wget https://raw.githubusercontent.com/e-bug/iglue/main/datasets/xGQA/annotations/zero_shot/testdev_balanced_questions_ko.json -P /tmp/data/xgqa/annotations + wget https://raw.githubusercontent.com/e-bug/iglue/main/datasets/xGQA/annotations/zero_shot/testdev_balanced_questions_pt.json -P /tmp/data/xgqa/annotations + wget https://raw.githubusercontent.com/e-bug/iglue/main/datasets/xGQA/annotations/zero_shot/testdev_balanced_questions_ru.json -P /tmp/data/xgqa/annotations + wget https://raw.githubusercontent.com/e-bug/iglue/main/datasets/xGQA/annotations/zero_shot/testdev_balanced_questions_zh.json -P /tmp/data/xgqa/annotations + wget https://downloads.cs.stanford.edu/nlp/data/gqa/images.zip -P /tmp/data/xgqa/ + unzip /tmp/data/xgqa/images.zip -d /tmp/data/xgqa/ + +Then, run conversion locally (make sure to install tensorflow-datasets for the `tfds` util): + + cd big_vision/datasets + env TFDS_DATA_DIR=/tmp/tfds tfds build --datasets=xgqa + +Example to load: + +import tensorflow_datasets as tfds +dataset = tfds.load( + 'xgqa', split='test_zs_en', + data_dir='/tmp/tfds') +""" +import json +import os + +import tensorflow_datasets as tfds + +_DESCRIPTION = """xGQA (uses GQA images).""" + +# pylint: disable=line-too-long +_CITATION = ( + '@inproceedings{pfeiffer-etal-2022-xgqa,' + 'title = "x{GQA}: Cross-Lingual Visual Question Answering",' + 'author = "Pfeiffer, Jonas and' + ' Geigle, Gregor and' + ' Kamath, Aishwarya and' + ' Steitz, Jan-Martin and' + ' Roth, Stefan and' + ' Vuli{\'c}, Ivan and' + ' Gurevych, Iryna",' + 'booktitle = "Findings of the Association for Computational Linguistics: ' + 'ACL 2022",' + 'month = may,' + 'year = "2022",' + 'address = "Dublin, Ireland",' + 'publisher = "Association for Computational Linguistics",' + 'url = "https://aclanthology.org/2022.findings-acl.196",' + 'doi = "10.18653/v1/2022.findings-acl.196",' + 'pages = "2497--2511",' + '}' +) +# pylint: enable=line-too-long + +# When running locally (recommended), copy files as above an use these: +_DATA_PATH = '/tmp/data/xgqa/' +_IMAGE_PATH = '/tmp/data/xgqa/images/' + +LANGUAGES = frozenset(['bn', 'de', 'en', 'id', 'ko', 'pt', 'ru', 'zh']) + + +class XGQA(tfds.core.GeneratorBasedBuilder): + """DatasetBuilder for XGQA dataset.""" + + VERSION = tfds.core.Version('1.0.0') + RELEASE_NOTES = {'1.0.0': 'First release.'} + + def _info(self): + """Returns the metadata.""" + + return tfds.core.DatasetInfo( + builder=self, + description=_DESCRIPTION, + features=tfds.features.FeaturesDict({ + 'example_id': tfds.features.Text(), + 'image/id': tfds.features.Text(), + 'image': tfds.features.Image(encoding_format='jpeg'), + 'question': tfds.features.Text(), + 'answer': tfds.features.Text(), + }), + supervised_keys=None, + homepage='https://github.com/adapter-hub/xGQA', + citation=_CITATION, + ) + + def _split_generators(self, dl_manager: tfds.download.DownloadManager): + """Returns SplitGenerators.""" + d = dict() + for l in LANGUAGES: + d.update({ + f'test_zs_{l}': self._generate_examples('test', 'zero_shot', l), + f'test_fs_{l}': self._generate_examples('test', 'few_shot', l), + f'dev_fs_{l}': self._generate_examples('test', 'few_shot', l), + f'train_fs1_{l}': self._generate_examples('train_1', 'few_shot', l), + f'train_fs5_{l}': self._generate_examples('train_5', 'few_shot', l), + f'train_fs10_{l}': self._generate_examples('train_10', 'few_shot', l), + f'train_fs20_{l}': self._generate_examples('train_20', 'few_shot', l), + f'train_fs25_{l}': self._generate_examples('train_25', 'few_shot', l), + f'train_fs48_{l}': self._generate_examples('train_48', 'few_shot', l), + }) + return d + + def _generate_examples(self, split, num_shots, lang): + """Yields (key, example) tuples.""" + # Loads the questions for each image. + if num_shots == 'few_shot': + file_path = os.path.join(_DATA_PATH, 'annotations', 'few_shot', lang, + f'{split}.json') + elif num_shots == 'zero_shot': + file_path = os.path.join(_DATA_PATH, 'annotations', 'zero_shot', + f'testdev_balanced_questions_{lang}.json') + else: + raise ValueError(f'Unknown num_shots: {num_shots}') + with open(file_path, 'r') as f: + entries = json.load(f) + + # Make one entry per question-answer pair. + for question_id, question_data in entries.items(): + example_id = f'{question_id}_{lang}' + yield example_id, { + 'example_id': example_id, + 'image/id': question_data['imageId'], + 'image': os.path.join(_IMAGE_PATH, f'{question_data["imageId"]}.jpg'), + 'question': question_data['question'], + 'answer': question_data['answer'], + } diff --git a/big_vision/datasets/xm3600/xm3600.py b/big_vision/datasets/xm3600/xm3600.py new file mode 100644 index 0000000000000000000000000000000000000000..599b3bd4df3fb590f09b24b9d99afe04921f19bd --- /dev/null +++ b/big_vision/datasets/xm3600/xm3600.py @@ -0,0 +1,136 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""Generates XM3600 in a TFDS-ready structure. + +First, download the captions from https://google.github.io/crossmodal-3600/ and the images from https://cocodataset.org/#download. +The coco Karpathy split is available at http://cs.stanford.edu/people/karpathy/deepimagesent/caption_datasets.zip: + mkdir -p /tmp/data/xm3600 + wget https://google.github.io/crossmodal-3600/web-data/captions.zip -P /tmp/data/xm3600 + unzip /tmp/data/xm3600/captions.zip -d /tmp/data/xm3600/ + wget https://open-images-dataset.s3.amazonaws.com/crossmodal-3600/images.tgz ta-P /tmp/data/xm3600 + mkdir /tmp/data/xm3600/images + tar -xzf /tmp/data/xm3600/images.tgz -C /tmp/data/xm3600/images + +Then, run conversion locally (make sure to install tensorflow-datasets for the `tfds` util): + + cd big_vision/datasets + env TFDS_DATA_DIR=/tmp/tfds tfds build --datasets=xm3600 + +Example to load: + + import tensorflow_datasets as tfds + dataset = tfds.load( + 'xm3600', split='en', + data_dir='/tmp/tfds') +""" + +import json +import os.path + +import tensorflow_datasets as tfds + +_DESCRIPTION = """ +COCO image + captions, translated from English to 35 languages (English incl.). +""" + +# pylint: disable=line-too-long +_CITATION = """ +@inproceedings{thapliyal-etal-2022-crossmodal, + title = "Crossmodal-3600: A Massively Multilingual Multimodal Evaluation Dataset", + author = "Thapliyal, Ashish V. and + Pont Tuset, Jordi and + Chen, Xi and + Soricut, Radu", + editor = "Goldberg, Yoav and + Kozareva, Zornitsa and + Zhang, Yue", + booktitle = "Proceedings of the 2022 Conference on Empirical Methods in Natural Language Processing", + month = dec, + year = "2022", + address = "Abu Dhabi, United Arab Emirates", + publisher = "Association for Computational Linguistics", + url = "https://aclanthology.org/2022.emnlp-main.45", + doi = "10.18653/v1/2022.emnlp-main.45", + pages = "715--729", +} +""" +# pylint: enable=line-too-long + + +_CAPTIONS_PATH = '/tmp/data/xm3600' +_IMAGES_PATH = '/tmp/data/xm3600/images' + +XM3600_LANGUAGES = [ + 'ar', 'bn', 'cs', 'da', 'de', 'el', 'en', 'es', 'fa', 'fi', 'fil', 'fr', + 'he', 'hi', 'hr', 'hu', 'id', 'it', 'ja', 'ko', 'mi', 'nl', 'no', 'pl', + 'pt', 'quz', 'ro', 'ru', 'sv', 'sw', 'te', 'th', 'tr', 'uk', 'vi', 'zh' +] + + +class Xm3600(tfds.core.GeneratorBasedBuilder): + """DatasetBuilder for XM3600 dataset.""" + + VERSION = tfds.core.Version('1.0.1') + RELEASE_NOTES = { + '1.0.0': 'First release.', + '1.0.1': 'Add captions/tokenized feature to compute metrics (eg CIDEr).', + } + + def _info(self): + """Returns the metadata.""" + + return tfds.core.DatasetInfo( + builder=self, + description=_DESCRIPTION, + features=tfds.features.FeaturesDict({ + 'image/id': tfds.features.Text(), + 'image': tfds.features.Image(encoding_format='jpeg'), + 'captions': tfds.features.Sequence(tfds.features.Text()), + 'captions/tokenized': tfds.features.Sequence(tfds.features.Text()), + 'language': tfds.features.Text(), + }), + supervised_keys=None, + homepage='https://google.github.io/crossmodal-3600/', + citation=_CITATION, + ) + + def _split_generators(self, dl_manager: tfds.download.DownloadManager): + """Returns SplitGenerators.""" + return {lang: self._generate_examples(lang) for lang in XM3600_LANGUAGES} + + def _generate_examples(self, split: str): + """Yields (key, example) tuples from dataset.""" + language = split + + annot_fname = os.path.join(_CAPTIONS_PATH, 'captions.jsonl') + data = {} + tok_data = {} + with open(annot_fname, 'r') as f: + for line in f: + j = json.loads(line) + image_id = f'{j["image/key"]}_{language}' + captions = j[language]['caption'] + data[image_id] = captions + tok_data[image_id] = j[language]['caption/tokenized'] + + for image_id, captions in data.items(): + yield image_id, { + 'image/id': image_id, + 'image': os.path.join(_IMAGES_PATH, f'{image_id.split("_")[0]}.jpg'), + 'captions': captions, + 'captions/tokenized': tok_data[image_id], + 'language': language, + } diff --git a/big_vision/evaluators/__init__.py b/big_vision/evaluators/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/big_vision/evaluators/classification.py b/big_vision/evaluators/classification.py new file mode 100644 index 0000000000000000000000000000000000000000..263ead8f5027f4b8e640b9ba42a72b3cbc33adf2 --- /dev/null +++ b/big_vision/evaluators/classification.py @@ -0,0 +1,76 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Evaluator for the classfication task.""" +# pylint: disable=consider-using-from-import + +import functools + +from big_vision.evaluators import common +import big_vision.utils as u +import jax +import jax.numpy as jnp + + +# Temporary global flag to facilitate backwards compatability. Will be removed +# by the end of year 2023. +API = 'jit' + + +# To avoid re-compiling the function for every new instance of the same +# evaluator on a different dataset! +@functools.cache +def get_eval_fn(predict_fn, loss_name): + """Produces eval function, also applies pmap.""" + @jax.jit + def _eval_fn(train_state, batch, labels, mask): + logits, *_ = predict_fn(train_state, batch) + + # Ignore the entries with all zero labels for evaluation. + mask *= labels.max(axis=1) + + loss = getattr(u, loss_name)( + logits=logits, labels=labels, reduction=False) + loss = jnp.sum(loss * mask) + + top1_idx = jnp.argmax(logits, axis=1) + # Extracts the label at the highest logit index for each image. + top1_correct = jnp.take_along_axis( + labels, top1_idx[:, None], axis=1)[:, 0] + ncorrect = jnp.sum(top1_correct * mask) + nseen = jnp.sum(mask) + return ncorrect, loss, nseen + return _eval_fn + + +class Evaluator: + """Classification evaluator.""" + + def __init__(self, predict_fn, loss_name, label_key='labels', **kw): + self.get_data_iter, self.steps = common.eval_input_pipeline(**kw) + self.eval_fn = get_eval_fn(predict_fn, loss_name) + self.label_key = label_key + + def run(self, train_state): + """Computes all metrics.""" + ncorrect, loss, nseen = 0, 0, 0 + for _, batch in zip(range(self.steps), self.get_data_iter()): + labels, mask = batch.pop(self.label_key), batch.pop('_mask') + batch_ncorrect, batch_losses, batch_nseen = jax.device_get( + self.eval_fn(train_state, batch, labels, mask)) + ncorrect += batch_ncorrect + loss += batch_losses + nseen += batch_nseen + yield ('prec@1', ncorrect / nseen) + yield ('loss', loss / nseen) diff --git a/big_vision/evaluators/common.py b/big_vision/evaluators/common.py new file mode 100644 index 0000000000000000000000000000000000000000..42dcdbb4b52a5208673821b9c68df246709fcf6d --- /dev/null +++ b/big_vision/evaluators/common.py @@ -0,0 +1,228 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utils for evaluators in general.""" + +import dataclasses +import functools +import importlib +import json +import os +from typing import Any, Callable + +from absl import flags +from big_vision import input_pipeline +from big_vision.datasets import core as ds_core +from big_vision.pp import builder as pp_builder +import big_vision.utils as u +import flax +import jax +import numpy as np + +from tensorflow.io import gfile + + +def from_config(config, predict_fns, + write_note=lambda s: s, + get_steps=lambda key, cfg: cfg[f"{key}_steps"], + devices=None): + """Creates a list of evaluators based on `config`.""" + evaluators = [] + specs = config.get("evals", {}) + + for name, cfg in specs.items(): + write_note(name) + + # Pop all generic settings off so we're left with eval's kwargs in the end. + cfg = cfg.to_dict() + module = cfg.pop("type", name) + pred_key = cfg.pop("pred", "predict") + pred_kw = cfg.pop("pred_kw", None) + prefix = cfg.pop("prefix", f"{name}/") + cfg.pop("skip_first", None) + logsteps = get_steps("log", cfg) + for typ in ("steps", "epochs", "examples", "percent"): + cfg.pop(f"log_{typ}", None) + + # Use same batch_size as eval by default, to reduce fragmentation. + # TODO: eventually remove all the deprecated names... + cfg["batch_size"] = cfg.get("batch_size") or config.get("batch_size_eval") or config.get("input.batch_size") or config.get("batch_size") # pylint: disable=line-too-long + + module = importlib.import_module(f"big_vision.evaluators.{module}") + + if devices is not None: + cfg["devices"] = devices + + api_type = getattr(module, "API", "pmap") + if api_type == "pmap" and "devices" in cfg: + raise RuntimeError( + "You are seemingly using the old pmap-based evaluator, but with " + "jit-based train loop, see (internal link) for more details.") + if api_type == "jit" and "devices" not in cfg: + raise RuntimeError( + "You are seemingly using new jit-based evaluator, but with " + "old pmap-based train loop, see (internal link) for more details.") + + try: + predict_fn = predict_fns[pred_key] + except KeyError as e: + raise ValueError( + f"Unknown predict_fn '{pred_key}'. Available predict_fns are:\n" + + "\n".join(predict_fns)) from e + if pred_kw is not None: + predict_fn = _CacheablePartial(predict_fn, flax.core.freeze(pred_kw)) + evaluator = module.Evaluator(predict_fn, **cfg) + evaluators.append((name, evaluator, logsteps, prefix)) + + return evaluators + + +@dataclasses.dataclass(frozen=True, eq=True) +class _CacheablePartial: + """partial(fn, **kwargs) that defines hash and eq - to help with jit caches. + + This is particularly common in evaluators when one has many evaluator + instances that run on difference slices of data. + + Example: + + ``` + f1 = _CacheablePartial(fn, a=1) + jax.jit(f1)(...) + jax.jit(_CacheablePartial(fn, a=1))(...) # fn won't be retraced. + del f1 + jax.jit(_CacheablePartial(fn, a=1))(...) # fn will be retraced. + ``` + """ + fn: Callable[..., Any] + kwargs: flax.core.FrozenDict + + def __call__(self, *args, **kwargs): + return functools.partial(self.fn, **self.kwargs)(*args, **kwargs) + + +def eval_input_pipeline( + data, pp_fn, batch_size, devices, keep_on_cpu=(), + cache="pipeline", prefetch=1, warmup=False, +): + """Create an input pipeline in the way used by most evaluators. + + Args: + data: The configuration to create the data source (like for training). + pp_fn: A string representing the preprocessing to be performed. + batch_size: The batch size to use. + devices: The devices that the batches are sharded and pre-fetched onto. + keep_on_cpu: See input_pipeline.start_global. Entries in the batch that + should be kept on the CPU, hence could be ragged or of string type. + cache: One of "none", "pipeline", "raw_data", "final_data". Determines what + part of the input stream should be cached across evaluator runs. They use + more and more RAM, but make evals faster, in that order. + - "none": Entirely re-create and destroy the input pipeline each run. + - "pipeline": Keep the (tf.data) pipeline object alive across runs. + - "raw_data": Cache the full raw data before pre-processing. + - "final_data": Cache the full raw data after pre-processing. + prefetch: How many batches to fetch ahead. + warmup: Start fetching the first batch at creation time (right now), + instead of once the iteration starts. + + Returns: + A tuple (get_iter, steps), the first element is a function that returns + the iterator to be used for an evaluation, the second one is how many steps + should be iterated for doing one evaluation. + """ + assert ( + cache is None + or cache.lower() in ("none", "pipeline", "raw_data", "final_data") + ), f"Unknown value for cache: {cache}" + data_source = ds_core.get(**data) + tfdata, steps = input_pipeline.make_for_inference( + data_source.get_tfdata(ordered=True, allow_cache=cache.lower() != "none"), + batch_size=batch_size, + num_ex_per_process=data_source.num_examples_per_process(), + preprocess_fn=pp_builder.get_preprocess_fn(pp_fn, str(data)), + cache_final=cache == "raw_data", + cache_raw=cache == "final_data") + get_data_iter = lambda: input_pipeline.start_global( + tfdata, devices, prefetch, keep_on_cpu, warmup) + + # Possibly create one persistent iterator: + if cache in ("pipeline", "raw_data", "final_data"): + data_iter = get_data_iter() + get_data_iter = lambda: data_iter + + return get_data_iter, steps + + +def process_sum(tree): + """Sums the pytree across all processes.""" + if jax.process_count() == 1: # Avoids corner-cases on donuts. + return tree + + with jax.transfer_guard_device_to_host("allow"): + gathered = jax.experimental.multihost_utils.process_allgather(tree) + return jax.tree.map(functools.partial(np.sum, axis=0), gathered) + + +def resolve_outfile(outfile, split="", **kw): + if not outfile: + return None + + # A caveat: when workdir doesn't exist but is in the `outfile`, we should + # skip. This is common in small runs or runlocal debuggings. + if "{workdir}" in outfile and not flags.FLAGS.workdir: + return None + + return outfile.format( + workdir=flags.FLAGS.workdir, + split="".join(c if c not in "[]%:" else "_" for c in split), + step=getattr(u.chrono, "prev_step", None), + **kw, + ) + + +def multiprocess_write_json(outfile, jobj): # jobj = "json object" + """Write a single json file combining all processes' `jobj`s.""" + if not outfile: + return + + outfile = resolve_outfile(outfile) + gfile.makedirs(os.path.dirname(outfile)) + + if isinstance(jobj, list): + combine_fn = list.extend + elif isinstance(jobj, dict): + combine_fn = dict.update + else: + raise TypeError(f"Can only write list or dict jsons, but got {type(jobj)}") + + # First, each process writes its own file. + with gfile.GFile(outfile + f".p{jax.process_index()}", "w+") as f: + f.write(json.dumps(jobj)) + + u.sync() # Wait for all files to be written; `with` above does close/flush. + + # Have process 0 collect, concat, and write final output. + all_json = type(jobj)() + if jax.process_index() == 0: + for pid in range(jax.process_count()): + with gfile.GFile(outfile + f".p{pid}", "r") as f: + combine_fn(all_json, json.loads(f.read())) + with gfile.GFile(outfile, "w+") as f: + f.write(json.dumps(all_json)) + + # Cleanup time + u.sync() + gfile.remove(outfile + f".p{jax.process_index()}") + + return all_json diff --git a/big_vision/evaluators/fewshot_lsr.py b/big_vision/evaluators/fewshot_lsr.py new file mode 100644 index 0000000000000000000000000000000000000000..1b7019ad3fa58936975b631206947b3b33ecdc67 --- /dev/null +++ b/big_vision/evaluators/fewshot_lsr.py @@ -0,0 +1,245 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utils for few-shot evaluation.""" +# pylint: disable=consider-using-from-import,g-importing-member + +import functools + +import big_vision.datasets.core as ds_core +import big_vision.input_pipeline as input_pipeline +import big_vision.pp.builder as pp_builder +import big_vision.utils as u +import jax +import jax.numpy as jnp +from jax.sharding import NamedSharding as Sharding +from jax.sharding import PartitionSpec as P +import numpy as np + +BIAS_CONSTANT = 100.0 + +# Temporary global flag to facilitate backwards compatability. Will be removed +# by the end of year 2023. +API = "jit" + + +# Setup function for few-shot regression on CPU to avoid "polluting" the TPU. +@u.jit_cpu(static_argnums=(2,)) +def _precompute_cache(x, y, num_classes): + """Cache quantities to speed-up the computation of L2-regularized least-sq.""" + # Whiten + mean = jnp.mean(x, axis=0, keepdims=True) + std = jnp.std(x, axis=0, keepdims=True) + 1e-5 + x = (x - mean) / std + + # Add a constant feature for the bias, large so it's almost unregularized: + x = jnp.pad(x, ((0, 0), (0, 1)), constant_values=BIAS_CONSTANT) + + # To one-hot representation rescaled into {-1, 1} + y = 2.0 * jax.nn.one_hot(y, num_classes) - 1.0 + + num_points, dim = x.shape + # Let N be the number of points, D the dimension and C the number of classes. + # We have x of shape (N, D) and y of shape (N, C). + # For least-squares, we can compute + # + # (A) when N >= D, (x^T x + l2 Id)^{-1} x^T y + # (B) when D > N, x^T (x x^T + l2 Id)^{-1} y + # + # We pre-compute the eigen-decomposition of either x^T x or x x^T which + # becomes q diag(eigs) q^T with q unitary matrix either (D, D) or (N, N) + # and eigs a vector (D,) or (N,). + # + # For any l2 > 0, we can compute (x^T x + l2 Id)^{-1} or (x x^T + l2 Id)^{-1} + # by simply computing q (diag(eigs) + l2 Id)^{-1} q^T. + # (SVD would be more natural here, but it proved slower, so we use eigh) + # + # Both cases (A) and (B) can be viewed as lhs (diag(eigs) + l2 Id)^{-1} rhs, + # where lhs/rhs are pre-computed left/right-hand sides to specify. + # + # Detailed evaluation in terms of time and fewshot metrics can be found in + # (internal link) + # + # Implemented by Rodolphe Jenatton. + if num_points >= dim: + eigs, q = jnp.linalg.eigh(x.T @ x) + rhs = q.T @ (x.T @ y) + lhs = q + else: + eigs, q = jnp.linalg.eigh(x @ x.T) + rhs = q.T @ y + lhs = x.T @ q + + cache = { + "eigs": eigs, + "rhs": rhs, + "lhs": lhs, + "mean": mean, + "std": std + } + return cache + + +@u.jit_cpu() +def _eig_fewshot_acc_fn(cache, x_test, y_test, l2_reg): + """Computes (x,y) linear regression accuracy on (x_test, y_test).""" + + x_test = (x_test - cache["mean"]) / cache["std"] + x_test = jnp.pad(x_test, ((0, 0), (0, 1)), constant_values=BIAS_CONSTANT) + + rhs = cache["rhs"] + lhs = cache["lhs"] + eigs = cache["eigs"] + + # See comments in _precompute_cache for context about the formula. + scaling = 1.0 / (eigs + l2_reg * jnp.ones_like(eigs)) + scaling = scaling.reshape((1, -1)) + w = (lhs * scaling) @ rhs + # Predict test-set values and measure their accuracy + preds = jnp.argmax(x_test @ w, axis=1) + return jnp.mean(preds == y_test) + + +class Evaluator: + """Class for few-shot evaluation.""" + + def __init__(self, predict_fn, batch_size, + datasets, shots, l2_reg, + pp_train, pp_eval, display_first, + representation_layer=None, num_seeds=3, + label_key="label", mask_key="_mask", data_dir=None, *, + devices): + self.datasets = datasets + self.shots = shots + self.l2_reg = l2_reg + self.batch_size = batch_size + self.pp_tr = pp_train + self.pp_te = pp_eval + self.display_first = display_first + self._datasets = {} # Cache for tfds data. Persists while object is alive. + self._repr = {} # Cache for precomputed repr. Persists within the run call. + self.num_seeds = num_seeds + self.label_key = label_key + self.mask_key = mask_key + self.data_dir = data_dir + self.devices = devices + self.mesh = jax.sharding.Mesh(devices, ("devices",)) + self.repr_fn = self.get_representation_fn( + predict_fn, representation_layer) + + def get_representation_fn(self, predict_fn, representation_layer): + # `out_shardings=Sharding(self.mesh, P())` will "all_gather" the outputs. + @functools.partial(jax.jit, out_shardings=Sharding(self.mesh, P())) + def _repr_fn(train_state, batch, labels, mask): + zimg, *_, out = predict_fn(train_state, batch) + if representation_layer is not None: + rep = u.tree_get(out, representation_layer) + else: + rep = zimg + return rep, labels, mask + return _repr_fn + + # Setup input pipeline. + def _get_dataset(self, dataset, train_split, test_split): + """Lazy-loads given dataset.""" + key = (dataset, train_split, test_split) + try: + return self._datasets[key] + except KeyError: + # NOTE: only supporting TFDS data for now for bwd compat/lazyness. + train_data = ds_core.get( + name=dataset, split=train_split, data_dir=self.data_dir + ) + test_data = ds_core.get( + name=dataset, split=test_split, data_dir=self.data_dir + ) + train_ds, batches_tr = input_pipeline.make_for_inference( + train_data.get_tfdata(ordered=True), + num_ex_per_process=train_data.num_examples_per_process(), + batch_size=self.batch_size, + preprocess_fn=pp_builder.get_preprocess_fn(self.pp_tr)) + test_ds, batches_te = input_pipeline.make_for_inference( + test_data.get_tfdata(ordered=True), + num_ex_per_process=test_data.num_examples_per_process(), + batch_size=self.batch_size, + preprocess_fn=pp_builder.get_preprocess_fn(self.pp_te)) + + num_classes = train_data.builder.info.features[self.label_key].num_classes + return self._datasets.setdefault( + key, (train_ds, batches_tr, test_ds, batches_te, num_classes)) + + def _get_repr(self, params, data, steps): + """Compute representation for the whole dataset.""" + pre_logits_list = [] + labels_list = [] + for batch, _ in zip( + input_pipeline.start_global(data, self.devices, 0), range(steps)): + labels, mask = batch.pop(self.label_key), batch.pop(self.mask_key) + pre_logits, labels, mask = jax.device_get(self.repr_fn( + params, batch, labels, mask)) + mask = mask.astype(bool) + pre_logits_list.append(pre_logits[mask]) + labels_list.append(labels[mask]) + pre_logits = np.concatenate(pre_logits_list, axis=0) + labels = np.concatenate(labels_list, axis=0) + + return pre_logits, labels + + def compute_fewshot_metrics(self, train_state, seed, + dataset, train_split, test_split): + """Compute few-shot metrics on one dataset.""" + if dataset in self._repr: + repr_train, labels_train, repr_test, labels_test, num_classes = ( + self._repr[dataset]) + else: + train_ds, steps_tr, test_ds, steps_te, num_classes = self._get_dataset( + dataset, train_split, test_split) + repr_train, labels_train = self._get_repr(train_state, train_ds, steps_tr) + repr_test, labels_test = self._get_repr(train_state, test_ds, steps_te) + self._repr[dataset] = (repr_train, labels_train, + repr_test, labels_test, + num_classes) + + # Collect where we have samples of which classes. + rng = np.random.default_rng(seed) + class_indices = [rng.permutation(np.where(labels_train == cls_i)[0]) + for cls_i in range(num_classes)] + + results = {} + for shots in self.shots: + all_idx = [indices[:shots] for indices in class_indices] + all_idx = np.concatenate(all_idx, axis=0) + x = u.put_cpu(repr_train[all_idx]) + y = u.put_cpu(labels_train[all_idx]) + repr_test, labels_test = u.put_cpu((repr_test, labels_test)) + + # Note the code is optimized to solve multiple LSR tasks for changing l2 + # strength, even though we currently used the fixed l2_reg constant. + cache = _precompute_cache(x, y, num_classes) + acc = _eig_fewshot_acc_fn( + cache, repr_test, labels_test, u.put_cpu(self.l2_reg)) + results[shots] = jax.device_get(acc) + + return results + + def run(self, train_state): + """New API executed in terms of old API.""" + self._repr = {} + for seed in range(self.num_seeds): + for name, dataset_args in self.datasets.items(): + result = self.compute_fewshot_metrics(train_state, seed, *dataset_args) + for shots, v in result.items(): + prefix = "a/" if (name, shots) in self.display_first else "z/" + suffix = f"-seed-{seed}" + yield f"{prefix}{name}_{shots}shot{suffix}", v diff --git a/big_vision/evaluators/mean.py b/big_vision/evaluators/mean.py new file mode 100644 index 0000000000000000000000000000000000000000..d11590667053764bce73897eeaca4ea0d815efef --- /dev/null +++ b/big_vision/evaluators/mean.py @@ -0,0 +1,80 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Evaluator for computing mean of per-example metrics. + +This evaluator can be used in two ways: + 1. Create a new evaluator with reduced boilerplate by inheriting from it. + 2. For quick prototyping, use this with predict_fns which return the metrics. +""" +from functools import partial +from typing import Mapping + +from big_vision.evaluators import common + +import jax +import jax.numpy as jnp +import numpy as np + + +# Temporary global flag to facilitate backwards compatability. Will be removed +# by the end of year 2023. +API = 'jit' + + +# Note: global to avoid jax re-compiling across different evaluator instances. +@partial(jax.jit, static_argnums=0) +def _run_predict_fn(predict_fn, train_state, batch): + """Sum per-example metrics weighted by `_mask`.""" + mask = batch['_mask'] + metrics = predict_fn(train_state, batch) + # Sanity check output format of predict_fn. + assert isinstance(metrics, Mapping), 'predict_fn must return a dict' + for y in jax.tree.leaves(metrics): + if y.shape != mask.shape: + raise ValueError( + f'Expected per-example metrics of shape {mask.shape} found ' + f'{jax.tree.map(lambda x: x.shape, metrics)}.') + metrics = {**metrics, '_mask': mask} + return jax.tree.map(lambda x: jnp.sum(jnp.where(mask, x, 0)), metrics) + + +class Evaluator: + """Report the mean of per-example metrics computed by predict_fn. + + `predict_fn(params, batch)` must return a dict from metric name to + per-example metrics of shape [batch_size]. + """ + + def __init__(self, predict_fn, **kw): + self.get_data_iter, self.steps = common.eval_input_pipeline(**kw) + self.predict_fn = partial(_run_predict_fn, predict_fn) + + def run(self, train_state): + """Computes all metrics.""" + metrics = [] + + # Compute batch metrics without blocking. + for _, batch in zip(range(self.steps), self.get_data_iter()): + batch_metrics = self.predict_fn(train_state, batch) + metrics.append(batch_metrics) + + # Transfer metrics (blocking). + metrics = jax.device_get(metrics) + + # Accumulate metrics across batches. + metrics_sum = jax.tree.map(lambda *x: np.sum(x), *metrics) + mask_sum = metrics_sum.pop('_mask') + for key, value_sum in metrics_sum.items(): + yield (key, value_sum / mask_sum) diff --git a/big_vision/evaluators/proj/cappa/perplexity.py b/big_vision/evaluators/proj/cappa/perplexity.py new file mode 100644 index 0000000000000000000000000000000000000000..2ce69398be4243304ef62a7a1276172a4e648787 --- /dev/null +++ b/big_vision/evaluators/proj/cappa/perplexity.py @@ -0,0 +1,50 @@ +# Copyright 2023 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Evaluator for perplexity of a model.""" +from big_vision.evaluators import mean +import big_vision.utils as u +import jax.numpy as jnp + + +# Temporary global flag to facilitate backwards compatability. Will be removed +# by the end of year 2023. +API = 'jit' + + +def perplexity(predict_fn, normalize_by_seqlen): + """Returns a function that computes perplexity.""" + + def _perplexity_fn(train_state, batch, pad_token=0, **kw): + logits, _ = predict_fn(train_state, batch, **kw) + + # Ignore perplexity on the padding label. + weights = jnp.where(batch['labels'] != pad_token, 1, 0).astype(jnp.float32) + if batch.get('label_masks') is not None: + weights = weights * batch['label_masks'] + + losses = u.weighted_softmax_xent( + logits=logits, labels=batch['labels'], + weights=weights, label_smoothing=0.0, + reduction=False, normalize=normalize_by_seqlen) + + return {'perplexity': losses} + return _perplexity_fn + + +class Evaluator(mean.Evaluator): + """Perplexity evaluator.""" + + def __init__(self, predict_fn, *a, normalize_by_seqlen=False, **kw): + super().__init__(perplexity(predict_fn, normalize_by_seqlen), *a, **kw) diff --git a/big_vision/evaluators/proj/cappa/scoring_classifier.py b/big_vision/evaluators/proj/cappa/scoring_classifier.py new file mode 100644 index 0000000000000000000000000000000000000000..60906bacac2e6c0a2cabd106e7d6d82a06c08b8e --- /dev/null +++ b/big_vision/evaluators/proj/cappa/scoring_classifier.py @@ -0,0 +1,63 @@ +# Copyright 2023 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Scoring classifier. + +This one is based on a generative perspective for image classification. +Here we input the image as well as all the tokenized labels to compute their +perplexity and select the one with minimum loss as the prediction. +""" +import functools +from big_vision.datasets.imagenet import class_names as imagenet_class_names +from big_vision.evaluators import mean +from big_vision.pp import builder as pp_builder +import jax.numpy as jnp +import numpy as np + +# Temporary global flag to facilitate backwards compatability. Will be removed +# by the end of year 2023. +API = "jit" + + +CLASS_NAMES = { + "imagenet2012": imagenet_class_names.CLIP_IMAGENET_CLASS_NAMES, +} + + +# As a separate function to cache result across instances. +@functools.lru_cache(maxsize=None) +def get_classes(dataset_name, pp_txt): + """Load the class label strings and tokenize them using pp_txt.""" + pp_fn = pp_builder.get_preprocess_fn(pp_txt, log_data=False) + return np.array([pp_fn({"label": name})["labels"] + for name in CLASS_NAMES[dataset_name]]) + + +def scoring(predict_fn, tokenized_labels): + + def _scoring_fn(train_state, batch, *a, **kw): + batch = {"_label_tokens": tokenized_labels, **batch} + scores = predict_fn(train_state, batch, *a, **kw) + predictions = jnp.argmax(scores, axis=-1) + return {"prec@1": predictions == batch["label"]} + + return _scoring_fn + + +class Evaluator(mean.Evaluator): + """Evaluator for classification accuracy based on scoring all classes.""" + + def __init__(self, predict_fn, data, pp_fn, pp_txt, *a, **kw): + cls_tokens = get_classes(data["name"], pp_txt) + super().__init__(scoring(predict_fn, cls_tokens), data, pp_fn, *a, **kw) diff --git a/big_vision/evaluators/proj/distill/distance.py b/big_vision/evaluators/proj/distill/distance.py new file mode 100644 index 0000000000000000000000000000000000000000..2cc35391fe3e8f92ac88b1ba21e137d6685b88ed --- /dev/null +++ b/big_vision/evaluators/proj/distill/distance.py @@ -0,0 +1,151 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Evaluator for the classfication task.""" +from functools import partial, lru_cache + +from big_vision import input_pipeline +import big_vision.datasets.core as ds_core +import big_vision.pp.builder as pp_builder +import big_vision.utils as u + +import einops +import jax +import jax.numpy as jnp +from jax.sharding import NamedSharding +from jax.sharding import PartitionSpec as P +import numpy as np + +# Temporary global flag to facilitate backwards compatability. Will be removed +# by the end of year 2023. +API = 'jit' + + +def dist(student, teacher, kind, feat_axis=-1, + epsilon=1e-12, t=1, ls=0.0, k=1): + """Distance function used for distillation.""" + diff = student - teacher + if kind == 'euclidean': + return jnp.sqrt(jnp.sum(diff * diff, axis=feat_axis) + epsilon) + elif kind == 'l2': + return jnp.sum(diff * diff, axis=feat_axis) + elif kind == 'hard': + pseudolabels = jnp.argmax(teacher, feat_axis) + pl = u.onehot(pseudolabels, teacher.shape[feat_axis]) + if ls: + pl = (1.0 - ls) * pl + (ls / (pl.shape[-1] - 1)) * (1.0 - pl) + return u.softmax_xent(logits=student, labels=pl, + reduction=False, kl=True, axis=feat_axis) + elif kind == 'kl': + return t**2 * u.softmax_xent( + logits=student / t, + labels=jax.nn.softmax(teacher / t), + reduction=False, kl=True, axis=feat_axis) + elif kind == 'logsoftmax_euclidean': + logsoftmax_diff = ( + jax.nn.log_softmax(student, axis=feat_axis) - + jax.nn.log_softmax(teacher, axis=feat_axis)) + return jnp.sqrt( + jnp.sum(logsoftmax_diff * logsoftmax_diff, axis=feat_axis) + epsilon) + elif kind == 'agree': + def get_top_k(arr, k, ax): + return jax.lax.top_k(arr.swapaxes(ax, -1), k)[1].swapaxes(ax, -1) + return (get_top_k(student, k, feat_axis) == + get_top_k(teacher, 1, feat_axis)).sum(feat_axis) + else: + assert False, f'Unknown kind of distance {kind}.' + + +@lru_cache(None) +def get_dist_fn(**kw): + return partial(dist, **kw) + + +# To avoid re-compiling the function for every new instance of the same +# evaluator on a different dataset! +@lru_cache(None) +def get_eval_fn(student_teacher_fwd, what, mesh, distances): + """Produces eval function, also applies pmap.""" + @partial(jax.jit, out_shardings=NamedSharding(mesh, P())) + def _eval_fn(train_state, batch, mask): + (_, out_s), (_, out_t) = student_teacher_fwd(train_state, batch) + repr_s = u.tree_get(out_s, what[0]) + repr_t = u.tree_get(out_t, what[1]) + + # Let's flatten any non-vectors (eg feature-maps). + repr_s = einops.rearrange(repr_s, 'b ... -> b (...)') + repr_t = einops.rearrange(repr_t, 'b ... -> b (...)') + + all_ds = [] + # NOTE: we're gathering and returning all ; if this becomes too slow, we + # can change to compute and return summary stats later on. + for dist_fn in distances: + ds = dist_fn(repr_s, repr_t) + all_ds.append(ds) + all_masks = mask + return all_ds, all_masks + + return _eval_fn + + +class Evaluator: + """Distillation distance evaluator.""" + + def __init__( + self, + student_teacher_fwd, + data, + pp_fn, + distances, + what=('logits', 'logits'), + *, + devices, + **data_kw, + ): + data = ds_core.get(**data) + pp_fn = pp_builder.get_preprocess_fn(pp_fn) + prefetch = data_kw.pop('prefetch', 1) + self.ds, self.steps = input_pipeline.make_for_inference( + data.get_tfdata(ordered=True), + pp_fn, + num_ex_per_process=data.num_examples_per_process(), + **data_kw, + ) + self.data_iter = input_pipeline.start_global(self.ds, devices, prefetch) + dist_fns = tuple(get_dist_fn(**dist) for dist in distances) + self.dist_names = [ + '_'.join(f'{k}={v}' for k, v in dist.items()) for dist in distances + ] + mesh = jax.sharding.Mesh(devices, ('data',)) + self.eval_fn = get_eval_fn(student_teacher_fwd, what, mesh, dist_fns) + + def run(self, train_state): + """Computes all metrics.""" + all_ds = [[] for _ in self.dist_names] + for _, batch in zip(range(self.steps), self.data_iter): + mask = batch.pop('_mask') + batch_ds, batch_ms = self.eval_fn(train_state, batch, mask) + # All results are a replicated array shaped as follows: + # (local_devices, per_device_batch_size, elem_shape...) + # with each local device's entry being identical. + # So let's just take the first one to the host as numpy. + batch_ms = np.array(batch_ms) + for i, val in enumerate(batch_ds): + all_ds[i].append(np.array(val)[batch_ms == 1]) + for name, ds in zip(self.dist_names, all_ds): + ds = np.concatenate(ds) + yield f'{name}/all', ds + yield f'{name}/avg', np.mean(ds) + yield f'{name}/min', np.min(ds) + yield f'{name}/max', np.max(ds) diff --git a/big_vision/evaluators/proj/givt/coco_panoptic.py b/big_vision/evaluators/proj/givt/coco_panoptic.py new file mode 100644 index 0000000000000000000000000000000000000000..264f30a80ad51b131c25f59bad4aef2f19125730 --- /dev/null +++ b/big_vision/evaluators/proj/givt/coco_panoptic.py @@ -0,0 +1,401 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""COCO17 panoptic evaluation. + +jax.jit-compatible fork of the evaluator from evaluators/proj/uvim. +""" +import functools +import itertools +import json +import os +import tempfile +import time +from typing import Any +import zipfile + +from absl import flags +from absl import logging +from big_vision import input_pipeline +from big_vision import utils +from big_vision.datasets import core as ds_core +import big_vision.pp.builder as pp_builder +import jax +import jax.numpy as jnp +import numpy as np +from pycocotools.panopticapi import evaluation +import panopticapi_converters.twochannels2panoptic_coco_format as converter +import tensorflow as tf +import tensorflow_datasets as tfds + +from tensorflow.io import gfile + +# Temporary global flag to facilitate backwards compatability. +API = 'jit' + +ROOT = os.environ.get('COCO_DATA_DIR', '.') + +PANOPTIC_COCO_CATS_FILE = f'{ROOT}/panoptic_coco_categories.json' +PANOPTIC_2017 = { + 'train': f'{ROOT}/panoptic_train2017.json', + 'validation': f'{ROOT}/panoptic_val2017.json', +} + +PANOPTIC_GT_ZIP = { + 'train': f'{ROOT}/panoptic_train2017.zip', + 'validation': f'{ROOT}/panoptic_val2017.zip', +} + + +# Note: global to avoid jax re-compiling across different evaluator instances. +@functools.cache +def _get_predict_fn(predict_fn, mesh=None): + """Wrapper for jit-compiled predict function.""" + + # `out_shardings` annotation is needed because of the `all_gather` ops in the + # pmap implementation. + @functools.partial(jax.jit, + out_shardings=jax.sharding.NamedSharding( + mesh, jax.sharding.PartitionSpec())) + def _run_predict_fn(train_state, batch): + """Run predict_fn and gather all outputs on all devices.""" + y = predict_fn(train_state, batch) + res = { + 'image/id': batch['image/id'], + 'mask': batch['_mask'], + 'y': jnp.stack([y['semantics'], y['instances']], axis=-1), + } + return res + return _run_predict_fn + + +class Evaluator: + """Panoptic segmentation evaluator: calls official COCO API.""" + + def __init__( + self, + predict_fn, + pp_fn, + batch_size, + data=None, + cache_final=True, + cache_raw=False, + prefetch=1, + save_dir=None, + *, + devices, + ): + """Panoptic segmentation evaluator: calls official COCO API. + + Args: + predict_fn: jit-compilable function, which accepts arbitrary dictionaries + of parameters and data, where the data dictionary is produced by the + `pp_fn`. It is expected to output a 2-channel mask, where the first + channel encodes semantics, and the second channel encodes instance ids. + pp_fn: Preprocessing function, sepcified as string. + batch_size: Batch size. + data: Dict specifying name and split of the data set. Defaults to the + standard COCO (2017). + cache_final: Whether to cache the data after preprocessing - see + input_pipeline for details. + cache_raw: Whether to cache the raw data - see input_pipline for details. + prefetch: Number of batches to prefetch + save_dir: Directory to save the results in. + devices: List of jax devices. + """ + self.predict_fn = _get_predict_fn( + predict_fn, jax.sharding.Mesh(devices, ('devices',))) + + data_specs = dict(name='coco/2017_panoptic', + data_dir=None, split='validation') + data_specs.update(data or {}) + data = ds_core.get(**data_specs) + self.dataset, self.steps = input_pipeline.make_for_inference( + data.get_tfdata(ordered=True), batch_size=batch_size, + num_ex_per_process=data.num_examples_per_process(), + preprocess_fn=pp_builder.get_preprocess_fn(pp_fn), + cache_final=cache_final, cache_raw=cache_raw) + self.data_iter = input_pipeline.start_global( + self.dataset, devices, prefetch) + + # Only process 0 runs conversion to png and calls into coco api. + if jax.process_index() == 0: + self.result_dir = tempfile.TemporaryDirectory() + (self.gt_folder, self.gt_json, self.categories_json, + self.remap, self.size_map) = _prepare_ground_truth( + data_specs['name'], data_specs['split'], + data_specs.get('data_dir')) + if save_dir: + self.save_dir = save_dir.format(workdir=flags.FLAGS.workdir) + gfile.makedirs(self.save_dir) + else: + self.save_dir = None + + def _compute_png_predictions( + self, train_state: Any) -> Any: + """Computes predictions and converts then to png to optimize memory use.""" + count = 0 + logging.info('Panoptic eval: running inference.') + for batch in itertools.islice(self.data_iter, self.steps): + out = self.predict_fn(train_state, batch) + + if jax.process_index(): + continue + + out = jax.device_get(out) + mask = out['mask'] + pan_recs = out['y'][mask] + ids = out['image/id'][mask] + + for pan_rec, image_id in zip(pan_recs, ids): + sem = pan_rec[..., 0] + ins = pan_rec[..., 1] + + sem_remapped = np.array(sem) + for v in np.unique(sem): + sem_remapped[sem == v] = self.remap[v] + sem = sem_remapped + + pan_mask = np.stack([sem, ins, np.zeros_like(sem)], axis=-1) + pan_mask = utils.put_cpu(pan_mask) + pan_mask = _resize_nearest(pan_mask, self.size_map[image_id]) + pan_mask_png = tf.io.encode_png(pan_mask.astype('uint8')).numpy() + + fname = f'{self.result_dir.name}/{image_id:012d}.png' + with open(fname, 'wb') as f: + f.write(pan_mask_png) + count += 1 + + logging.log_every_n_seconds( + logging.INFO, 'Panoptic eval: processed %i examples so far.', 30, + count) + + if jax.process_index(): + return None + + logging.info('Panoptic eval: inference done. Processed %d examples.', count) + return self.result_dir + + def run(self, train_state): + """Run panoptic segmentation evaluation. + + Args: + train_state: pytree containing the model parameters. + + Yields: + Tuples consisting of metric name and value. + """ + # Note result_dir is constant, but files inside are mutated. + result_dir = self._compute_png_predictions(train_state) + + if jax.process_index(): + return + + if self.save_dir: + gfile.RecursivelyCopyDir(result_dir.name, self.save_dir, overwrite=True) + + with tempfile.TemporaryDirectory() as pred_folder, \ + tempfile.NamedTemporaryFile(mode='w') as pred_json: + + logging.info('Panoptic eval: running conversion.') + converter.converter( + source_folder=result_dir.name, + images_json_file=self.gt_json, + categories_json_file=self.categories_json, + segmentations_folder=pred_folder, + predictions_json_file=pred_json.name) + logging.info('Panoptic eval: conversion done.') + + logging.info('Panoptic eval: running metrics computation.') + res = evaluation.pq_compute(gt_json_file=self.gt_json, + gt_folder=self.gt_folder, + pred_json_file=pred_json.name, + pred_folder=pred_folder) + logging.info('Panoptic eval: metrics computation done.') + + for k in ['All', 'Stuff', 'Things']: + for m in ['pq', 'rq', 'sq']: + yield f'{k}_{m}', res[k][m] + + +def _prepare_ground_truth(dataset, split, data_dir): + if dataset == 'coco/2017_panoptic' and data_dir is None: + return _prepare_ground_truth_from_zipfiles(split) + else: + return _prepare_ground_truth_from_dataset(dataset, split, data_dir) + + +@functools.lru_cache(maxsize=None) +def _prepare_ground_truth_from_dataset(dataset, split, data_dir): + """Prepare ground truth from a tf.data.Dataset. + + Args: + dataset: TFDS-compatible dataset specification. + split: Data set split to use. + data_dir: Folder containing the data + + Returns: + A tuple containing the folder containing the ground-truth data, the + ground truth annotations loaded from json, the categories loaded form json, + a map for remapping, and a map mapping image id to image size. + + """ + tfds_dataset = tfds.builder( + dataset, data_dir=data_dir).as_dataset(split=split) + + categories_json = _make_local_copy(PANOPTIC_COCO_CATS_FILE) + with gfile.GFile(categories_json, 'rb') as f: + categories = json.loads(f.read()) + + # Build map from tfds class ids to COCO class ids. + remap = {0: 0} + with gfile.GFile(categories_json, 'r') as f: + remap = {**remap, **{(i + 1): x['id'] for i, x in enumerate(categories)}} + + gt_folder = tempfile.mkdtemp() + gfile.makedirs(gt_folder) + size_map = {} + annotations = [] + images = [] + for example in tfds_dataset: + image_id = int(example['image/id']) + panoptic_image = example['panoptic_image'] + ann_ids = example['panoptic_objects']['id'] + ann_labels = example['panoptic_objects']['label'] + ann_iscrowd = example['panoptic_objects']['is_crowd'] + ann_area = example['panoptic_objects']['area'] + + fname = f'{image_id:012d}.png' + with gfile.GFile(os.path.join(gt_folder, fname), 'wb') as f: + f.write(tf.io.encode_png(panoptic_image).numpy()) + + size_map[image_id] = (panoptic_image.shape[0], panoptic_image.shape[1]) + + segments_info = [] + for i in range(len(ann_ids)): + segments_info.append({ + 'id': int(ann_ids[i]), + 'category_id': remap[int(ann_labels[i] + 1)], + 'iscrowd': int(ann_iscrowd[i]), + 'area': int(ann_area[i]), + }) + + annotations.append({ + 'file_name': str(fname), + 'image_id': int(image_id), + 'segments_info': segments_info + }) + images.append({ + 'id': image_id, + 'file_name': f'{image_id:012d}.jpg', + }) + + # Write annotations.json needed for pq_compute. + gt_json = os.path.join(gt_folder, 'annotations.json') + with gfile.GFile(gt_json, 'wb') as f: + f.write(json.dumps({ + 'images': images, + 'annotations': annotations, + 'categories': categories, + })) + + return gt_folder, gt_json, categories_json, remap, size_map + + +def _prepare_ground_truth_from_zipfiles(split): + """Prepare ground truth from coco zip files. + + Args: + split: dataset split to prepare ground truth for. + + Returns: + A tuple containing the folder containing the ground-truth data, the ground + truth annotations loaded from json, the categories loaded form json, a map + for remapping, and a map mapping image id to image size. + """ + split_prefix = split.split('[')[0] + if split_prefix not in ('train', 'validation'): + raise ValueError(f'Split {split} not supported') + + # The following 4 calls are cached. This allows to save significant time + # in use cases like sweeping predict_fn hparams on the same run. + gt_json = _make_local_copy(PANOPTIC_2017[split_prefix]) + gt_folder = _make_local_unzip_copy(PANOPTIC_GT_ZIP[split_prefix]) + categories_json = _make_local_copy(PANOPTIC_COCO_CATS_FILE) + image_ids = _list_image_ids('coco/2017_panoptic', split) + + gt_folder = os.path.join( + gt_folder, 'panoptic_val2017' + if split_prefix == 'validation' else 'panoptic_train2017') + + # Build map from tfds class ids to COCO class ids. + remap = {0: 0} + with gfile.GFile(categories_json, 'r') as f: + remap = {**remap, **{(i + 1): x['id'] for i, x in enumerate(json.load(f))}} + + # Filters gt_json to contain only annotations for images in dataset. + with gfile.GFile(gt_json) as f: + data = json.load(f) + logging.info( + 'Panoptic eval: pre-filter %d annotations.', + len(data['annotations']) + ) + data['images'] = [x for x in data['images'] if x['id'] in image_ids] + data['annotations'] = [ + x for x in data['annotations'] if x['image_id'] in image_ids + ] + logging.info( + 'Panoptic eval: post-filter %d annotations.', + len(data['annotations']) + ) + filtered_gt_json = tempfile.NamedTemporaryFile(delete=False).name + with open(filtered_gt_json, 'w') as f: + json.dump(data, f) + + # Precompute images sizes. + size_map = {x['id']: (x['height'], x['width']) for x in data['images']} + + return gt_folder, filtered_gt_json, categories_json, remap, size_map + + +@functools.lru_cache(maxsize=None) +def _list_image_ids(dataset, split): + d = tfds.load(dataset, split=split).map(lambda x: x['image/id']) + return frozenset(d.as_numpy_iterator()) + + +@functools.lru_cache(maxsize=None) +def _make_local_copy(fname) -> str: + start = time.monotonic() + local_file = tempfile.NamedTemporaryFile(delete=False) + gfile.copy(fname, local_file.name, overwrite=True) + logging.info('Copy %s in %d seconds.', fname, time.monotonic() - start) + return local_file.name + + +@functools.lru_cache(maxsize=None) +def _make_local_unzip_copy(fname) -> str: + start = time.monotonic() + folder = tempfile.mkdtemp() + with tempfile.NamedTemporaryFile() as tmp_zip_file: + gfile.copy(fname, tmp_zip_file.name, overwrite=True) + with zipfile.ZipFile(tmp_zip_file.name, 'r') as f: + f.extractall(folder) + logging.info('Copy %s in %d seconds.', fname, time.monotonic() - start) + return folder + + +@utils.jit_cpu(static_argnums=(1,)) +def _resize_nearest(image, shape): + return jax.image.resize(image, shape + image.shape[-1:], 'nearest') diff --git a/big_vision/evaluators/proj/givt/nyu_depth.py b/big_vision/evaluators/proj/givt/nyu_depth.py new file mode 100644 index 0000000000000000000000000000000000000000..23c6b3559703ed9af18bd1bf7de5371abe458e1c --- /dev/null +++ b/big_vision/evaluators/proj/givt/nyu_depth.py @@ -0,0 +1,191 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Evaluation for NYU depth. + +jax.jit-compatible fork of the evaluator from evaluators/proj/uvim. + +At evaluation time the ground truth is cropped and clipped. Values outside of +the test crop or clipping range are not included in eval calculations. + +In this evaluator, it is assume that the groud truth is already cropped, so the +entire image is evaluated. However, the evaluator does perform the clipping. + +Reference implementations: + https://github.com/zhyever/Monocular-Depth-Estimation-Toolbox/blo(internal link)a0f341244260ff61541191a613dd74bc/depth/datasets/nyu.py + https://github.com/vinvino02/GLPDepth/blob/7f3c78df4ecd6e7c79fd0c4b73c95d61f4aa2121/code/utils/metrics.py + https://github.com/shariqfarooq123/AdaBins/blob/2fb686a66a304f0a719bc53d77412460af97fd61/evaluate.py +""" + +import functools +import itertools + +from big_vision import input_pipeline +from big_vision import utils +from big_vision.datasets import core as ds_core +import big_vision.pp.builder as pp_builder +import jax +import jax.numpy as jnp +import numpy as np + +# Temporary global flag to facilitate backwards compatability. +API = "jit" + + +# Note: global to avoid jax re-compiling across different evaluator instances. +@functools.cache +def _get_predict_fn(predict_fn, mesh=None): + """Wrapper for jit-compiled predict function.""" + + # `out_shardings` annotation is needed because of the `all_gather` ops in the + # pmap implementation. + @functools.partial(jax.jit, + out_shardings=jax.sharding.NamedSharding( + mesh, jax.sharding.PartitionSpec())) + def _run_predict_fn(train_state, batch): + """Run predict_fn and gather all outputs on all devices.""" + pred = predict_fn(train_state, batch) + return {"mask": batch["_mask"], + "gt": jnp.squeeze(batch["ground_truth"], axis=-1), + "y": pred["depth"]} + return _run_predict_fn + + +class Evaluator: + """Evaluator for NYU depth.""" + + def __init__(self, + predict_fn, + pp_fn, + batch_size, + data, + cache_final=True, + cache_raw=False, + prefetch=1, + min_depth=1e-3, + max_depth=10, + *, + devices): + """Evaluator for NYU depth. + + Args: + predict_fn: jit-compilable function, accepts arbitrary dictionaries of + parameters and data, where the data dictionary is produced by the + `pp_fn` op. It is expected to output a dict with `depth` containing an + 2D array with the predicted depth. The prediction is resized to the + ground_truth size with nearest neighbour. + pp_fn: Preprocessing function, sepcified as string. `pp_fn` must also + output a 'ground_truth' as a 2D array of ground truth. Fruther, it has + to apply a crop, if one wants to compute metrics with the eval crop + typically used for NYU Depth metrics. + batch_size: Batch size. + data: Dict specifying name and split of the data set. Defaults to the + standard COCO (2017). + cache_final: Whether to cache the data after preprocessing - see + input_pipeline for details. + cache_raw: Whether to cache the raw data - see input_pipline for details. + prefetch: Number of batches to prefetch + min_depth: Minimum depth value. + max_depth: Maximum depth value. + devices: List of jax devices. + """ + self.min_depth = min_depth + self.max_depth = max_depth + self.predict_fn = _get_predict_fn( + predict_fn, jax.sharding.Mesh(devices, ("devices",))) + + data = ds_core.get(**data) + self.dataset, self.steps = input_pipeline.make_for_inference( + data.get_tfdata(ordered=True), batch_size=batch_size, + num_ex_per_process=data.num_examples_per_process(), + preprocess_fn=pp_builder.get_preprocess_fn(pp_fn), + cache_final=cache_final, cache_raw=cache_raw) + self.data_iter = input_pipeline.start_global( + self.dataset, devices, prefetch) + + def run(self, train_state): + """Run NYU depth eval. + + Args: + train_state: pytree containing the model parameters. + + Yields: + Tuples consisting of metric name and value. + """ + rmses = [] + abs_res = [] + abs_logs = [] + d1s = [] + d2s = [] + d3s = [] + for batch in itertools.islice(self.data_iter, self.steps): + # Outputs is a dict with values shaped (gather/same, devices, batch, ...) + out = self.predict_fn(train_state, batch) + + if jax.process_index(): # Host0 gets all preds and does eval. + continue + + out = jax.device_get(out) + # Then the bool-indexing with mask resulting in flat (global_batch, ...) + out = jax.tree_map(lambda x: x[out["mask"]], out) # pylint:disable=cell-var-from-loop + + for gt, pred in zip(out["gt"], out["y"]): + # put_cpu and conversion to numpy arrays below to avoid unwanted + # host-to-device transfers + pred, gt = utils.put_cpu((pred, gt)) + pred = _resize_nearest(pred, (gt.shape[0], gt.shape[1])) + pred, gt = np.array(pred), np.array(gt) + valid_mask = np.logical_and(gt > self.min_depth, gt < self.max_depth) + + rmses.append(_compute_rmse(gt[valid_mask], pred[valid_mask])) + abs_res.append(_compute_abs_re(gt[valid_mask], pred[valid_mask])) + abs_logs.append(_compute_abs_log(gt[valid_mask], pred[valid_mask])) + d1s.append(_compute_delta(gt[valid_mask], pred[valid_mask], order=1)) + d2s.append(_compute_delta(gt[valid_mask], pred[valid_mask], order=2)) + d3s.append(_compute_delta(gt[valid_mask], pred[valid_mask], order=3)) + + if jax.process_index(): # Host0 gets all preds and does eval. + return + + yield "RMSE", np.mean(rmses) + yield "abs_RE", np.mean(abs_res) + yield "log10", np.mean(abs_logs) + yield "delta1", np.mean(d1s) + yield "delta2", np.mean(d2s) + yield "delta3", np.mean(d3s) + + +@utils.jit_cpu(static_argnums=(1,)) +def _resize_nearest(image, shape): + return jax.image.resize(image, shape, "nearest") + + +def _compute_rmse(gt, pred): + diff = gt - pred + return np.sqrt(np.mean(np.power(diff, 2))) + + +def _compute_abs_re(gt, pred): + diff = np.abs(gt - pred) + return np.mean(diff / gt) + + +def _compute_abs_log(gt, pred): + diff = np.abs(np.log10(gt) - np.log10(pred)) + return np.mean(diff) + + +def _compute_delta(gt, pred, order): + rel_diff = np.maximum(gt / pred, pred / gt) + return np.sum(rel_diff < 1.25**order) / rel_diff.size diff --git a/big_vision/evaluators/proj/givt/save_predictions.py b/big_vision/evaluators/proj/givt/save_predictions.py new file mode 100644 index 0000000000000000000000000000000000000000..d72c48e31ee754cc2722c8e062db064c5f110264 --- /dev/null +++ b/big_vision/evaluators/proj/givt/save_predictions.py @@ -0,0 +1,118 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Evaluator to save predictions.""" +# pylint: disable=consider-using-from-import +import functools +import io # pylint: disable=unused-import +import itertools +import os + +from absl import flags +from absl import logging +from big_vision import input_pipeline +from big_vision.datasets import core as ds_core +import big_vision.pp.builder as pp_builder +import big_vision.utils as u +import jax +import numpy as np + +from tensorflow.io import gfile # pylint: disable=unused-import + +# Temporary global flag to facilitate backwards compatability. +API = 'jit' + + +# Note: global to avoid jax re-compiling across different evaluator instances. +@functools.cache +def _get_predict_fn(predict_fn, mesh=None): + """Wrapper for jit-compiled predict function.""" + + # `out_shardings` annotation is needed because of the `all_gather` ops in the + # pmap implementation. + @functools.partial(jax.jit, + out_shardings=jax.sharding.NamedSharding( + mesh, jax.sharding.PartitionSpec())) + def _run_predict_fn(train_state, batch): + """Run predict_fn and gather all outputs on all devices.""" + y = predict_fn(train_state, batch) + return {'inputs': batch, 'outputs': y, 'mask': batch['_mask']} + return _run_predict_fn + + +class Evaluator: + """Save predictions in "{FLAGS.workdir}/{outfile}". + + Results can then be easily inspected in a notebook such as: + + ``` + results = utils.load_checkpoint("") + inputs, outputs = (results["inputs"], results["outputs"]) + ``` + """ + + def __init__(self, predict_fn, pp_fn, batch_size, data, outfile, + cache_final=True, cache_raw=False, prefetch=1, *, devices): + self.predict_fn = _get_predict_fn( + predict_fn, jax.sharding.Mesh(devices, ('devices',))) + + # Prepare data for each process and pad with zeros so all processes have the + # same number of batches. + data = ds_core.get(**data) + self.dataset, self.steps = input_pipeline.make_for_inference( + data.get_tfdata(ordered=True), batch_size=batch_size, + num_ex_per_process=data.num_examples_per_process(), + preprocess_fn=pp_builder.get_preprocess_fn(pp_fn), + cache_final=cache_final, cache_raw=cache_raw) + self.data_iter = input_pipeline.start_global( + self.dataset, devices, prefetch) + + self.path = os.path.join(flags.FLAGS.workdir, outfile) + + def run(self, train_state): + """Compute all predictions, gather in main host and save in outfile.""" + count = 0 + outputs = [] + for batch in itertools.islice(self.data_iter, self.steps): + out = self.predict_fn(train_state, batch) + if jax.process_index(): + continue + + out = jax.device_get(out) + # Note that we need to access `out['mask']` here `x` does not have that + # field during the tree map. + out = jax.tree_map(lambda x: x[out['mask']], out) # pylint: disable=cell-var-from-loop + count += out['mask'].shape[0] + out.pop('mask') + outputs.append(out) + + logging.log_every_n_seconds( + logging.INFO, 'Save predictions: processed %i examples so far.', 30, + count) + + if jax.process_index(): + return + + logging.info('Save predictions: processed %d examples.', count) + + # Actually save in filesystem. + outputs = jax.tree_map(lambda *x: np.concatenate(x, axis=0), *outputs) + names_and_vals, _ = u.tree_flatten_with_names(outputs) + io_buffer = io.BytesIO() + np.savez_compressed(io_buffer, **{k: v for k, v in names_and_vals}) + with gfile.GFile(self.path, 'wb') as f: + f.write(io_buffer.getvalue()) + return + + yield None # pylint: disable=unreachable diff --git a/big_vision/evaluators/proj/image_text/contrastive.py b/big_vision/evaluators/proj/image_text/contrastive.py new file mode 100644 index 0000000000000000000000000000000000000000..0726d2066cb7c406aec47c508e195b2f40399ff5 --- /dev/null +++ b/big_vision/evaluators/proj/image_text/contrastive.py @@ -0,0 +1,99 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Evaluator for the contrastive task. + +DON'T COMPARE ACROSS RUNS, use for training health monitoring only. + +Note that this evaluator's `ncorrect_minibatch` is only a rough proxy for +training progress and does not report the actual `ncorrect`: when the same +labels found multiple times in a batch, then the reported value is biased +towards lower values. + +Also note that the `ncorrect_minibatch` is a function of batch size (it's a lot +easier to find correct values in small batches). +""" +import functools + +from big_vision import input_pipeline +import big_vision.datasets.core as ds_core +import big_vision.pp.builder as pp_builder +import big_vision.utils as u +import jax +import jax.numpy as jnp +import numpy as np + + +def _all_gather(z): + """All gather and flatten first two dims.""" + gather_flat = lambda x: jnp.concatenate(jax.lax.all_gather(x, "batch"), 0) + return jax.tree_map(gather_flat, z) + + +# To avoid re-compiling the function for every new instance of the same +# evaluator on a different dataset! +@functools.lru_cache(None) +def get_eval_fn(predict_fn, use_global_batch): + """Produces eval function, also applies pmap.""" + + @functools.partial(jax.pmap, axis_name="batch") + def _eval_fn(params, images, labels, mask): + zimg, ztxt, extras = predict_fn(params, images, labels) + + if use_global_batch: + zimg, ztxt, mask = _all_gather((zimg, ztxt, mask)) + + # Temperature won't affect ranking for accuracy, but impacts loss magnitude. + losses, measurements = u.bidirectional_contrastive_loss( + zimg, ztxt, extras["t"], mask, reduction=False) + l = jax.lax.psum(losses * mask, axis_name="batch") + c = jax.lax.psum(measurements["ncorrect"] * mask, axis_name="batch") + n = jax.lax.psum(mask, axis_name="batch") + return c, l, n + + return _eval_fn + + +class Evaluator: + """Contrastive evaluator.""" + + def __init__(self, predict_fn, data, pp_fn, batch_size, + use_global_batch, cache_final=True, + cache_raw=False, prefetch=1, label_key="labels"): + data = ds_core.get(**data) + pp_fn = pp_builder.get_preprocess_fn(pp_fn) + self.ds, self.steps = input_pipeline.make_for_inference( + data.get_tfdata(ordered=True), pp_fn, batch_size, + num_ex_per_process=data.num_examples_per_process(), + cache_final=cache_final, cache_raw=cache_raw) + self.data_iter = input_pipeline.start_input_pipeline(self.ds, prefetch) + self.eval_fn = get_eval_fn(predict_fn, use_global_batch) + self.label_key = label_key + + def run(self, params): + """Computes all metrics.""" + l, c, nseen = 0, 0, 0 + for _, batch in zip(range(self.steps), self.data_iter): + labels, mask = batch.pop(self.label_key), batch.pop("_mask") + batch_ncorrect, batch_losses, batch_n = self.eval_fn( + params, batch["image"], labels, mask) + # All results are a replicated array shaped as follows: + # (local_devices, per_device_batch_size, elem_shape...) + # with each local device's entry being identical as they got psum'd. + # So let's just take the first one to the host as numpy. + c += np.sum(np.array(batch_ncorrect[0])) + l += np.sum(np.array(batch_losses[0])) + nseen += np.sum(np.array(batch_n[0])) + yield ("ncorrect_minibatch", c / nseen) + yield ("loss", l / nseen) diff --git a/big_vision/evaluators/proj/image_text/discriminative_classifier.py b/big_vision/evaluators/proj/image_text/discriminative_classifier.py new file mode 100644 index 0000000000000000000000000000000000000000..e233564404c17ef9e8d898d13a57a96148f6c9ae --- /dev/null +++ b/big_vision/evaluators/proj/image_text/discriminative_classifier.py @@ -0,0 +1,440 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Discriminative zero-shot classification evaluator. +""" + +import functools +import time + +from absl import logging +from big_vision import input_pipeline +from big_vision import utils +from big_vision.evaluators.proj.image_text import prompt_engineering +from big_vision.pp import ops_general # pylint: disable=unused-import +from big_vision.pp import ops_image # pylint: disable=unused-import +import big_vision.pp.builder as pp_builder +import jax +import jax.numpy as jnp +from jax.sharding import NamedSharding +from jax.sharding import PartitionSpec as P +import numpy as np +import tensorflow as tf +import tensorflow_datasets as tfds + + +# Temporary global flag to facilitate backwards compatability. Will be removed +# by the end of year 2023. +API = "jit" + + +DATASET_NAMES = ("imagenet2012", "cifar100", "oxford_iiit_pet") +DEFAULT_OVERRIDES = ( + ("imagenet2012", ( + ("class_names", "clip"), + ("split", "validation"), + )), + ) + + +def _with_infinite_padding(dataset): + """Adds "infinite padding" to the dataset.""" + filler_element = tf.nest.map_structure( + lambda spec: tf.zeros(spec.shape, spec.dtype)[None], dataset.element_spec) + filler_element["mask"] = [False] + filler_dataset = tf.data.Dataset.from_tensor_slices(filler_element) + dataset = dataset.map( + lambda features: dict(mask=True, **features), + num_parallel_calls=tf.data.experimental.AUTOTUNE) + return dataset.concatenate(filler_dataset.repeat(None)) + + +# This is needed so retrieval_test can replace dataset info. +def _get_dataset_info(builder): + return builder.info + + +def prepare_datasets(img_dataset, + class_names, + *, + prompt_templates, + pp_img, + pp_txt, + cache_final=False, + pre_filter_fn=None, + class_name_offset=0): + """Returns unbatched `ds_images, ds_texts` datasets.""" + + assert prompt_templates, "Must specify prompt templates (e.g. simply ['{}'])" + + def expand_aliases(idx, class_name): + class_names = tf.strings.split(class_name, ",") + return tf.data.Dataset.from_tensor_slices(( + tf.repeat([idx + class_name_offset], len(class_names), axis=0), + class_names, + )) + + def add_prompts(idx, class_name): + return tf.data.Dataset.from_tensor_slices({ + "label": tf.repeat([idx], len(prompt_templates), axis=0), + "class_name": tf.repeat([class_name], len(prompt_templates), axis=0), + "prompt_template": prompt_templates, + }) + + def substitute_prompt(features): + parts = tf.strings.split(features["prompt_template"], "{}") + tf.debugging.assert_equal(len(parts), 2, features["prompt_template"]) + return { + "label": features["label"], + "texts": tf.strings.join([parts[0], features["class_name"], parts[1]]) + } + + if pre_filter_fn: + img_dataset = img_dataset.filter(pre_filter_fn) + ds_images = img_dataset.map( + pp_builder.get_preprocess_fn(f"{pp_img}|keep('label', 'image')")) + ds_texts = tf.data.Dataset.from_tensor_slices(list(class_names)).enumerate( + ).flat_map(expand_aliases).flat_map(add_prompts).map(substitute_prompt).map( + pp_builder.get_preprocess_fn(f"{pp_txt}|keep('label', 'labels')")) + + if cache_final: + ds_images, ds_texts = ds_images.cache(), ds_texts.cache() + + return ds_images, ds_texts + + +def _split_and_batch(dataset_name, data_dir, class_names, batch_size, split, + get_ds): + """Splits dataset, calls `get_ds` and returns padded + batched datasets.""" + assert not batch_size % jax.device_count(), ( + f"batch_size={batch_size} % jax.device_count()={jax.device_count()}") + builder = tfds.builder(dataset_name, data_dir=data_dir) + + # Split class names (last process gets remainder). + if len(class_names) < jax.process_count(): + # See (internal link) for more details. + class_names += [""] * (jax.process_count() - len(class_names)) + per_process = len(class_names) // jax.process_count() + class_name_offset = per_process * jax.process_index() + if jax.process_index() == jax.process_count() - 1: + class_names = class_names[class_name_offset:] + else: + class_names = class_names[class_name_offset:class_name_offset + per_process] + + ds_images, ds_texts = get_ds( + builder.as_dataset(split=tfds.split_for_jax_process(split)), + class_names, + class_name_offset=class_name_offset) + return ( + _with_infinite_padding(ds_images).batch(batch_size), + _with_infinite_padding(ds_texts).batch(batch_size), + ) + + +def _average_embeddings(embeddings, *, labels, num_classes, normalize): + """Computes per-class averages of `embeddings`.""" + assert embeddings.ndim == 2, f"Expected {embeddings.ndim}==2" + assert labels.ndim == 1, f"Expected {labels.ndim}==1" + assert len(labels) == len(embeddings), ( + f"Expected {len(labels)}=={len(embeddings)}") + + byidx = [[] for _ in range(num_classes)] + for label, embedding in zip(labels, embeddings): + byidx[label].append(embedding) + missing = set(range(num_classes)) - set( + idx for idx, embs in enumerate(byidx) if len(embs)) + assert not missing, f"Classes without embeddings: {missing}" + embeddings = [np.array(embedding).mean(axis=0) for embedding in byidx] + embeddings = np.stack(embeddings) + + assert len(embeddings) == num_classes + if normalize: + embeddings /= 1e-8 + np.linalg.norm(embeddings, axis=1, keepdims=True) + return embeddings + + +class Evaluator: + """Zero-shot classification evaluator.""" + + def __init__(self, + predict_fn, + *, + batch_size, + devices, + dataset_names=DATASET_NAMES, + data_dir=None, + class_names="dataset_info:label", + split="test", + prompt_templates="clip_paper", + canonicalize=True, + pp_img="resize(224)|value_range(-1,1)", + pp_txt="tokenize(max_len=16, eos='sticky', " + "pad_value=1, inkey='texts', outkey='labels')", + cache_final=False, + pre_filter_fn=None, + first_class_name_only=True, + dataset_overrides=DEFAULT_OVERRIDES, + async_delay=1): + """Initializes a new zero-shot classification evaluator. + + See `prepare_datasets()` for details on how the dataset is pre-processed. + + Args: + predict_fn: Prediction function with signature + `zimg, ztxt, out = predict_fn(params, images, texts)` + batch_size: Global batch size. + devices: list of devices. + dataset_names: Names of TFDS datasets to evaluate on. + data_dir: Optional argument to `tfds.builder()`. + class_names: Usually specified as a string that is interpreted by + `prompt_engineering.get_class_names()` to look up class names. + Alternatively, this attribute can be a list of class names (using "," + to separate multiple aliases). + split: Which dataset split to use for evaluation. + prompt_templates: Specifies which prompt templates to use. See module + big_vision.evaluators.proj.image_text.prompte_engineering + for valid values. + canonicalize: Whether class names and prompt templates should be + canonicalized. See `prompt_engineering.py` for details. + pp_img: Preprocessing string for images. Preprocessed features should + contain key "image" with value that can be batched and is suitable for + the `images` argument of `predict_fn` input``. + pp_txt: Preprocessing string for texts. Can expect "texts" key as an input + (shape=[], dtype=string), and is expected to produce "labels" key that + is suitable for the `text` argument of `predict_fn` input. + cache_final: Wether preprocesse dataset should be cached. + pre_filter_fn: Predicate applied to the dataset for filtering records. + first_class_name_only: Whether only the first class name should be + considered (i.e. not using any aliases). + dataset_overrides: Mapping `dataset_name` to an optional dictionary that + can override parameters `dataset_name`, `data_dir`, `pp_img`, `pp_txt`, + `class_names`, `split`, `pre_filter_fn`, and the extra + `class_names_dataset_name`. + Works with tuple/dict of tuples/dicts. + async_delay: How many steps to wait before checking if all hosts have + finished their batch. A value > 1 allows for more parallelized + processing, but will results in more unnecessary steps with padded data. + """ + t0 = time.monotonic() + self.datasets = {} + self.prompt_templates = prompt_engineering.get_prompt_templates( + prompt_templates, canonicalize=canonicalize) + self._axis_name = "batch" + dataset_overrides = {k: dict(v) for k, v in dict(dataset_overrides).items()} + + for dataset_name in dataset_names: + overrides = dataset_overrides.pop(dataset_name, {}) + dataset_name_ = overrides.pop("dataset_name", dataset_name) + data_dir_ = overrides.pop("data_dir", data_dir) + class_names_dataset_name = overrides.pop("class_names_dataset_name", + dataset_name_) + class_names_ = overrides.pop("class_names", class_names) + class_names_ = prompt_engineering.get_class_names( + dataset_name=class_names_dataset_name, + source=class_names_, + canonicalize=canonicalize) + pp_img_ = overrides.pop("pp_img", pp_img) + pp_txt_ = overrides.pop("pp_txt", pp_txt) + cache_final_ = overrides.pop("cache_final", cache_final) + split_ = overrides.pop("split", split) + pre_filter_fn_ = overrides.pop("pre_filter_fn", pre_filter_fn) + prompt_templates_ = overrides.pop("prompt_templates", prompt_templates) + canonicalize_ = overrides.pop("canonicalize", canonicalize) + prompt_templates_ = prompt_engineering.get_prompt_templates( + prompt_templates_, canonicalize=canonicalize_) + assert not overrides, f"Unknown overrides {dataset_name}: {overrides}" + + if first_class_name_only: + class_names_ = [name.split(",")[0] for name in class_names_] + ds_images, ds_texts = _split_and_batch( + dataset_name=dataset_name_, + data_dir=data_dir_, + class_names=class_names_, + batch_size=batch_size, + split=split_, + get_ds=functools.partial( + prepare_datasets, + pp_img=pp_img_, + pp_txt=pp_txt_, + cache_final=cache_final_, + pre_filter_fn=pre_filter_fn_, + prompt_templates=prompt_templates_)) + self.datasets[dataset_name] = dict( + images=ds_images, texts=ds_texts, class_names=class_names_, + dataset_name=dataset_name_, split=split_) + + assert not dataset_overrides, f"Extra overrides: {dataset_overrides}" + + def embed_texts(train_state, texts): + """Returns text embeddings.""" + _, ztxt, _ = predict_fn(train_state, {"labels": texts}) + return ztxt + + def count_correct(train_state, return_embeddings, *, mask, labels, image, + ztxt): + """Returns count of correct predictions (and optionally embeddings).""" + zimg, _, _ = predict_fn(train_state, {"image": image}) + best_txt = (zimg @ ztxt.T).argmax(axis=1) + # labels has format [[1, -1, -1], [5, -1, -1], [7, 2, -1], ...] + # so here we count "any" correct, such that the counting matches the + # multilabel scenario described in "are we done with imagenet" + # (http://arxiv.org/abs/2006.07159) section 3.1 + if labels.ndim == 1: + labels = labels[..., None] + assert labels.ndim == 2, labels.shape + matching = (best_txt[:, None] == labels).sum(axis=1) + correct = jnp.where(mask, (matching > 0).astype(jnp.int32), 0).sum() + correct = jnp.sum(correct) + if return_embeddings: + return correct, zimg + else: + return correct, None + + self.devices = devices + self.mesh = jax.sharding.Mesh(devices, ("devices",)) + + self._embed_texts_p = jax.jit( + embed_texts, out_shardings=NamedSharding(self.mesh, P())) + self._count_correct_p = jax.jit(count_correct, static_argnums=(1,), + out_shardings=NamedSharding(self.mesh, P())) + self._count_p = jax.jit(jnp.sum, + out_shardings=NamedSharding(self.mesh, P())) + self._all_gather_p = jax.jit( + lambda x: x, out_shardings=NamedSharding(self.mesh, P())) + + self._compiled = set() + assert async_delay > 0, f"async_delay must be >0, not {async_delay}" + self._async_delay = async_delay + logging.info("Initialized evaluator in %.1f seconds", time.monotonic() - t0) + + def _embed_texts(self, train_state, dataset_name): + """Returns per-class averaged text embeddings.""" + t0 = time.monotonic() + logging.info("Starting text embedding...") + ns = [] + embeddings = [] + data = {"label": [], "mask": []} + + ds_b = input_pipeline.start_global( + self.datasets[dataset_name]["texts"], self.devices) + for batch in ds_b: + ns.append(jax.device_get(self._count_p(batch["mask"]))) + if len(ns) >= self._async_delay and ns[-self._async_delay] == 0: + break + + embeddings.append(jax.device_get(self._embed_texts_p( + train_state, batch["labels"]))) + for name in data: + data[name].append(jax.device_get(self._all_gather_p(batch[name]))) + + if self._embed_texts_p not in self._compiled: + logging.info("Compiled text embeddings in %.1fs", time.monotonic() - t0) + t0 = time.monotonic() + self._compiled.add(self._embed_texts_p) + + ns = np.array(ns) + n = ns.sum() + data["embedding"] = embeddings + data = {k: np.concatenate(v, axis=0) for k, v in data.items()} + mask = data.pop("mask").astype(bool) + data = {k: v[mask] for k, v in data.items()} + data["average_embedding"] = _average_embeddings( + data["embedding"], + labels=data["label"], + num_classes=len(self.datasets[dataset_name]["class_names"]), + normalize=True) + + logging.info("Embedded %s text in %d steps - ...%s", dataset_name, len(ns), + ns[-10:]) + logging.info("Totalling %d text in %.1fs", n, time.monotonic() - t0) + logging.info("Total texts embeddings size %.1fM", + data["embedding"].nbytes / 1e6) + return data + + def evaluate(self, + train_state, + dataset_name, + *, + return_embeddings=False): + """Returns evaluation results.""" + texts = self._embed_texts(train_state, dataset_name) + ztxt_p = texts["average_embedding"] + ztxt_p = utils.reshard(ztxt_p, NamedSharding(self.mesh, P())) + + t0 = time.monotonic() + logging.info("Starting image embedding...") + + ns = [] + embeddings = [] + corrects = [] + data = {"mask": [], "label": []} if return_embeddings else {} + + ds_b = input_pipeline.start_global( + self.datasets[dataset_name]["images"], self.devices) + for batch in ds_b: + ns.append(jax.device_get(self._count_p(batch["mask"]))) + if len(ns) >= self._async_delay and ns[-self._async_delay] == 0: + break + + labels = batch["label"] + correct_p, embs_p = self._count_correct_p( + train_state, + return_embeddings, + mask=batch["mask"], + labels=labels, + image=batch["image"], + ztxt=ztxt_p, + ) + corrects.append(jax.device_get(correct_p)) + if self._count_correct_p not in self._compiled: + logging.info("Compiled image embeddings in %.1fs", + time.monotonic() - t0) + t0 = time.monotonic() + self._compiled.add(self._count_correct_p) + + if return_embeddings: + embeddings.append(jax.device_get(self._all_gather_p(embs_p))) + for name in data: + data[name].append(jax.device_get(self._all_gather_p(batch[name]))) + + ns = np.array(ns) + n = ns.sum() + correct = np.array(corrects).sum() + + logging.info("Embedded %s image in %d steps - ...%s", dataset_name, len(ns), + ns[-10:]) + logging.info("Totalling %d image in %.1fs", n, time.monotonic() - t0) + ret = { + "accuracy": correct / n, + "correct": correct, + "count": n, + } + logging.info("Dataset %s, results %s", dataset_name, ret) + + if return_embeddings: + data["embedding"] = embeddings + data = {k: np.concatenate(v, axis=0) for k, v in data.items()} + logging.info("Total images embeddings size %.1fM", + data["embedding"].nbytes / 1e6) + mask = data.pop("mask").astype(bool) + ret["images"] = {k: v[mask] for k, v in data.items()} + ret["texts"] = texts + + return ret + + def run(self, train_state): + """Returns metrics.""" + return [(f"{dataset_name}_accuracy", + self.evaluate(train_state, dataset_name)["accuracy"]) + for dataset_name in self.datasets] diff --git a/big_vision/evaluators/proj/image_text/discriminative_classifier_test.py b/big_vision/evaluators/proj/image_text/discriminative_classifier_test.py new file mode 100644 index 0000000000000000000000000000000000000000..06f4fe28910191555779c21a838032ae8bc506cd --- /dev/null +++ b/big_vision/evaluators/proj/image_text/discriminative_classifier_test.py @@ -0,0 +1,237 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for discriminative_classifier.""" + +from unittest import mock + +from big_vision.evaluators.proj.image_text import discriminative_classifier +from big_vision.pp import ops_general # pylint: disable=unused-import +from big_vision.pp import ops_image # pylint: disable=unused-import +from big_vision.pp.registry import Registry +import flax.linen as nn +import jax +import jax.numpy as jnp +import numpy as np +import tensorflow as tf +import tensorflow_datasets as tfds + + +@Registry.register("preprocess_ops.test_texts2labels") +def _get_test_texts2labels(): + + def pp(features): + features["labels"] = tf.strings.to_number(features["texts"]) + return features + + return pp + + +@Registry.register("preprocess_ops.copy_from") +def _get_copy_from(**key_map): + + def copy_from(d): + d = dict(d) + for k1, k2 in key_map.items(): + d[k1] = d[k2] + return d + + return copy_from + + +class _Model(nn.Module): + + @nn.compact + def __call__(self, image, texts): + self.param("x", lambda _: 0.) + + def z(x): + if x is not None: + # Note that the returned vector is most similar with other vectors + # generated from the same underlying `x[:]`. + return jnp.stack([jnp.cos(x / 10.), jnp.sin(x / 10.)]).T + + if texts is not None: + texts %= 5 # For testing `pre_filter_fn` below. + return z(image), z(texts), None + + + class DiscriminativeClassifierTest(tf.test.TestCase): + + def test_prepare_datasets(self): + + def generator(): + yield { + "image": tf.ones([5, 5, 3], tf.float32), + "label": 1, + } + yield { + "image": tf.ones([4, 4, 3], tf.float32), + "label": 2, + } + + ds = tf.data.Dataset.from_generator( + generator, + output_signature={ + "image": tf.TensorSpec(shape=[None, None, 3], dtype=tf.float32), + "label": tf.TensorSpec(shape=[], dtype=tf.int64), + }) + class_names = [ + "class1,class1a", + "class2", + ] + prompt_templates = [ + "test {}", + "test {} test", + ] + ds_img, ds_txt = discriminative_classifier.prepare_datasets( + ds, + class_names, + prompt_templates=prompt_templates, + pp_img="resize(2)", + pp_txt="copy_from(labels='texts')", + ) + + it_img = iter(ds_img) + batch = next(it_img) + self.assertAllEqual(1, batch["label"]) + self.assertAllEqual(tf.ones([2, 2, 3]), batch["image"]) + batch = next(it_img) + self.assertAllEqual(2, batch["label"]) + self.assertAllEqual(tf.ones([2, 2, 3]), batch["image"]) + + it_txt = iter(ds_txt) + batch = next(it_txt) + self.assertAllEqual(0, batch["label"]) + self.assertAllEqual("test class1", batch["labels"]) + batch = next(it_txt) + self.assertAllEqual(0, batch["label"]) + self.assertAllEqual("test class1 test", batch["labels"]) + batch = next(it_txt) + self.assertAllEqual(0, batch["label"]) + self.assertAllEqual("test class1a", batch["labels"]) + batch = next(it_txt) + self.assertAllEqual(0, batch["label"]) + self.assertAllEqual("test class1a test", batch["labels"]) + batch = next(it_txt) + self.assertAllEqual(1, batch["label"]) + self.assertAllEqual("test class2", batch["labels"]) + batch = next(it_txt) + self.assertAllEqual(1, batch["label"]) + self.assertAllEqual("test class2 test", batch["labels"]) + + def test_average_embeddings(self): + self.assertAllEqual(jnp.array([ + [2.], [4.], [8.], + ]), discriminative_classifier._average_embeddings( + embeddings=jnp.array([ + 1., 3., 3., 1., # label1 + 8., 0., # label2 + 32., 0., 0., 0., # label3 + ])[..., None], + labels=jnp.array([ + 0, 0, # label1 + 0, 0, # label1 (alias) + 1, 1, # label2 + 2, 2, # label3 + 2, 2, # label3 (alias) + ], jnp.int32), + num_classes=3, normalize=False)) + self.assertAllEqual( + jnp.array([ + [2**-.5, 2**-.5], + ]), + discriminative_classifier._average_embeddings( + embeddings=jnp.array([[2., 2.]]), + labels=jnp.array([0], jnp.int32), + num_classes=1, + normalize=True)) + + @mock.patch("big_vision.evaluators.proj." + "image_text.prompt_engineering.get_class_names") + @mock.patch("big_vision.evaluators.proj." + "image_text.prompt_engineering.get_prompt_templates") + @mock.patch("big_vision.evaluators.proj." + "image_text.discriminative_classifier._get_dataset_info") + def test_evaluate(self, get_dataset_info_mock, get_prompt_templates_mock, + get_class_names_mock): + per_device_batch_size = 10 # Make sure we have some unfiltered examples. + global_batch_size = per_device_batch_size * jax.device_count() + per_host_num_examples = int( + np.ceil(global_batch_size / jax.process_count())) + splits = { + "test": + tfds.core.SplitInfo( + name="test", shard_lengths=[per_host_num_examples], num_bytes=0) + } + + model = _Model() + params = model.init(jax.random.PRNGKey(0), None, None)["params"] + + prompt_templates = [ + "test prompt 1 {}", + "test prompt 2 {}", + ] + class_names = [ + f"test_class_{i}" for i in range(10) + ] + + get_prompt_templates_mock.return_value = prompt_templates + get_class_names_mock.return_value = class_names + get_dataset_info_mock.return_value.splits = splits + + def pre_filter_fn(features): + return features["label"] < 5 # matches `texts %= 5` above + + dataset_name = "cifar10_test" + with tfds.testing.mock_data(num_examples=per_host_num_examples): + evaluator = discriminative_classifier.Evaluator( + lambda p, b: model.apply({"params": p}, + b.get("image", None), + b.get("labels", None)), + dataset_names=[dataset_name], + prompt_templates="test_prompts", + batch_size=global_batch_size, + devices=jax.devices(), + pp_img="copy_from(image='label')", + pp_txt="copy_from(labels='label')", + dataset_overrides={ + dataset_name: { + "dataset_name": "cifar10", + "class_names": "test_classes", + "pre_filter_fn": pre_filter_fn, + } + }, + first_class_name_only=True, + ) + results = evaluator.evaluate( + params, + dataset_name, + return_embeddings=True) + metrics = dict(evaluator.run(params)) + + # Assert all examples were processed. + self.assertLen(results["texts"]["embedding"], + len(class_names) * len(prompt_templates)) + self.assertLen(results["texts"]["average_embedding"], len(class_names)) + self.assertAllEqual( + sorted(results["texts"]["label"]), + [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9]) + # Note that above model makes perfect predictions by design. + self.assertEqual(1.0, results["accuracy"]) + self.assertEqual(1.0, metrics[f"{dataset_name}_accuracy"]) + + +if __name__ == "__main__": + tf.test.main() diff --git a/big_vision/evaluators/proj/image_text/image_text_retrieval.py b/big_vision/evaluators/proj/image_text/image_text_retrieval.py new file mode 100644 index 0000000000000000000000000000000000000000..8d0f7b2d4b84b121b8a642ee301835d1dc59ed62 --- /dev/null +++ b/big_vision/evaluators/proj/image_text/image_text_retrieval.py @@ -0,0 +1,85 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Evaluates image-text retrieval results.""" +from typing import List, Mapping + +import numpy as np + +RECALL_THRESHOLDS = (1, 5, 10) + + +def text_to_image_retrieval_eval( + dist_matrix: np.ndarray, + text_image_correspondence: List[int]) -> Mapping[str, float]: + """Runs the text-to-image retrieval eval from the distance matrix. + + Args: + dist_matrix: Distance matrix between text and image embeddings (shape + N_IMAGES x N_TEXTS). + text_image_correspondence: Mapping between rows and columns of + `dist_matrix`, that is, a list of N_TEXTS integers n_i that represent that + the text embedding in column i corresponds to the image embedding in row + n_i. Please note that many texts can be assigned to the same image. For + instance, if we have 2 images and 4 texts (i.e. dist_matrix is 2x4), then + `text_image_correspondence = [0, 0, 1, 1]` means that the two first texts + correspond to the first image and the two last texts to the second image. + + Returns: + A dictionary with the Recall@k scores for k in RECALL_THRESHOLDS. + """ + per_text_ranks = dist_matrix.argsort(axis=0) + text_image_correspondence = np.array(text_image_correspondence) + + def recall_at(k): + wins = per_text_ranks[:k, :] == text_image_correspondence[None] + return wins.any(axis=0).mean() + + return { + f'Recall@{k}': recall_at(k) + for k in RECALL_THRESHOLDS + } + + +def image_to_text_retrieval_eval( + dist_matrix: np.ndarray, + text_image_correspondence: List[int]) -> Mapping[str, float]: + """Runs the image-to-text retrieval eval from the distance matrix. + + Args: + dist_matrix: Distance matrix between text and image embeddings (shape + N_IMAGES x N_TEXTS). + text_image_correspondence: Mapping between rows and columns of + `dist_matrix`, that is, a list of N_TEXTS integers n_i that represent that + the text embedding in column i corresponds to the image embedding in row + n_i. Please note that many texts can be assigned to the same image. For + instance, if we have 2 images and 4 texts (i.e. dist_matrix is 2x4), then + `text_image_correspondence = [0, 0, 1, 1]` means that the two first texts + correspond to the first image and the two last texts to the second image. + + Returns: + A dictionary with the Recall@k scores for k in RECALL_THRESHOLDS. + """ + per_image_ranks = dist_matrix.argsort(axis=1) + text_image_correspondence = np.array(text_image_correspondence) + + def recall_at(k): + top_k_images = text_image_correspondence[per_image_ranks[:, :k]] + wins = top_k_images == np.arange(len(per_image_ranks))[:, None] + return wins.any(axis=1).mean() + + return { + f'Recall@{k}': recall_at(k) + for k in RECALL_THRESHOLDS + } diff --git a/big_vision/evaluators/proj/image_text/image_text_retrieval_test.py b/big_vision/evaluators/proj/image_text/image_text_retrieval_test.py new file mode 100644 index 0000000000000000000000000000000000000000..5125b17f77d9c0c3a15f0c358415334af0b6c216 --- /dev/null +++ b/big_vision/evaluators/proj/image_text/image_text_retrieval_test.py @@ -0,0 +1,86 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for image_text_retrieval.""" +from typing import Mapping + +from absl.testing import absltest +from absl.testing import parameterized +from big_vision.evaluators.proj.image_text import image_text_retrieval +import numpy as np + + +class ImTextRetrievalTest(parameterized.TestCase): + + @parameterized.parameters( + (np.array([[0.0, 0.0, 0.1, 0.5, 0.1, 0.2, 0.5, 0.1], + [0.5, 0.4, 0.0, 0.0, 0.4, 0.2, 0.6, 0.4], + [0.5, 0.4, 0.1, 0.5, 0.0, 0.0, 0.8, 0.3], + [0.5, 0.4, 0.1, 0.5, 0.3, 0.2, 0.0, 0.0]]), { + 'Recall@1': 1.0, + 'Recall@5': 1.0, + 'Recall@10': 1.0 + }), # + (np.array([[0.8, 0.8, 0.1, 0.5, 0.1, 0.2, 0.5, 0.1], + [0.5, 0.4, 0.0, 0.0, 0.4, 0.2, 0.6, 0.4], + [0.5, 0.4, 0.1, 0.5, 0.0, 0.8, 0.8, 0.3], + [0.5, 0.4, 0.1, 0.5, 0.4, 0.2, 0.3, 0.3]]), { + 'Recall@1': 0.5, + 'Recall@5': 0.75, + 'Recall@10': 1.0 + })) + def test_image_to_text_retrieval_eval(self, dist_matrix: np.ndarray, + expected: Mapping[str, float]): + """Checks `image_to_text_retrieval_eval`. + + Args: + dist_matrix: Distance matrix between image (rows) and text (columns). + expected: Expected eval results. + """ + self.assertEqual( + image_text_retrieval.image_to_text_retrieval_eval( + dist_matrix, [0, 0, 1, 1, 2, 2, 3, 3]), expected) + + @parameterized.parameters( + (np.array([[0.0, 0.0, 0.1, 0.5, 0.1, 0.2, 0.5, 0.1], + [0.5, 0.4, 0.0, 0.0, 0.4, 0.2, 0.6, 0.4], + [0.5, 0.4, 0.1, 0.5, 0.0, 0.0, 0.8, 0.3], + [0.5, 0.4, 0.1, 0.5, 0.3, 0.2, 0.0, 0.0]]), { + 'Recall@1': 1.0, + 'Recall@5': 1.0, + 'Recall@10': 1.0 + }), # + (np.array([[0.8, 0.8, 0.1, 0.5, 0.1, 0.2, 0.1, 0.1], + [0.5, 0.4, 0.0, 0.0, 0.4, 0.2, 0.6, 0.4], + [0.5, 0.4, 0.1, 0.5, 0.0, 0.8, 0.8, 0.3], + [0.5, 0.4, 0.1, 0.5, 0.4, 0.2, 0.3, 0.3]]), { + 'Recall@1': 0.375, + 'Recall@5': 1.0, + 'Recall@10': 1.0 + })) + def test_image_text_retrieval(self, dist_matrix: np.ndarray, + expected: Mapping[str, float]): + """Checks `text_to_image_retrieval_eval`. + + Args: + dist_matrix: Distance matrix between image (rows) and text (columns). + expected: Expected eval results. + """ + self.assertEqual( + image_text_retrieval.text_to_image_retrieval_eval( + dist_matrix, [0, 0, 1, 1, 2, 2, 3, 3]), expected) + + +if __name__ == '__main__': + absltest.main() diff --git a/big_vision/evaluators/proj/image_text/prompt_engineering.py b/big_vision/evaluators/proj/image_text/prompt_engineering.py new file mode 100644 index 0000000000000000000000000000000000000000..52450f6a7640d9865dbf19e11945281c9464657d --- /dev/null +++ b/big_vision/evaluators/proj/image_text/prompt_engineering.py @@ -0,0 +1,112 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for generating zero-shot prompts.""" + +import re +import string +from typing import Sequence + +from absl import logging +from big_vision.datasets.imagenet import class_names as imagenet_class_names +from big_vision.evaluators.proj.image_text import prompt_engineering_constants +import tensorflow_datasets as tfds + + +_CLASS_NAMES = { # For each dataset, maps from a source to its class names. + "imagenet2012": { + "clip": imagenet_class_names.CLIP_IMAGENET_CLASS_NAMES, + }, + "grand-vision:imagenet2012": { + "clip": imagenet_class_names.CLIP_IMAGENET_CLASS_NAMES, + }, + "imagenet_a": { + "clip": [ + imagenet_class_names.CLIP_IMAGENET_CLASS_NAMES[i] + for i in imagenet_class_names.IMAGENET_A_LABELSET + ] + }, + "imagenet_r": { + "clip": [ + imagenet_class_names.CLIP_IMAGENET_CLASS_NAMES[i] + for i in imagenet_class_names.IMAGENET_R_LABELSET + ] + }, + "imagenet_v2": { + "clip": imagenet_class_names.CLIP_IMAGENET_CLASS_NAMES, + }, +} + +_PROMPT_TEMPLATES = { + "class_name_only": ["{}"], + "clip_paper": prompt_engineering_constants.CLIP_PAPER_PROMPT_TEMPLATES, + "clip_best": prompt_engineering_constants.CLIP_BEST_PROMPT_TEMPLATES, +} + + +def get_class_names(*, dataset_name, source="dataset_info", canonicalize=True): + """Returns class name for `dataset_name` from `source`.""" + if isinstance(source, str): + if source.startswith("dataset_info:"): + name = source[len("dataset_info:"):] + class_names = tfds.builder(dataset_name).info.features[name].names + else: + class_names = _CLASS_NAMES[dataset_name][source] + else: + assert isinstance(source, Sequence) and all( + map(lambda s: isinstance(s, str), source)), source + class_names = source + if canonicalize: + class_names = [ + canonicalize_text(name, keep_punctuation_exact_string=",") + for name in class_names + ] + logging.info("Using %d class_names: %s", len(class_names), class_names) + return class_names + + +def get_prompt_templates(prompt_templates_name, + *, + canonicalize=True): + """Returns prompt templates.""" + prompts_templates = _PROMPT_TEMPLATES[prompt_templates_name] + if canonicalize: + prompts_templates = [ + canonicalize_text(name, keep_punctuation_exact_string="{}") + for name in prompts_templates + ] + logging.info("Using %d prompts_templates: %s", len(prompts_templates), + prompts_templates) + return prompts_templates + + +def canonicalize_text(text, *, keep_punctuation_exact_string=None): + """Returns canonicalized `text` (lowercase and puncuation removed). + + Args: + text: string to be canonicalized. + keep_punctuation_exact_string: If provided, then this exact string kept. + For example providing '{}' will keep any occurrences of '{}' (but will + still remove '{' and '}' that appear separately). + """ + text = text.replace("_", " ") + if keep_punctuation_exact_string: + text = keep_punctuation_exact_string.join( + part.translate(str.maketrans("", "", string.punctuation)) + for part in text.split(keep_punctuation_exact_string)) + else: + text = text.translate(str.maketrans("", "", string.punctuation)) + text = text.lower() + text = re.sub(r"\s+", " ", text) + return text.strip() diff --git a/big_vision/evaluators/proj/image_text/prompt_engineering_constants.py b/big_vision/evaluators/proj/image_text/prompt_engineering_constants.py new file mode 100644 index 0000000000000000000000000000000000000000..0997971f639711e592f83f71daafd0d6d47f6e49 --- /dev/null +++ b/big_vision/evaluators/proj/image_text/prompt_engineering_constants.py @@ -0,0 +1,110 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Constants used by the module `prompt_engineering` in the same directory.""" + +CLIP_PAPER_PROMPT_TEMPLATES = [ + 'a bad photo of a {}.', + 'a photo of many {}.', + 'a sculpture of a {}.', + 'a photo of the hard to see {}.', + 'a low resolution photo of the {}.', + 'a rendering of a {}.', + 'graffiti of a {}.', + 'a bad photo of the {}.', + 'a cropped photo of the {}.', + 'a tattoo of a {}.', + 'the embroidered {}.', + 'a photo of a hard to see {}.', + 'a bright photo of a {}.', + 'a photo of a clean {}.', + 'a photo of a dirty {}.', + 'a dark photo of the {}.', + 'a drawing of a {}.', + 'a photo of my {}.', + 'the plastic {}.', + 'a photo of the cool {}.', + 'a close-up photo of a {}.', + 'a black and white photo of the {}.', + 'a painting of the {}.', + 'a painting of a {}.', + 'a pixelated photo of the {}.', + 'a sculpture of the {}.', + 'a bright photo of the {}.', + 'a cropped photo of a {}.', + 'a plastic {}.', + 'a photo of the dirty {}.', + 'a jpeg corrupted photo of a {}.', + 'a blurry photo of the {}.', + 'a photo of the {}.', + 'a good photo of the {}.', + 'a rendering of the {}.', + 'a {} in a video game.', + 'a photo of one {}.', + 'a doodle of a {}.', + 'a close-up photo of the {}.', + 'a photo of a {}.', + 'the origami {}.', + 'the {} in a video game.', + 'a sketch of a {}.', + 'a doodle of the {}.', + 'a origami {}.', + 'a low resolution photo of a {}.', + 'the toy {}.', + 'a rendition of the {}.', + 'a photo of the clean {}.', + 'a photo of a large {}.', + 'a rendition of a {}.', + 'a photo of a nice {}.', + 'a photo of a weird {}.', + 'a blurry photo of a {}.', + 'a cartoon {}.', + 'art of a {}.', + 'a sketch of the {}.', + 'a embroidered {}.', + 'a pixelated photo of a {}.', + 'itap of the {}.', + 'a jpeg corrupted photo of the {}.', + 'a good photo of a {}.', + 'a plushie {}.', + 'a photo of the nice {}.', + 'a photo of the small {}.', + 'a photo of the weird {}.', + 'the cartoon {}.', + 'art of the {}.', + 'a drawing of the {}.', + 'a photo of the large {}.', + 'a black and white photo of a {}.', + 'the plushie {}.', + 'a dark photo of a {}.', + 'itap of a {}.', + 'graffiti of the {}.', + 'a toy {}.', + 'itap of my {}.', + 'a photo of a cool {}.', + 'a photo of a small {}.', + 'a tattoo of the {}.', + '{}', +] + +CLIP_BEST_PROMPT_TEMPLATES = [ + 'itap of a {}.', + 'a bad photo of the {}.', + 'a origami {}.', + 'a photo of the large {}.', + 'a {} in a video game.', + 'art of the {}.', + 'a photo of the small {}.', + '{}', +] diff --git a/big_vision/evaluators/proj/image_text/prompt_engineering_test.py b/big_vision/evaluators/proj/image_text/prompt_engineering_test.py new file mode 100644 index 0000000000000000000000000000000000000000..e6833c60157d01ba9d787115e643b85506e57d4f --- /dev/null +++ b/big_vision/evaluators/proj/image_text/prompt_engineering_test.py @@ -0,0 +1,48 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for prompt_engineering.""" + +from absl.testing import absltest +from big_vision.evaluators.proj.image_text import prompt_engineering + + +class PromptEngineeringTest(absltest.TestCase): + + def test_canonicalize_text(self): + self.assertEqual(prompt_engineering.canonicalize_text("test_test"), "test test") + self.assertEqual( + prompt_engineering.canonicalize_text("test___test"), "test test") + self.assertEqual(prompt_engineering.canonicalize_text("test"), "test") + self.assertEqual(prompt_engineering.canonicalize_text("test."), "test") + self.assertEqual(prompt_engineering.canonicalize_text(" test "), "test") + self.assertEqual( + prompt_engineering.canonicalize_text("test\ntest"), "test test") + self.assertEqual( + prompt_engineering.canonicalize_text("test test"), "test test") + self.assertEqual(prompt_engineering.canonicalize_text("test {}"), "test") + self.assertEqual( + prompt_engineering.canonicalize_text( + "test {}", keep_punctuation_exact_string="{}"), "test {}") + self.assertEqual( + prompt_engineering.canonicalize_text( + " test {}...", keep_punctuation_exact_string="{}"), "test {}") + self.assertEqual( + prompt_engineering.canonicalize_text( + "test {} {} {}", keep_punctuation_exact_string="{}"), + "test {} {} {}") + + +if __name__ == "__main__": + absltest.main() diff --git a/big_vision/evaluators/proj/image_text/retrieval.py b/big_vision/evaluators/proj/image_text/retrieval.py new file mode 100644 index 0000000000000000000000000000000000000000..daf43927f3273dca1232567578614aa6f783f613 --- /dev/null +++ b/big_vision/evaluators/proj/image_text/retrieval.py @@ -0,0 +1,306 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Multi-host image->text and text->image retrieval evaluation. + +Example how to add to config: + + config.evals {} + config.evals.retieval = dict(log_steps=1200, type='proj.image_text.retrieval') + config.evals.retrieval.dataset = 'coco_captions' + config.evals.retrieval.txt_name = ('captions', 'text') + # Note that initial "decode|" is not needed. + config.evals.retrieval.pp_img = 'resize(224)|value_range(-1,1)' + # Raw text strings use key "texts" in feature dict. The evaluator expects + # tokenized text with key "labels". + config.evals.retrieval.pp_txt = ( + 'tokenize(max_len=16, eos="sticky", pad_value=1, inkey="texts", ' + ' outkey="labels")') + +Example to support precomputed data: +See `big_vision/configs/proj/image_text/lit.py`. +""" + +import functools +import operator +import time + +from absl import logging +from big_vision import input_pipeline +from big_vision.evaluators.proj.image_text import image_text_retrieval +import big_vision.pp.builder as pp_builder +import jax +import jax.numpy as jnp +from jax.sharding import NamedSharding +from jax.sharding import PartitionSpec as P +import numpy as np +import tensorflow as tf +import tensorflow_datasets as tfds + + +# Temporary global flag to facilitate backwards compatability. Will be removed +# by the end of year 2023. +API = "jit" + + +def _with_infinite_padding(dataset): + """Adds "infinite padding" to the dataset.""" + filler_element = tf.nest.map_structure( + lambda spec: tf.zeros(spec.shape, spec.dtype)[None], dataset.element_spec) + filler_element["mask"] = [False] + filler_dataset = tf.data.Dataset.from_tensor_slices(filler_element) + dataset = dataset.map( + lambda features: dict(mask=True, **features), + num_parallel_calls=tf.data.experimental.AUTOTUNE) + return dataset.concatenate(filler_dataset.repeat(None)) + + +# This is needed so retrieval_test can replace dataset info. +def _get_dataset_info(builder): + return builder.info + + +def prepare_datasets( + dataset, *, pp_img, pp_txt, txt_name, offset=0, cache_final=False +): + """Returns unbatched `ds_images, ds_texts` datasets. + + Args: + dataset: An image-text `tf.data.Dataset` that is expected to contain the + following features: "image" (dtype=uint8, shape=[None, None, 3]), + `txt_name` (dtype=string, shape=[None]). + pp_img: String defining pre-processing for images. The pre-processing can + expect the following features to be prepared: "image", "id". The + pre-processing should convert the "image" (dtype=uint8, + shape=[None, None, 3]) to "image" (dtype=float32, shape=[sz, sz, 3]). + pp_txt: String defining pre-processing for text. The pre-processing can + expect the following features to be prepared: "texts", "id", "caption_id". + The pre-processing should convert the "texts" (dtype=string, shape=[]) + into a tokenized "labels" (dtype=int32, shape=[max_len]). + txt_name: Name of the text feature to unroll in the original `dataset`. Can + be a simple string feature name, or an iterable of strings to specify a + nested feature (e.g. for "coco_captions", this would be + `('captions', 'text')`). + offset: Offset that should be added to enumerated examples to generate IDs. + In a multi-host setup, this is typically set to a value large enough to + make all IDs distinct. + cache_final: Whether the dataset should be cached. + + Returns: + Image and text datasets. + """ + + def get_feature_value(data, feature_name): + if isinstance(feature_name, str): + feature_name = [feature_name] + return functools.reduce(operator.getitem, feature_name, data) + + def get_captions(idx, features): + """Returns a dataset with unrolled "caption" for every example.""" + texts = get_feature_value(features, txt_name) + texts = tf.experimental.numpy.atleast_1d(texts) # For single-text GT. + texts_n = tf.shape(texts)[0] + return tf.data.Dataset.from_tensor_slices({ + "id": tf.tile([idx + offset], [texts_n]), + "caption_i": tf.stack(tf.range(texts_n)), + "texts": tf.stack(texts), + }) + + def add_id(idx, features): + return {**features, "id": idx + offset} + + ds_images = dataset.enumerate().map(add_id).map( + pp_builder.get_preprocess_fn(f"{pp_img}|keep('id', 'image')")) + ds_texts = dataset.enumerate().flat_map(get_captions).map( + pp_builder.get_preprocess_fn( + f"{pp_txt}|keep('id', 'caption_i', 'labels')")) + if cache_final: + ds_images, ds_texts = ds_images.cache(), ds_texts.cache() + return ds_images, ds_texts + + +def _split_and_batch(dataset_name, batch_size, split, get_ds, data_dir=None): + """Splits dataset, calls `get_ds` and returns padded + batched datasets.""" + assert not batch_size % jax.device_count(), ( + f"batch_size={batch_size} % jax.device_count()={jax.device_count()}") + builder = tfds.builder(dataset_name, data_dir=data_dir) + info = _get_dataset_info(builder) + num_examples = info.splits[split].num_examples + ds_images, ds_texts = get_ds( + builder.as_dataset(split=tfds.split_for_jax_process(split)), + offset=jax.process_index() * num_examples, + ) + return ( + _with_infinite_padding(ds_images).batch(batch_size), + _with_infinite_padding(ds_texts).batch(batch_size), + ) + + +class Evaluator: + """Image/text retrieval evaluator.""" + + def __init__(self, + predict_fn, + *, + dataset, + pp_img, + pp_txt, + txt_name, + batch_size, + devices, + data_dir=None, + split="test", + cache_final=True): + """Initializes a new zero-shot image/text retrieval evaluator. + + See `prepare_datasets()` for details on how the dataset is pre-processed. + + Args: + predict_fn: Prediction function with signature + `zimg, ztxt, out = predict_fn(params, images, texts)` + dataset: The TFDS dataset name of the eval data. + pp_img: Preprocessing string for images. Preprocessed features should + contain key "image" with value that can be batched and is suitable for + `predict_fn(images)` input``. + pp_txt: Preprocessing string for texts. Can expect "texts" key as an input + (shape=[], dtype=string), and is expected to produce "labels" key that + is suitable for `predict_fn(texts)` input. + txt_name: The name of the feature of captions (can be a tuple to look up a + value in a nested feature dictionary). Expected shape=[None], + dtype=string. specified then items are used as lookup path. + batch_size: Global batch size. + devices: list of devices. + data_dir: Optional dir to load the TFDS dataset from. + split: The split of the eval data. + cache_final: Wether preprocessed dataset should be cached. + """ + self.ds_images, self.ds_texts = _split_and_batch( + dataset, + batch_size, + split, + functools.partial( + prepare_datasets, + pp_img=pp_img, + pp_txt=pp_txt, + txt_name=txt_name, + cache_final=cache_final, + ), + data_dir=data_dir, + ) + self._axis_name = "batch" + + self.devices = devices + mesh = jax.sharding.Mesh(devices, ("devices",)) + + def embed_images(train_state, images): + zimg, _, _ = predict_fn(train_state, {"image": images}) + return zimg + + def embed_texts(train_state, texts): + _, ztxt, _ = predict_fn(train_state, {"labels": texts}) + return ztxt + + self._embed_images_p = jax.jit(embed_images, + out_shardings=NamedSharding(mesh, P())) + self._embed_texts_p = jax.jit(embed_texts, + out_shardings=NamedSharding(mesh, P())) + self._all_gather_p = jax.jit( + lambda x: x, out_shardings=NamedSharding(mesh, P())) + self._count_p = jax.jit(jnp.sum, out_shardings=NamedSharding(mesh, P())) + self._compiled = set() + + def _embed(self, name, train_state, ds, embed_fn, id_names): + """Embeds features name `name` using `embed_fn`. + + Args: + name: Feature name to be embedded. + train_state: train_state for the predict_fn. + ds: The dataset. + embed_fn: A pmapped function that returns the embeddings. + id_names: An iterable of feature names that should be collected. + + Returns: + A dictionary with "embeddings" and `id_names` as keys. + """ + ns = [] + embeddings = [] + ids = {id_name: [] for id_name in list(id_names) + ["mask"]} + + t0 = time.time() + + ds_b = input_pipeline.start_global(ds, self.devices) + for batch in ds_b: + ns.append(jax.device_get(self._count_p(batch["mask"]))) + + # Due to infinite padding, this loop will never end. We will stop once + # all processes only process padded data. We don't check the latest + # DeviceArray `ns[-1]` Because we want to keep our computation async for + # efficiency reasons. + if len(ns) >= 2 and ns[-2] == 0: + break + + embs = embed_fn(train_state, batch[name]) + if embed_fn not in self._compiled: + logging.info("Compiled %s embeddings in %.3fs", name, time.time() - t0) + t0 = time.time() + self._compiled.add(embed_fn) + + embeddings.append(jax.device_get(embs)) + for id_name in ids: + ids[id_name].append(jax.device_get(self._all_gather_p(batch[id_name]))) + + # Only access DeviceArray at end of loop for better efficiency. + ns = np.array(ns) + embeddings = np.concatenate(embeddings) + ids = {k: np.concatenate(v) for k, v in ids.items()} + masks = ids.pop("mask").astype(bool) + logging.info("Processed %s in %d steps - ...%s", name, len(ns), ns[-10:]) + n = ns.sum() + logging.info("Totalling %d %s in %.3fs", n, name, time.time() - t0) + return { + "embeddings": embeddings[masks], + **{k: v[masks] for k, v in ids.items()}, + } + + def evaluate(self, train_state): + """Returns evaluation results.""" + images = self._embed("image", train_state, self.ds_images, + self._embed_images_p, ("id",)) + texts = self._embed("labels", train_state, self.ds_texts, + self._embed_texts_p, ("id", "caption_i")) + # Shapes: (nimg, emb) * (emb, ntxt) -> (nimg, ntxt) + similarities = np.dot(images["embeddings"], texts["embeddings"].T) + + t0 = time.time() + id2img = {id_: i for i, id_ in enumerate(images["id"])} + text_image_correspondence = [id2img[id_] for id_ in texts["id"]] + img2txt = image_text_retrieval.image_to_text_retrieval_eval( + -similarities, text_image_correspondence) + txt2img = image_text_retrieval.text_to_image_retrieval_eval( + -similarities, text_image_correspondence) + logging.info("Computed retrieval metrics in %.3fs", time.time() - t0) + + return dict( + images=images, + texts=texts, + img2txt=img2txt, + txt2img=txt2img, + ) + + def run(self, train_state): + """Returns metrics.""" + results = self.evaluate(train_state) + return [(f"{direction}_{k.lower()}", v) + for direction in ("img2txt", "txt2img") + for k, v in results[direction].items()] diff --git a/big_vision/evaluators/proj/image_text/retrieval_test.py b/big_vision/evaluators/proj/image_text/retrieval_test.py new file mode 100644 index 0000000000000000000000000000000000000000..6a6a993df88819d8d626710bb8135444fba2130a --- /dev/null +++ b/big_vision/evaluators/proj/image_text/retrieval_test.py @@ -0,0 +1,178 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for retrieval.""" + +from unittest import mock + +from big_vision.evaluators.proj.image_text import retrieval +from big_vision.pp import ops_general # pylint: disable=unused-import +from big_vision.pp import ops_image # pylint: disable=unused-import +from big_vision.pp import registry +import chex +import flax.linen as nn +import jax +import jax.numpy as jnp +import tensorflow as tf +import tensorflow_datasets as tfds + + +def _get_test_texts2labels(): + + def pp(features): + features["labels"] = tf.strings.to_number(features["texts"]) + return features + + return pp + + +def _get_copy_from(**key_map): + + def copy_from(d): + d = dict(d) + for k1, k2 in key_map.items(): + d[k1] = d[k2] + return d + + return copy_from + + +class _Model(nn.Module): + + @nn.compact + def __call__(self, image, texts): + self.param("x", lambda _: 0.) + + def z(x): + if x is not None: + batch_size = len(x) + # Note that the returned vector is most similar with other vectors + # generated from the same underlying `x[:]`. + x = jnp.concatenate([100 * jnp.ones([batch_size, 1]), x[:, None]], + axis=1) + return x / jnp.linalg.norm(x, axis=1)[:, None] + + return z(image), z(texts), None + + +def setUpModule(): + chex.set_n_cpu_devices(8) + + +class RetrievalTest(tf.test.TestCase): + + def test_prepare_datasets(self): + + def generator(): + yield { + "image": tf.ones([5, 5, 3], tf.float32), + "captions": { + "text": tf.constant(["11", "12"]) + } + } + yield { + "image": tf.ones([4, 4, 3], tf.float32), + "captions": { + "text": tf.constant(["21", "22", "23"]) + } + } + + ds = tf.data.Dataset.from_generator( + generator, + output_signature={ + "image": tf.TensorSpec(shape=[None, None, 3], dtype=tf.float32), + "captions": { + "text": tf.TensorSpec(shape=[None], dtype=tf.string), + }, + }) + with registry.temporary_ops(test_texts2labels=_get_test_texts2labels): + ds_img, ds_txt = retrieval.prepare_datasets( + ds, + pp_img="resize(2)", + pp_txt="test_texts2labels()", + txt_name=("captions", "text"), + ) + it_img = iter(ds_img) + it_txt = iter(ds_txt) + batch = next(it_img) + self.assertAllEqual(batch["id"], 0) + self.assertAllEqual(batch["image"], tf.ones([2, 2, 3])) + batch = next(it_img) + self.assertAllEqual(batch["id"], 1) + self.assertAllEqual(batch["image"], tf.ones([2, 2, 3])) + batch = next(it_txt) + self.assertAllEqual(batch["id"], 0) + self.assertAllEqual(batch["caption_i"], 0) + self.assertAllEqual(batch["labels"], 11.0) + batch = next(it_txt) + self.assertAllEqual(batch["id"], 0) + self.assertAllEqual(batch["caption_i"], 1) + self.assertAllEqual(batch["labels"], 12.0) + batch = next(it_txt) + self.assertAllEqual(batch["id"], 1) + self.assertAllEqual(batch["caption_i"], 0) + self.assertAllEqual(batch["labels"], 21.0) + batch = next(it_txt) + self.assertAllEqual(batch["id"], 1) + self.assertAllEqual(batch["caption_i"], 1) + self.assertAllEqual(batch["labels"], 22.0) + batch = next(it_txt) + self.assertAllEqual(batch["id"], 1) + self.assertAllEqual(batch["caption_i"], 2) + self.assertAllEqual(batch["labels"], 23.0) + + def test_evaluate(self): + per_device_batch_size = 2 + batch_size = per_device_batch_size * jax.device_count() + num_examples = 1 * batch_size + 1 + splits = { + "test": + tfds.core.SplitInfo( + name="test", shard_lengths=[num_examples], num_bytes=0) + } + + model = _Model() + params = model.init(jax.random.PRNGKey(0), None, None)["params"] + + with tfds.testing.mock_data(num_examples=num_examples): + info_mock = mock.Mock() + info_mock.splits = splits + with mock.patch.object(retrieval, "_get_dataset_info", + lambda _: info_mock): + with registry.temporary_ops(copy_from=_get_copy_from): + evaluator = retrieval.Evaluator( + lambda p, b: model.apply({"params": p}, + b.get("image", None), + b.get("labels", None)), + dataset="coco_captions", + batch_size=batch_size, + devices=jax.devices(), + txt_name=("captions", "text"), + pp_img="copy_from(image='id')", + pp_txt="copy_from(labels='id')", + ) + results = evaluator.evaluate(params) + + # Assert all examples were processed. + self.assertLen(results["images"]["embeddings"], num_examples) + self.assertLen(results["images"]["id"], num_examples) + # Assert no padding was processed (expects exactly one (=first) image.id=0 + self.assertEqual((results["images"]["id"] == 0).sum(), 1) + # Expect perfect ITR with above _Model()... + self.assertEqual(results["img2txt"]["Recall@1"], 1.0) + self.assertEqual(results["txt2img"]["Recall@5"], 1.0) + + +if __name__ == "__main__": + tf.test.main() diff --git a/big_vision/evaluators/proj/paligemma/perplexity.py b/big_vision/evaluators/proj/paligemma/perplexity.py new file mode 100644 index 0000000000000000000000000000000000000000..38b5ed077b698ec922785e8e87b462409ef737d2 --- /dev/null +++ b/big_vision/evaluators/proj/paligemma/perplexity.py @@ -0,0 +1,58 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Evaluator for perplexity of a model.""" +import functools + +from big_vision.evaluators import mean +import big_vision.utils as u +import jax.numpy as jnp + + +# Temporary global flag to facilitate backwards compatability. Will be removed +# by the end of year 2023. +API = 'jit' + + +# Cache the function such that it won't always recompile (in mean evaluator). +@functools.cache +def perplexity( + predict_fn, key='labels', shift_labels=True): + """Returns a function that computes perplexity.""" + + def _perplexity_fn(train_state, batch, **kw): + logits, _ = predict_fn(train_state, batch, **kw) + + labels = batch[key] + weights = batch.get('mask_loss', jnp.ones_like(labels)) + + if shift_labels: + labels = labels[:, 1:] + weights = weights[:, 1:] + + losses = u.weighted_softmax_xent( + logits=logits, labels=labels, weights=weights, + reduction=False, normalize=False) + normalizer = jnp.clip(weights.sum(axis=1), 2e-38) + + return {'sum': losses, 'avg': losses / normalizer} + return _perplexity_fn + + +class Evaluator(mean.Evaluator): + """Perplexity evaluator.""" + + def __init__(self, predict_fn, *a, key='labels', shift_labels=False, **kw): + kw.setdefault('prefetch', 0) # More memory-saving default. + super().__init__(perplexity(predict_fn, key, shift_labels), *a, **kw) diff --git a/big_vision/evaluators/proj/paligemma/transfers/chartqa.py b/big_vision/evaluators/proj/paligemma/transfers/chartqa.py new file mode 100644 index 0000000000000000000000000000000000000000..41dadb9e1fb148e83ecf4b09673d3346722c7b96 --- /dev/null +++ b/big_vision/evaluators/proj/paligemma/transfers/chartqa.py @@ -0,0 +1,139 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Evaluator for ChartQA variants.""" + +import functools + +import big_vision.evaluators.common as c +import big_vision.pp.tokenizer +import big_vision.utils as u + + +# Temporary global flag to facilitate backwards compatability. Will be removed +# by the end of year 2023. +API = "jit" + + +class Evaluator: + """Evaluator for simple VQA tasks.""" + + def __init__( + self, predict_fn, tokenizer, to_lower=False, + outfile="{workdir}/{split}.json", + out_question_key="question_id", out_answer_key="answer", + *, data, devices, **kw): + self.get_data_iter, self.steps = c.eval_input_pipeline( + keep_on_cpu={"answer", "question_id"}, data=data, devices=devices, **kw) + + self.outfile = c.resolve_outfile(outfile, split=data.get("split")) + self.out_question_key = out_question_key + self.out_answer_key = out_answer_key + + # We'll need the tokenizer to detokenize the model outputs later. + self.tok = big_vision.pp.tokenizer.get_tokenizer(tokenizer) + self.postproc = (lambda s: s.lower()) if to_lower else lambda s: s + self.decode = functools.partial( + predict_fn, devices=devices, eos_token=self.tok.eos_token) + + def run(self, train_state): + """Does one evaluation run, yields metrics.""" + + accuracies = [] + relaxed_accuracies = [] + json_out = [] + for _, batch in zip(range(self.steps), self.get_data_iter()): + # (batch, seqlen) array of decoded generated tokens. + tokens = self.decode(train_state, batch) + + # (local_batch,) that indicates padding examples (0) vs real examples (1). + tokens = u.get_local_slice_from_fsarray(tokens) + ex_masks = u.get_local_slice_from_fsarray(batch["_mask"]) + + # Turn predictions into texts and then scores, one by one. + for i in range(len(tokens)): + if ex_masks[i] == 0: # Skip last-batch padding examples + continue + + answer = self.postproc(self.tok.to_str(tokens[i], stop_at_eos=True)) + + gt = self.postproc(batch["answer"][i]) + accuracies.append(float(answer == gt)) + relaxed_accuracies.append(_relaxed_match(gt, answer)) + json_out.append({ + self.out_question_key: batch["question_id"][i].item(), + self.out_answer_key: answer, + "gt": gt, + "relaxed_match": relaxed_accuracies[-1], + }) + + # At this point `accuracies` is a list of per-example scores. However, + # remember that each host holds a different subset of the examples! So if + # we were to just return the mean accuracy here, we would effectively only + # have evaluated on the main host's (who writes metrics) subset! + # So now, we need to compute global means. + # There is one more caveat: `process_sum` needs the summands on each host + # to have the same size. So we either need to include dummy values for + # the padding examples (last batch, annoying), or we only sum scalars as in + # sufficient statistics, which we do here. + sum_accs, sum_relaxed_accs, num = c.process_sum( + [sum(accuracies), sum(relaxed_accuracies), len(accuracies)]) + + # Yielding metric_name, value means logging the metric. + yield "acc", sum_accs / num + yield "relaxed_acc", sum_relaxed_accs / num + yield "num", num # Just for sanity checks. + c.multiprocess_write_json(self.outfile, json_out) + + +def _to_float(text: str) -> float | None: + try: + if text.endswith("%"): + # Convert percentages to floats. + return float(text.rstrip("%")) / 100.0 + else: + return float(text) + except ValueError: + return None + + +def _relaxed_match(target: str, + prediction: str, + max_relative_error: float = 0.05) -> bool: + """Calculates relaxed correctness. + + The correctness tolerates certain error ratio defined by max_relative_error. + See https://arxiv.org/pdf/2203.10244.pdf, end of section 5.1: + “Following Methani et al. (2020), we use a relaxed accuracy measure for the + numeric answers to allow a minor inaccuracy that may result from the automatic + data extraction process. We consider an answer to be correct if it is within + 5% of the gold answer. For non-numeric answers, we still need an exact match + to consider an answer to be correct.” + + Args: + target: Target string. + prediction: Predicted string. + max_relative_error: Maximum relative error. + + Returns: + Whether the prediction was correct given the specified tolerance. + """ + prediction_float = _to_float(prediction) + target_float = _to_float(target) + # When the target is 0 is always required an exact match. + if prediction_float is not None and target_float: + relative_error = abs(prediction_float - target_float) / abs(target_float) + return relative_error <= max_relative_error + else: + return prediction == target diff --git a/big_vision/evaluators/proj/paligemma/transfers/coco_caption.py b/big_vision/evaluators/proj/paligemma/transfers/coco_caption.py new file mode 100644 index 0000000000000000000000000000000000000000..69721ff57dc061432c7e06e2cd3f1c5b7f9a4cde --- /dev/null +++ b/big_vision/evaluators/proj/paligemma/transfers/coco_caption.py @@ -0,0 +1,145 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Evaluator for caption generation metrics used for the MS COCO dataset.""" +import collections +import functools +import os +import tempfile + +import big_vision.evaluators.common as c +import big_vision.input_pipeline +import big_vision.pp.builder +import big_vision.pp.tokenizer +import big_vision.utils as u + +from pycocoevalcap.bleu import bleu +from pycocoevalcap.cider import cider +from pycocoevalcap.meteor import meteor +from pycocoevalcap.rouge import rouge +from pycocoevalcap.spice import spice +from pycocoevalcap.tokenizer import ptbtokenizer + +import jax + +from tensorflow.io import gfile + +# Temporary global flag to facilitate backwards compatability. Will be removed +# by the end of year 2023. +API = "jit" + + +class Evaluator: + """Evaluator for caption generation metrics used for the MS COCO dataset. + + See https://arxiv.org/pdf/1504.00325.pdf or the repository implementing it + https://github.com/tylin/coco-caption for details on the metrics. This code + uses the python3 pip package from: https://github.com/salaniz/pycocoevalcap + + Note that both the model caption and the ground truth reference captions are + further processed with the PTBTokenizer before computing scores. + + `predict_fn` accepts arbitrary dictionaries of parameters and data, where + the data dictionary is produced by the `pp_fn` op. It is expected to output a + dict containing tokenized captions. + + `pp_fn` must have fields: "image/id" and "captions". + """ + + def __init__( + self, predict_fn, tokenizer=None, + metrics=("cider",), # Default to only cider. We often just look at that. + preds_outfile="{workdir}/{name}_{split}_preds.json", + annot_outfile="{workdir}/{name}_{split}_annotations.json", + *, data, devices, **kw + ): + self.get_data_iter, self.steps = c.eval_input_pipeline( + keep_on_cpu={"image/id", "captions"}, data=data, devices=devices, **kw) + + self.preds_outfile = c.resolve_outfile( + preds_outfile, name=data.get("name"), split=data.get("split")) + self.annot_outfile = c.resolve_outfile( + annot_outfile, name=data.get("name"), split=data.get("split")) + + self.metrics = metrics + self.tok = big_vision.pp.tokenizer.get_tokenizer(tokenizer) + self.decode = functools.partial( + predict_fn, devices=devices, eos_token=self.tok.eos_token) + + def run(self, train_state): + """Run eval.""" + gts = [] + res = [] + + for _, batch in zip(range(self.steps), self.get_data_iter()): + # (batch, seqlen) array of decoded generated tokens. + tokens = self.decode(train_state, batch) + + # (local_batch,) + tokens = u.get_local_slice_from_fsarray(tokens) + ex_masks = u.get_local_slice_from_fsarray(batch["_mask"]) + + image_ids = batch["image/id"][ex_masks] + pred_captions = self.tok.to_str(tokens[ex_masks]) + + for image_id, caption in zip(image_ids, pred_captions): + res.append({"image_id": image_id.item(), "caption": caption}) + + for image_id, captions in zip(image_ids, batch["captions"]): + for caption in captions: + gts.append({"image_id": image_id.item(), "caption": caption.item()}) + + # Write model outputs following: https://cocodataset.org/#format-results + # Use same format for gt although that is not the usual format for them. + res = c.multiprocess_write_json(self.preds_outfile, res) + gts = c.multiprocess_write_json(self.annot_outfile, gts) + + if jax.process_index(): # Host0 gets all preds and does eval. + return + + outs = self.evaluate(gts, res) + for key, score in outs.items(): + yield key, score + + def evaluate(self, gt_annotations, res_annotations): + """Creates scorers and run evaluation.""" + scorers = { + "rouge": rouge.Rouge, + "cider": cider.Cider, + "bleu-4": bleu.Bleu, + "spice": spice.Spice, + "meteor": meteor.Meteor, + } + + # Reformat gts and res from [{"image_id": int|str, "caption": str}] to + # {int_image_id: [{"caption": str}]} as expected by tokenizer and scorers. + # Note there are multiple reference captions for the ground truth but only + # one for the model predictions. + iid_map = collections.defaultdict(lambda: len(iid_map)) + res = {iid_map[x["image_id"]]: [x] for x in res_annotations} + gts = collections.defaultdict(list) + for x in gt_annotations: + gts[iid_map[x["image_id"]]].append(x) + assert sorted(gts.keys()) == sorted(res.keys()) + + # Tokenize captions and predictions using coco tokenizer. + coco_tokenizer = ptbtokenizer.PTBTokenizer() + gts = coco_tokenizer.tokenize(gts) + res = coco_tokenizer.tokenize(res) + + scores = {} + for metric in self.metrics: + scorer = scorers[metric]() + scores[metric], _ = scorer.compute_score(gts, res) + return scores diff --git a/big_vision/evaluators/proj/paligemma/transfers/pope.py b/big_vision/evaluators/proj/paligemma/transfers/pope.py new file mode 100644 index 0000000000000000000000000000000000000000..7ef4df53f89f5929b644e6b95881df7c77c3cdbb --- /dev/null +++ b/big_vision/evaluators/proj/paligemma/transfers/pope.py @@ -0,0 +1,135 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Evaluator for the POPE dataset (https://github.com/RUCAIBox/POPE). + +POPE is a binary classification dataset with ground-truth answers being either +'yes' or 'no'. +""" + +import functools + +import big_vision.datasets.core +import big_vision.evaluators.common as c +import big_vision.input_pipeline +import big_vision.pp.builder +import big_vision.pp.tokenizer +import big_vision.utils as u + + +# Temporary global flag to facilitate backwards compatability. Will be removed +# by the end of year 2023. +API = "jit" + + +class Evaluator: + """Evaluator for the POPE task. + + This evaluator expects the batch to contain a field `question_id` and a field + `answer` for single ground truth or `answers` for multiple ground truths. + + The field names used when writting the json result can be controlled with + `out_question_key` and `out_answer_key`. + """ + + def __init__( + self, + predict_fn, + data, + pp_fn, + tokenizer, + batch_size, + *, + devices, + outfile="{workdir}/{split}.json", + out_question_key="question_id", + out_answer_key="answer" + ): + + self.outfile = c.resolve_outfile(outfile, split=data.get("split")) + self.out_question_key = out_question_key + self.out_answer_key = out_answer_key + # This will mostly look the same across all evaluators, preparing data: + data = big_vision.datasets.core.get(**data) + pp_fn = big_vision.pp.builder.get_preprocess_fn(pp_fn) + self.ds, self.steps = big_vision.input_pipeline.make_for_inference( + data.get_tfdata(ordered=True), + pp_fn, + batch_size, + num_ex_per_process=data.num_examples_per_process(), + ) + # The `keep_on_cpu=` argument lists the data keys that, if they exist, we + # do NOT want to ship to the TPUs and instead just keep in host memory. + # Typically ground-truth and metadata, that is often of string type. + self.data_iter = big_vision.input_pipeline.start_global( + self.ds, devices, keep_on_cpu={"answer", "question_id"} + ) + # We'll need the tokenizer to detokenize the model outputs later. + self.tok = big_vision.pp.tokenizer.get_tokenizer(tokenizer) + self.decode = functools.partial( + predict_fn, devices=devices, eos_token=self.tok.eos_token + ) + + def run(self, train_state): + """Does one evaluation run, yields metrics.""" + + accuracies = [] + valid = [] + json_out = [] + for _, batch in zip(range(self.steps), self.data_iter): + # (batch, seqlen) array of decoded generated tokens. + tokens = self.decode(train_state, batch) + + # (local_batch,) that indicates padding examples (0) vs real examples (1). + tokens = u.get_local_slice_from_fsarray(tokens) + ex_masks = u.get_local_slice_from_fsarray(batch["_mask"]) + + # Turn predictions into texts and then scores, one by one. + for i in range(len(tokens)): + if ex_masks[i] == 0: # Skip last-batch padding examples + continue + + answer = self.tok.to_str(tokens[i], stop_at_eos=True).lower() + gt = batch["answer"][i] + accuracies.append(float(answer == gt)) + valid.append(float(answer in ("yes", "no"))) + + json_out.append( + { + self.out_question_key: batch["question_id"][i].item(), + self.out_answer_key: answer, + } + ) + + # At this point `accuracies` is a list of per-example scores. However, + # remember that each host holds a different subset of the examples! So if + # we were to just return the mean accuracy here, we would effectively only + # have evaluated on the main host's (who writes metrics) subset! + # So now, we need to compute global means. + # There is one more caveat: `process_sum` needs the summands on each host + # to have the same size. So we either need to include dummy values for + # the padding examples (last batch, annoying), or we only sum scalars as in + # sufficient statistics, which we do here. + sum_accs, sum_valid, num = c.process_sum([ + sum(accuracies), + sum(valid), + len(accuracies), + ]) + + if num: + yield "acc", sum_accs / num + yield "valid_percent", sum_valid / num + yield "num", num + + c.multiprocess_write_json(self.outfile, json_out) diff --git a/big_vision/evaluators/proj/paligemma/transfers/rsvqa.py b/big_vision/evaluators/proj/paligemma/transfers/rsvqa.py new file mode 100644 index 0000000000000000000000000000000000000000..575bd12d1f53b20b1393bba264f5c38ae08f7c6b --- /dev/null +++ b/big_vision/evaluators/proj/paligemma/transfers/rsvqa.py @@ -0,0 +1,173 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Evaluator for simple VQA variants with per answer-type metrics. + +According to the (A-)OKVAQ papers, the eval for these datasets should follow +VQAv2. But here we don't track different answer-types, and don't do any +leave-one-out averaging, as this isn't done in the official implementation at +https://github.com/allenai/aokvqa/blob/main/evaluation/eval_predictions.py +either. +""" + +import functools + +import big_vision.evaluators.common as c +import big_vision.pp.tokenizer +import big_vision.utils as u +import editdistance + + +# Temporary global flag to facilitate backwards compatability. Will be removed +# by the end of year 2023. +API = "jit" + +QUESTION_TYPES = ("comp", "count", "presence", "rural_urban", "area") + +ACC_SUBSETS = ( + ("nonum", ("comp", "presence", "rural_urban")), # rsvqa_lr + ("nonum", ("comp", "presence")), # rsvqa_hr +) + + +class Evaluator: + """Evaluator for simple VQA tasks.""" + + def __init__( + self, predict_fn, tokenizer, to_lower=False, + outfile="{workdir}/{split}.json", + *, data, devices, **kw): + self.get_data_iter, self.steps = c.eval_input_pipeline( + keep_on_cpu={"answers", "answer", "question_id", "question_type"}, + data=data, devices=devices, **kw) + + self.outfile = c.resolve_outfile(outfile, split=data.get("split")) + + # We'll need the tokenizer to detokenize the model outputs later. + self.tok = big_vision.pp.tokenizer.get_tokenizer(tokenizer) + self.postproc = (lambda s: s.lower()) if to_lower else lambda s: s + self.decode = functools.partial( + predict_fn, devices=devices, eos_token=self.tok.eos_token) + + def run(self, train_state): + """Does one evaluation run, yields metrics.""" + + accuracies = [] + accuracies_any = [] + counts_per_type = {t: 0 for t in QUESTION_TYPES} + accuracies_per_type = {t: [] for t in QUESTION_TYPES} + anls_values = [] + json_out = [] + for _, batch in zip(range(self.steps), self.get_data_iter()): + # (batch, seqlen) array of decoded generated tokens. + tokens = self.decode(train_state, batch) # (B,L,E) + + # (local_batch,) that indicates padding examples (0) vs real examples (1). + tokens = u.get_local_slice_from_fsarray(tokens) + ex_masks = u.get_local_slice_from_fsarray(batch["_mask"]) + + # Turn predictions into texts and then scores, one by one. + for i in range(len(tokens)): + if ex_masks[i] == 0: # Skip last-batch padding examples + continue + + answer = self.postproc(self.tok.to_str(tokens[i], stop_at_eos=True)) + + # Now we have two commonly used VQA evaluation modes: + if "answer" in batch: + # single GT (eg ocrvqa): just compare to that answer, done. + gt = self.postproc(batch["answer"][i]) + gts = [gt] + accuracies.append(float(answer == gt)) + accuracies_any.append(float(answer == gt)) + anls_values.append(anls_metric(gt, answer)) + elif "answers" in batch and (gt_answers := batch["answers"][i]).size: + # multiple GTs (eg okvqa): introduced by VQA, compare to each of them + # with a threshold, see also: https://visualqa.org/evaluation.html + gts = [self.postproc(a) for a in gt_answers] + num_match = sum([answer == gt for gt in gts]) + accuracies.append(min(1.0, num_match / 3.0)) + accuracies_any.append(min(1.0, float(num_match))) + anls_values.append(max(anls_metric(gt, answer) for gt in gts)) + accuracies_per_type[batch["question_type"][i]].append( + accuracies_any[-1] + ) + counts_per_type[batch["question_type"][i]] += 1 + else: + gts = [] + + json_out.append({ + "question_id": batch["question_id"][i].item(), + "answer": answer} | ({"gts": gts} if gts else {})) + + # At this point `accuracies` is a list of per-example scores. However, + # remember that each host holds a different subset of the examples! So if + # we were to just return the mean accuracy here, we would effectively only + # have evaluated on the main host's (who writes metrics) subset! + # So now, we need to compute global means. + # There is one more caveat: `process_sum` needs the summands on each host + # to have the same size. So we either need to include dummy values for + # the padding examples (last batch, annoying), or we only sum scalars as in + # sufficient statistics, which we do here. + sum_accs, sum_accs_any, sum_anls, num_accs, num = c.process_sum( + [sum(accuracies), sum(accuracies_any), sum(anls_values), + len(accuracies), len(json_out)]) + + sum_accs_per_type, sum_cnts_per_type = c.process_sum( + [{k: sum(v) for k, v in accuracies_per_type.items()}, counts_per_type] + ) + + # Yielding metric_name, value means logging the metric. + if num_accs: + yield "acc", sum_accs / num_accs + yield "acc_any", sum_accs_any / num_accs # Overall Accuracy (OA). + yield "anls", sum_anls / num_accs + acc_types = {} + for k, v in sum_accs_per_type.items(): + if sum_cnts_per_type[k]: + acc_types[k] = v / sum_cnts_per_type[k] + yield f"acc_{k}", acc_types[k] + yield "acc_avg", sum(acc_types.values()) / len(acc_types) # Avg acc (AA). + for postfix, types in ACC_SUBSETS: + if all(t in acc_types for t in types): + yield f"acc_avg_{postfix}", sum( + [v for k, v in acc_types.items() if k in types] + ) / len(types) # Average accuracy per question types subset. + yield "num", num # Just for sanity checks. + c.multiprocess_write_json(self.outfile, json_out) + + +def anls_metric(target: str, prediction: str, theta: float = 0.5): + """Calculates ANLS for DocVQA. + + There does not seem to be an official evaluation script. + Public implementation on which this implementation is based: + https://github.com/herobd/layoutlmv2/blob/main/eval_docvqa.py#L92 + + Original paper (see Eq 1): https://arxiv.org/pdf/1907.00490.pdf + + Args: + target: Target string. + prediction: Predicted string. + theta: Filter threshold set to 0.5 for DocVQA. + + Returns: + ANLS score. + """ + if target: + edit_distance = editdistance.eval(target, prediction) + normalized_ld = edit_distance / max(len(target), len(prediction)) + return 1 - normalized_ld if normalized_ld < theta else 0 + else: + return float(prediction == "") diff --git a/big_vision/evaluators/proj/paligemma/transfers/science_qa.py b/big_vision/evaluators/proj/paligemma/transfers/science_qa.py new file mode 100644 index 0000000000000000000000000000000000000000..cbf61fe3785e631231f016f9b007ff572dca3c2d --- /dev/null +++ b/big_vision/evaluators/proj/paligemma/transfers/science_qa.py @@ -0,0 +1,122 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Evaluator for ScienceQA. + +based on the official implementation at +https://github.com/lupantech/ScienceQA/blob/main/models/run_gpt3.py +""" + +import functools +import re + +import big_vision.evaluators.common as c +import big_vision.pp.tokenizer +import big_vision.utils as u + + +# Temporary global flag to facilitate backwards compatability. Will be removed +# by the end of year 2023. +API = "jit" +FAILURE = "failed" + + +class Evaluator: + """Evaluator for simple VQA tasks.""" + + def __init__( + self, predict_fn, tokenizer, + outfile="{workdir}/{split}.json", + out_question_key="question_id", + *, data, devices, **kw): + self.get_data_iter, self.steps = c.eval_input_pipeline( + keep_on_cpu={"answer", "question_id"}, data=data, devices=devices, **kw) + + self.outfile = c.resolve_outfile(outfile, split=data.get("split")) + self.out_question_key = out_question_key + + # We'll need the tokenizer to detokenize the model outputs later. + self.tok = big_vision.pp.tokenizer.get_tokenizer(tokenizer) + self.decode = functools.partial( + predict_fn, devices=devices, eos_token=self.tok.eos_token + ) + + def postproc(self, raw_answer): + """Post-processes the raw answer. extract a, b, c from the string.""" + match = re.match( + pattern=r"the answer is ([a-z])\.", string=raw_answer.lower() + ) + if match: + return match.groups()[0] # 'a', 'b', ... + else: + return FAILURE + + def run(self, train_state): + """Does one evaluation run, yields metrics.""" + + accuracies = [] + fail_parse = [] + json_out = [] + for _, batch in zip(range(self.steps), self.get_data_iter()): + # (batch, seqlen) array of decoded generated tokens. + tokens = self.decode(train_state, batch) + + # (local_batch,) that indicates padding examples (0) vs real examples (1). + tokens = u.get_local_slice_from_fsarray(tokens) + ex_masks = u.get_local_slice_from_fsarray(batch["_mask"]) + + # Turn predictions into texts and then scores, one by one. + for i in range(len(tokens)): + if ex_masks[i] == 0: # Skip last-batch padding examples + continue + + raw_answer = self.tok.to_str(tokens[i], stop_at_eos=True) + answer = self.postproc(raw_answer) + if "answer" in batch: + gt = self.postproc(batch["answer"][i]) + gts = [gt] + accuracies.append(float(answer == gt)) + fail_parse.append(float(answer == FAILURE)) + else: + gts = [] + + json_out.append( + { + self.out_question_key: batch["question_id"][i].item(), + "raw_answer": raw_answer, + "answer": answer, + } + | ({"gts": gts} if gts else {}) + ) + + # At this point `accuracies` is a list of per-example scores. However, + # remember that each host holds a different subset of the examples! So if + # we were to just return the mean accuracy here, we would effectively only + # have evaluated on the main host's (who writes metrics) subset! + # So now, we need to compute global means. + # There is one more caveat: `process_sum` needs the summands on each host + # to have the same size. So we either need to include dummy values for + # the padding examples (last batch, annoying), or we only sum scalars as in + # sufficient statistics, which we do here. + sum_accs, num_parsefail, num_accs, num = c.process_sum( + [sum(accuracies), sum(fail_parse), len(accuracies), len(json_out)] + ) + + # Yielding metric_name, value means logging the metric. + if num_accs > 0: + yield "acc", sum_accs / num_accs + yield "parsefail", num_parsefail / num_accs + + yield "num", num # Just for sanity checks. + c.multiprocess_write_json(self.outfile, json_out) diff --git a/big_vision/evaluators/proj/paligemma/transfers/segmentation.py b/big_vision/evaluators/proj/paligemma/transfers/segmentation.py new file mode 100644 index 0000000000000000000000000000000000000000..bb4d1a004559824dba1438c561b4dcb657323537 --- /dev/null +++ b/big_vision/evaluators/proj/paligemma/transfers/segmentation.py @@ -0,0 +1,268 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Evaluator for segmentation.""" + +import functools + +import big_vision.evaluators.common as c +import big_vision.pp.tokenizer +import big_vision.utils as u +import flax.linen as nn +import jax +import jax.numpy as jnp +import numpy as np +import PIL.Image + +from tensorflow.io import gfile + + +# Temporary global flag to facilitate backwards compatability. Will be removed +# by the end of year 2023. +API = 'jit' + + +def _inrange(a, min_value, max_value): + return (np.clip(a, min_value, max_value) == a).all() + + +def _area(y1, x1, y2, x2): + return max(x2 - x1, 0.0) * max(y2 - y1, 0.0) + + +class Evaluator: + """Evaluator for instance segmentation.""" + + def __init__(self, predict_fn, tokenizer, + model='oi', det_ious=(0.5, 0.75), + *, devices, **kw): + self.get_data_iter, self.steps = c.eval_input_pipeline( + keep_on_cpu={'prefix', 'suffix', 'objects/mask', 'objects/bbox'}, + devices=devices, **kw) + + self.tok = big_vision.pp.tokenizer.get_tokenizer(tokenizer) + self.decode = functools.partial( + predict_fn, devices=devices, eos_token=self.tok.eos_token) + tok = big_vision.pp.tokenizer.get_tokenizer(tokenizer) + self.loc0 = np.array(tok.to_int('')) + self.seg0 = np.array(tok.to_int('')) + # Verify tokenizer has `tokensets=("loc", "seg")` + assert self.loc0.shape == (1,), self.loc0 + assert self.seg0.shape == (1,), self.seg0 + self.reconstruct_masks = get_reconstruct_masks(model) + self.det_ious = det_ious + + def run(self, train_state): + """Does one evaluation run, yields metrics.""" + ious = [] # NOTE: no point to split in s/m/l: all objects are L (>96px²) + det_by_iou = {iou: [] for iou in self.det_ious} + invalid = total = 0 + for _, batch in zip(range(self.steps), self.get_data_iter()): + + decoded = self.decode(train_state, batch) + + not_padding = u.get_local_slice_from_fsarray(batch['_mask']) + decoded = u.get_local_slice_from_fsarray(decoded)[not_padding] + + # Note, gt masks are in full original image resolution. + gt_masks = [gt[:, :, 0] > 0 for gt in batch['objects/mask'][not_padding]] + gt_bbs = [gt for gt in batch['objects/bbox'][not_padding]] + + valid = [] + tokens = np.zeros([decoded.shape[0], 4 + 16], np.int32) + for i, dec in enumerate(decoded): + # TODO: b/andstein - do we need to optimize this loop? + t = np.r_[dec[:4] - self.loc0, dec[4:4 + 16] - self.seg0] # Ignore rest + if ( + len(t) == 4 + 16 # Full prediction + and _inrange(t[:4], 0, 1023) # Valid box tokens + and _inrange(t[4:], 0, 127) # Valid seg tokens + and t[2] > t[0] and t[3] > t[1] # Valid box + ): + valid.append(True) + tokens[i] = t + else: + valid.append(False) + + tocpu = lambda x: jax.device_put(x, jax.local_devices(backend='cpu')[0]) + seg_indices = np.array(tokens[:, 4:]) + mask64 = jax.device_get(self.reconstruct_masks(tocpu(seg_indices))) + mask64 = mask64[..., 0] + bbox = tokens[:, :4] / 1023 # Back to [0.0 ... 1.0] + + for v, m64, gtm, bb, gtbb in zip(valid, mask64, gt_masks, bbox, gt_bbs): + # TODO: b/andstein - do we need to optimize this loop? + total += 1 + h, w = gtm.shape # gt is full/original image resolution mask. + + # First, compute detection iou, in [0.0 ... 1.0] coordinate space. + y1, x1, y2, x2 = bb + gty1, gtx1, gty2, gtx2 = gtbb + ibb = max(y1, gty1), max(x1, gtx1), min(y2, gty2), min(x2, gtx2) + box_iou = _area(*ibb) / (_area(*bb) + _area(*gtbb) - _area(*ibb)) + for iou_thresh in det_by_iou: + det_by_iou[iou_thresh].append(iou_thresh <= box_iou) + + # Next, we convert to pixel coordinates and compute mask iou. + gt_area = gtm.sum() + y1, x1, y2, x2 = map(int, (y1 * h, x1 * w, y2 * h, x2 * w)) + + # Avoid compute-intensive mask stuff for invalid preds: + if not v or x2 <= x1 or y2 <= y1: # Can still happen after int(). + iou = 0.0 + invalid += 1 + else: + mi = np.asarray(PIL.Image.fromarray(m64).resize( + [x2 - x1, y2 - y1], resample=PIL.Image.BILINEAR # pytype: disable=module-attr + )) # Predicted mask in box-sized image. + mi = mi > 0.0 # Mask decoder output in [-1.0 ... 1.0] + iarea = (gtm[y1:y2, x1:x2] & mi).sum() # Intersection pixels. + iou = iarea / (gt_area + mi.sum() - iarea) + ious.append(iou) + + # Done going over all batches, now collect results from all processes. + sum_ious, num_ious, sum_dets, num_dets, num_invalid, num = c.process_sum([ + sum(ious), len(ious), + {k: sum(v) for k, v in det_by_iou.items()}, + {k: len(v) for k, v in det_by_iou.items()}, + invalid, total + ]) + + yield 'miou', sum_ious / num_ious + for k in sum_dets: + yield f'boxacc/{k}', sum_dets[k] / num_dets[k] + yield 'invalid', num_invalid + yield 'total', num + + +_KNOWN_MODELS = { + # Trained on open images. + 'oi': 'gs://big_vision/paligemma/vae-oid.npz', +} + + +def _get_params(checkpoint): + """Converts PyTorch checkpoint to Flax params.""" + + def transp(kernel): + return np.transpose(kernel, (2, 3, 1, 0)) + + def conv(name): + return { + 'bias': checkpoint[name + '.bias'], + 'kernel': transp(checkpoint[name + '.weight']), + } + + def resblock(name): + return { + 'Conv_0': conv(name + '.0'), + 'Conv_1': conv(name + '.2'), + 'Conv_2': conv(name + '.4'), + } + + return { + '_embeddings': checkpoint['_vq_vae._embedding'], + 'Conv_0': conv('decoder.0'), + 'ResBlock_0': resblock('decoder.2.net'), + 'ResBlock_1': resblock('decoder.3.net'), + 'ConvTranspose_0': conv('decoder.4'), + 'ConvTranspose_1': conv('decoder.6'), + 'ConvTranspose_2': conv('decoder.8'), + 'ConvTranspose_3': conv('decoder.10'), + 'Conv_1': conv('decoder.12'), + } + + +def _quantized_values_from_codebook_indices(codebook_indices, embeddings): + batch_size, num_tokens = codebook_indices.shape + assert num_tokens == 16, codebook_indices.shape + unused_num_embeddings, embedding_dim = embeddings.shape + + encodings = jnp.take(embeddings, codebook_indices.reshape((-1)), axis=0) + encodings = encodings.reshape((batch_size, 4, 4, embedding_dim)) + return encodings + + +class ResBlock(nn.Module): + features: int + + @nn.compact + def __call__(self, x): + original_x = x + x = nn.Conv(features=self.features, kernel_size=(3, 3), padding=1)(x) + x = nn.relu(x) + x = nn.Conv(features=self.features, kernel_size=(3, 3), padding=1)(x) + x = nn.relu(x) + x = nn.Conv(features=self.features, kernel_size=(1, 1), padding=0)(x) + return x + original_x + + +class Decoder(nn.Module): + """Upscales quantized vectors to mask.""" + + @nn.compact + def __call__(self, x): + num_res_blocks = 2 + dim = 128 + num_upsample_layers = 4 + + x = nn.Conv(features=dim, kernel_size=(1, 1), padding=0)(x) + x = nn.relu(x) + + for _ in range(num_res_blocks): + x = ResBlock(features=dim)(x) + + for _ in range(num_upsample_layers): + x = nn.ConvTranspose( + features=dim, + kernel_size=(4, 4), + strides=(2, 2), + padding=2, + transpose_kernel=True, + )(x) + x = nn.relu(x) + dim //= 2 + + x = nn.Conv(features=1, kernel_size=(1, 1), padding=0)(x) + + return x + + +@functools.cache +def get_reconstruct_masks(model): + """Reconstructs masks from codebook indices. + + Based on code from https://arxiv.org/abs/2301.02229 + + Verified in + https://colab.research.google.com/drive/1AOr0cokOpM6-N9Z5HmxoeGxGj6jS37Vl + + Args: + model: Model to use for conversion. + + Returns: + A function that expects indices shaped `[B, 16]` of dtype int32, each + ranging from 0 to 127 (inclusive), and that returns a decoded masks sized + `[B, 64, 64, 1]`, of dtype float32, in range [-1, 1]. + """ + def reconstruct_masks(codebook_indices): + quantized = _quantized_values_from_codebook_indices( + codebook_indices, params['_embeddings'] + ) + return Decoder().apply({'params': params}, quantized) + + with gfile.GFile(_KNOWN_MODELS.get(model, model), 'rb') as f: + params = _get_params(dict(np.load(f))) + + return jax.jit(reconstruct_masks, backend='cpu') diff --git a/big_vision/evaluators/proj/paligemma/transfers/storepreds.py b/big_vision/evaluators/proj/paligemma/transfers/storepreds.py new file mode 100644 index 0000000000000000000000000000000000000000..ee7314754a00c49dc9d7a5583c860607530e07f0 --- /dev/null +++ b/big_vision/evaluators/proj/paligemma/transfers/storepreds.py @@ -0,0 +1,77 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Evaluator to run inference and store results.""" +import functools + +import big_vision.evaluators.common as c +import big_vision.input_pipeline +import big_vision.pp.builder +import big_vision.pp.tokenizer +import big_vision.utils as u + +import jax + +# Temporary global flag to facilitate backwards compatability. Will be removed +# by the end of year 2023. +API = "jit" + + +class Evaluator: + """Evaluator to run inference and store results.""" + + def __init__( + self, predict_fn, tokenizer=None, + preds_outfile="{workdir}/{name}_{split}_preds.json", + annot_outfile="{workdir}/{name}_{split}_annotations.json", + id_key="id", + *, data, devices, **kw + ): + self.id_key = id_key + self.get_data_iter, self.steps = c.eval_input_pipeline( + keep_on_cpu={id_key}, data=data, devices=devices, **kw) + + self.preds_outfile = c.resolve_outfile( + preds_outfile, name=data.get("name"), split=data.get("split", "")) + self.annot_outfile = c.resolve_outfile( + annot_outfile, name=data.get("name"), split=data.get("split", "")) + + self.tok = big_vision.pp.tokenizer.get_tokenizer(tokenizer) + self.decode = functools.partial( + predict_fn, devices=devices, eos_token=self.tok.eos_token) + + def run(self, train_state): + """Run eval.""" + res = [] + + for _, batch in zip(range(self.steps), self.get_data_iter()): + # (batch, seqlen) array of decoded generated tokens. + tokens = self.decode(train_state, batch) + + # (local_batch,) + tokens = u.get_local_slice_from_fsarray(tokens) + ex_masks = u.get_local_slice_from_fsarray(batch["_mask"]) + + image_ids = batch[self.id_key][ex_masks] + pred_captions = self.tok.to_str(tokens[ex_masks]) + + for image_id, caption in zip(image_ids, pred_captions): + res.append({self.id_key: str(image_id), "caption": caption}) + + res = c.multiprocess_write_json(self.preds_outfile, res) + + if jax.process_index(): # Host0 gets all preds and does eval. + return + + yield "num_examples", len(res) diff --git a/big_vision/evaluators/proj/paligemma/transfers/tallyqa.py b/big_vision/evaluators/proj/paligemma/transfers/tallyqa.py new file mode 100644 index 0000000000000000000000000000000000000000..b2f219a5737ece149798004f0cf3ba6100f75385 --- /dev/null +++ b/big_vision/evaluators/proj/paligemma/transfers/tallyqa.py @@ -0,0 +1,144 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Evaluator for TallyQA dataset.""" + +import functools + +import big_vision.evaluators.common as c +import big_vision.pp.tokenizer +import big_vision.utils as u + + +# Temporary global flag to facilitate backwards compatability. Will be removed +# by the end of year 2023. +API = "jit" + + +# Largest count we want to track. +_LARGEST_COUNT = 15 + + +class Evaluator: + """TallyQA evaluator.""" + + def __init__(self, predict_fn, tokenizer, *, devices, **kw): + self.get_data_iter, self.steps = c.eval_input_pipeline( + keep_on_cpu={"answer", "issimple"}, devices=devices, **kw) + + # We'll need the tokenizer to detokenize the model outputs later. + self.tok = big_vision.pp.tokenizer.get_tokenizer(tokenizer) + self.decode = functools.partial( + predict_fn, devices=devices, eos_token=self.tok.eos_token + ) + + def run(self, train_state): + """Does one evaluation run, yields metrics.""" + + accuracies_by_type = {"all": [], "simple": [], "complex": []} + # Add per-count entries. Cannot use a `defaultdict` as we need to `tree_map` + # over keys later in `c.process_sum`. + accuracies_by_type.update( + {f"count_{i}": [] for i in range(_LARGEST_COUNT + 1)} + ) + + for _, batch in zip(range(self.steps), self.get_data_iter()): + # (batch, seqlen) array of decoded (generated) token sequences suffixes. + tokens = self.decode(train_state, batch) + + # (local_batch,) that indicates padding examples (0) vs real examples (1). + tokens = u.get_local_slice_from_fsarray(tokens) + ex_masks = u.get_local_slice_from_fsarray(batch["_mask"]) + + # Turn predictions into texts and then scores, one by one. + # We always compare the gt (string digit, e.g. "1") to the answer by the + # model (e.g. "1"). + for i in range(len(tokens)): + if ex_masks[i] == 0: # Skip last-batch padding examples + continue + + # Extract the suffix/answer from the generated string, skip bos. + answer = self.tok.to_str(tokens[i], stop_at_eos=True) + # Standardize the reponse, i.e., convert number words ("one") to + # numerals ("1"). + answer = _number_word_to_numeral(answer) + + # Always need to do light space-processing: + gt = _number_word_to_numeral(batch["answer"][i]) + accuracies_by_type["all"].append(float(answer == gt)) + + if "issimple" in batch: + # Simple/complex split. + if batch["issimple"][i] == 1: + accuracies_by_type["simple"].append(float(answer == gt)) + elif batch["issimple"][i] == 0: + accuracies_by_type["complex"].append(float(answer == gt)) + else: + # Train set is not annotated with simple/complex (but has dummy + # value of `-1` in this field). + pass + + # Store accuracies per count. + accuracies_by_type[f"count_{gt}"].append(float(answer == gt)) + + # At this point `accuracies` is a list of per-example scores. However, + # remember that each host holds a different subset of the examples! So if + # we were to just return the mean accuracy here, we would effectively only + # have evaluated on the main host's (who writes metrics) subset! + # So now, we need to compute global means. + # There is one more caveat: `process_sum` needs the summands on each host + # to have the same size. So we either need to include dummy values for + # the padding examples (last batch, annoying), or we only sum scalars as in + # sufficient statistics, which we do here. + sum_accs = c.process_sum({k: sum(v) for k, v in accuracies_by_type.items()}) + num_accs = c.process_sum({k: len(v) for k, v in accuracies_by_type.items()}) + + if n := num_accs["all"]: + yield "acc", sum_accs["all"] / n + yield "num", n # Just for sanity checks. + for key in sum_accs.keys(): + if (key != "all") and (num_accs[key]): + yield f"acc/{key}", sum_accs[key] / num_accs[key] + yield f"num/{key}", num_accs[key] # Just for sanity checks. + + +def _number_word_to_numeral(s: str) -> str: + """Returns numeral for a given number word, e.g., "one" -> "1" (up to 20).""" + return REPLACEMENTS.get(s.lower(), s) + + +REPLACEMENTS = { + "none": "0", + "zero": "0", + "one": "1", + "two": "2", + "three": "3", + "four": "4", + "five": "5", + "six": "6", + "seven": "7", + "eight": "8", + "nine": "9", + "ten": "10", + "eleven": "11", + "twelve": "12", + "thirteen": "13", + "fourteen": "14", + "fifteen": "15", + "sixteen": "16", + "seventeen": "17", + "eighteen": "18", + "nineteen": "19", + "twenty": "20", +} diff --git a/big_vision/evaluators/proj/paligemma/transfers/vqa.py b/big_vision/evaluators/proj/paligemma/transfers/vqa.py new file mode 100644 index 0000000000000000000000000000000000000000..9f6ddccaf3640306ec9bbd9fe62e3a34ca78b78b --- /dev/null +++ b/big_vision/evaluators/proj/paligemma/transfers/vqa.py @@ -0,0 +1,163 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Evaluator for simple VQA variants (OCR-VQA, OKVQA, A-OKVQA). + +According to the (A-)OKVAQ papers, the eval for these datasets should follow +VQAv2. But here we don't track different answer-types, and don't do any +leave-one-out averaging, as this isn't done in the official implementation at +https://github.com/allenai/aokvqa/blob/main/evaluation/eval_predictions.py +either. + +Please read the description of how evaluators work at (internal link). +This evaluator follows the pattern of also parallelizing the CPU computations +(ie postprocessing, score computation) across hosts for more scalability. + +For now, simple decoding is implemented as part of the evaluator. We'll soon +unify and move to a library of decoding functions, including fancier and more +efficient ones. +""" +import functools + +import big_vision.evaluators.common as c +import big_vision.pp.tokenizer +import big_vision.utils as u +import editdistance + + +# Temporary global flag to facilitate backwards compatability. Will be removed +# by the end of year 2023. +API = "jit" + + +class Evaluator: + """Evaluator for simple VQA tasks. + + This evaluator expects the batch to contain a field `question_id` and a field + `answer` for single ground truth or `answers` for multiple ground truths. + + The field names used when writting the json result can be controlled with + `out_question_key` and `out_answer_key`. + """ + + def __init__( + self, predict_fn, tokenizer, to_lower=False, + outfile="{workdir}/{split}.json", + out_question_key="question_id", out_answer_key="answer", + *, data, devices, **kw): + self.get_data_iter, self.steps = c.eval_input_pipeline( + keep_on_cpu={"answers", "answer", "question_id"}, + data=data, devices=devices, **kw) + + self.outfile = c.resolve_outfile(outfile, split=data.get("split")) + self.out_question_key = out_question_key + self.out_answer_key = out_answer_key + + # We'll need the tokenizer to detokenize the model outputs later. + self.tok = big_vision.pp.tokenizer.get_tokenizer(tokenizer) + self.postproc = (lambda s: s.lower()) if to_lower else lambda s: s + self.decode = functools.partial( + predict_fn, devices=devices, eos_token=self.tok.eos_token) + + def run(self, train_state): + """Does one evaluation run, yields metrics.""" + + accuracies = [] + accuracies_any = [] + anls_values = [] + json_out = [] + for _, batch in zip(range(self.steps), self.get_data_iter()): + # (batch, seqlen) array of decoded generated tokens. + tokens = self.decode(train_state, batch) + + # (local_batch,) that indicates padding examples (0) vs real examples (1). + tokens = u.get_local_slice_from_fsarray(tokens) + ex_masks = u.get_local_slice_from_fsarray(batch["_mask"]) + + # Turn predictions into texts and then scores, one by one. + for i in range(len(tokens)): + if ex_masks[i] == 0: # Skip last-batch padding examples + continue + + answer = self.postproc(self.tok.to_str(tokens[i], stop_at_eos=True)) + + # Now we have two commonly used VQA evaluation modes: + if "answer" in batch: + # single GT (eg ocrvqa): just compare to that answer, done. + gt = self.postproc(batch["answer"][i]) + gts = [gt] + accuracies.append(float(answer == gt)) + accuracies_any.append(float(answer == gt)) + anls_values.append(anls_metric(gt, answer)) + elif "answers" in batch and (gt_answers := batch["answers"][i]).size: + # multiple GTs (eg okvqa): introduced by VQA, compare to each of them + # with a threshold, see also: https://visualqa.org/evaluation.html + gts = [self.postproc(a) for a in gt_answers] + num_match = sum([answer == gt for gt in gts]) + accuracies.append(min(1.0, num_match / 3.0)) + accuracies_any.append(min(1.0, float(num_match))) + anls_values.append(max(anls_metric(gt, answer) for gt in gts)) + else: + gts = [] + + json_out.append({ + self.out_question_key: batch["question_id"][i].item(), + self.out_answer_key: answer} | ({"gts": gts} if gts else {})) + + # At this point `accuracies` is a list of per-example scores. However, + # remember that each host holds a different subset of the examples! So if + # we were to just return the mean accuracy here, we would effectively only + # have evaluated on the main host's (who writes metrics) subset! + # So now, we need to compute global means. + # There is one more caveat: `process_sum` needs the summands on each host + # to have the same size. So we either need to include dummy values for + # the padding examples (last batch, annoying), or we only sum scalars as in + # sufficient statistics, which we do here. + sum_accs, sum_accs_any, sum_anls, num_accs, num = c.process_sum( + [sum(accuracies), sum(accuracies_any), sum(anls_values), + len(accuracies), len(json_out)]) + + # Yielding metric_name, value means logging the metric. + if num_accs: + yield "acc", sum_accs / num_accs + yield "acc_any", sum_accs_any / num_accs + yield "anls", sum_anls / num_accs + + yield "num", num # Just for sanity checks. + c.multiprocess_write_json(self.outfile, json_out) + + +def anls_metric(target: str, prediction: str, theta: float = 0.5): + """Calculates ANLS for DocVQA. + + There does not seem to be an official evaluation script. + Public implementation on which this implementation is based: + https://github.com/herobd/layoutlmv2/blob/main/eval_docvqa.py#L92 + + Original paper (see Eq 1): https://arxiv.org/pdf/1907.00490.pdf + + Args: + target: Target string. + prediction: Predicted string. + theta: Filter threshold set to 0.5 for DocVQA. + + Returns: + ANLS score. + """ + if target: + edit_distance = editdistance.eval(target, prediction) + normalized_ld = edit_distance / max(len(target), len(prediction)) + return 1 - normalized_ld if normalized_ld < theta else 0 + else: + return float(prediction == "") diff --git a/big_vision/evaluators/proj/paligemma/transfers/vqav2.py b/big_vision/evaluators/proj/paligemma/transfers/vqav2.py new file mode 100644 index 0000000000000000000000000000000000000000..5f6f2a1839a6f9644674ca20de880edce22ddd11 --- /dev/null +++ b/big_vision/evaluators/proj/paligemma/transfers/vqav2.py @@ -0,0 +1,197 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Evaluator for VQAV2 dataset. +""" +import functools +import re + +import big_vision.evaluators.common as c +import big_vision.pp.tokenizer +import big_vision.utils as u +import numpy as np + + +# Temporary global flag to facilitate backwards compatability. Will be removed +# by the end of year 2023. +API = "jit" + + +class Evaluator: + """VQAv2 evaluator.""" + + def __init__( + self, predict_fn, tokenizer, outfile="{workdir}/{split}.json", + *, data, devices, **kw): + self.get_data_iter, self.steps = c.eval_input_pipeline( + keep_on_cpu={"answers", "answer_type", "question_type", "question_id"}, + data=data, devices=devices, **kw) + + self.outfile = c.resolve_outfile(outfile, split=data.get("split")) + + # We'll need the tokenizer to detokenize the model outputs later. + self.tok = big_vision.pp.tokenizer.get_tokenizer(tokenizer) + self.decode = functools.partial( + predict_fn, devices=devices, eos_token=self.tok.eos_token) + + def run(self, train_state): + """Does one evaluation run, yields metrics.""" + accuracies_by_type = {"yes/no": [], "number": [], "other": []} + json_out = [] + + for _, batch in zip(range(self.steps), self.get_data_iter()): + # (batch, seqlen) array of decoded (generated) token sequences suffixes. + tokens = self.decode(train_state, batch) + + # (local_batch,) that indicates padding examples (0) vs real examples (1). + tokens = u.get_local_slice_from_fsarray(tokens) + ex_masks = u.get_local_slice_from_fsarray(batch["_mask"]) + + # Turn predictions into texts and then scores, one by one. + for i in range(len(tokens)): + if ex_masks[i] == 0: # Skip last-batch padding examples + continue + + # Extract the suffix/answer from the generated string, skip bos. + answer = self.tok.to_str(tokens[i], stop_at_eos=True) + json = {"question_id": batch["question_id"][i].item(), "answer": answer} + + # The rest is computation of VQA-score which compares to multiple GTs. + # This is described better here: https://visualqa.org/evaluation.html + if (gt_answers := batch["answers"][i]).size: + # Always need to do light space-processing: + gt_answers = [stripspace_vqav2(a) for a in gt_answers] + answer = stripspace_vqav2(answer) + + # Only post-process if not all agree. Supposedly avoids postproc OCR: + # https://github.com/GT-Vision-Lab/VQA/issues/14#issuecomment-1334695361 + if len(set(gt_answers)) > 1: + answer = postprocess_vqav2_text(answer) + gt_answers = [postprocess_vqav2_text(a) for a in gt_answers] + + # Accuracy is avg over all ten leave-one-out GT's. + # https://github.com/GT-Vision-Lab/VQA/issues/1#issuecomment-199921352 + # An answer is counted 100% correct as soon as 3 GT's agree with it. + matches = answer == np.array(gt_answers) + acc = np.mean([ + np.clip(np.sum(np.delete(matches, i_leave_out)) / 3, 0, 1) + for i_leave_out in range(10) + ]) + + accuracies_by_type[batch["answer_type"][i]].append(acc) + + # Update json with fully post-processed answer and gt: + json["answer_raw"] = json["answer"] + json["answer"] = answer + json["gts"] = gt_answers + + json_out.append(json) + + # At this point `accuracies` is a list of per-example scores. However, + # remember that each host holds a different subset of the examples! So if + # we were to just return the mean accuracy here, we would effectively only + # have evaluated on the main host's (who writes metrics) subset! + # So now, we need to compute global means. + # There is one more caveat: `process_sum` needs the summands on each host + # to have the same size. So we either need to include dummy values for + # the padding examples (last batch, annoying), or we only sum scalars as in + # sufficient statistics, which we do here. + sum_accs = c.process_sum({k: sum(v) for k, v in accuracies_by_type.items()}) + num_accs = c.process_sum({k: len(v) for k, v in accuracies_by_type.items()}) + num = c.process_sum(len(json_out)) + + # Yielding metric_name, value means logging the metric. + if n := sum(num_accs.values()): + yield "acc", sum(sum_accs.values()) / n + if n := num_accs["yes/no"]: + yield "acc/yesno", sum_accs["yes/no"] / n + yield "num/yesno", n + if n := num_accs["number"]: + yield "acc/number", sum_accs["number"] / n + yield "num/number", n + if n := num_accs["other"]: + yield "acc/other", sum_accs["other"] / n + yield "num/other", n + + yield "num", num # Just for sanity checks. + c.multiprocess_write_json(self.outfile, json_out) + + +# Post-processing required is described at https://visualqa.org/evaluation.html + + +def stripspace_vqav2(txt): + return txt.replace("\n", " ").replace("\t", " ").strip() + + +def postprocess_vqav2_text(txt): + """Cleanup string according to VQA.""" + has_digit_comma = re.search(r"(\d)(\,)(\d)", txt) is not None + + out = txt + for p in PUNCT: + # NOTE: digit_comma here looks like a bug in official code, so we follow it. + if has_digit_comma or f"{p} " in txt or f" {p}" in txt: + out = out.replace(p, "") + else: + out = out.replace(p, " ") + + # Remove full-stops that aren't part of a number. + out = re.sub(r"(?!<=\d)(\.)(?!\d)", "", out, flags=re.UNICODE) + + words = [] + for word in out.lower().split(): + if word not in ARTICLES: + words.append(REPLACEMENTS.get(word, word)) + return " ".join(words) + + +# pylint: disable=line-too-long +REPLACEMENTS = { + # CONTRACTIONS + "aint": "ain't", "arent": "aren't", "cant": "can't", "couldve": "could've", "couldnt": "couldn't", + "couldn'tve": "couldn't've", "couldnt've": "couldn't've", "didnt": "didn't", "doesnt": "doesn't", "dont": "don't", "hadnt": "hadn't", + "hadnt've": "hadn't've", "hadn'tve": "hadn't've", "hasnt": "hasn't", "havent": "haven't", "hed": "he'd", "hed've": "he'd've", + "he'dve": "he'd've", "hes": "he's", "howd": "how'd", "howll": "how'll", "hows": "how's", "Id've": "I'd've", "I'dve": "I'd've", + "Im": "I'm", "Ive": "I've", "isnt": "isn't", "itd": "it'd", "itd've": "it'd've", "it'dve": "it'd've", "itll": "it'll", "let's": "let's", + "maam": "ma'am", "mightnt": "mightn't", "mightnt've": "mightn't've", "mightn'tve": "mightn't've", "mightve": "might've", + "mustnt": "mustn't", "mustve": "must've", "neednt": "needn't", "notve": "not've", "oclock": "o'clock", "oughtnt": "oughtn't", + "ow's'at": "'ow's'at", "'ows'at": "'ow's'at", "'ow'sat": "'ow's'at", "shant": "shan't", "shed've": "she'd've", "she'dve": "she'd've", + "she's": "she's", "shouldve": "should've", "shouldnt": "shouldn't", "shouldnt've": "shouldn't've", "shouldn'tve": "shouldn't've", + "somebody'd": "somebodyd", "somebodyd've": "somebody'd've", "somebody'dve": "somebody'd've", "somebodyll": "somebody'll", + "somebodys": "somebody's", "someoned": "someone'd", "someoned've": "someone'd've", "someone'dve": "someone'd've", + "someonell": "someone'll", "someones": "someone's", "somethingd": "something'd", "somethingd've": "something'd've", + "something'dve": "something'd've", "somethingll": "something'll", "thats": "that's", "thered": "there'd", "thered've": "there'd've", + "there'dve": "there'd've", "therere": "there're", "theres": "there's", "theyd": "they'd", "theyd've": "they'd've", + "they'dve": "they'd've", "theyll": "they'll", "theyre": "they're", "theyve": "they've", "twas": "'twas", "wasnt": "wasn't", + "wed've": "we'd've", "we'dve": "we'd've", "weve": "we've", "werent": "weren't", "whatll": "what'll", "whatre": "what're", + "whats": "what's", "whatve": "what've", "whens": "when's", "whered": "where'd", "wheres": "where's", "whereve": "where've", + "whod": "who'd", "whod've": "who'd've", "who'dve": "who'd've", "wholl": "who'll", "whos": "who's", "whove": "who've", "whyll": "why'll", + "whyre": "why're", "whys": "why's", "wont": "won't", "wouldve": "would've", "wouldnt": "wouldn't", "wouldnt've": "wouldn't've", + "wouldn'tve": "wouldn't've", "yall": "y'all", "yall'll": "y'all'll", "y'allll": "y'all'll", "yall'd've": "y'all'd've", + "y'alld've": "y'all'd've", "y'all'dve": "y'all'd've", "youd": "you'd", "youd've": "you'd've", "you'dve": "you'd've", + "youll": "you'll", "youre": "you're", "youve": "you've", + # NUMBERS + "none": "0", "zero": "0", "one": "1", "two": "2", + "three": "3", "four": "4", "five": "5", "six": "6", + "seven": "7", "eight": "8", "nine": "9", "ten": "10", +} +# pylint: enable=line-too-long + +PUNCT = [ + ";", "/", "[", "]", "\"", "{", "}", + "(", ")", "=", "+", "\\", "_", "-", + ">", "<", "@", "`", ",", "?", "!" +] +ARTICLES = {"a", "an", "the"} diff --git a/big_vision/evaluators/proj/uvim/coco_panoptic.py b/big_vision/evaluators/proj/uvim/coco_panoptic.py new file mode 100644 index 0000000000000000000000000000000000000000..b20d4458826dfe568b5a8f85246e1635c91a3f3b --- /dev/null +++ b/big_vision/evaluators/proj/uvim/coco_panoptic.py @@ -0,0 +1,324 @@ +# Copyright 2022 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""COCO17 panoptic evaluation.""" +import functools +from functools import partial +import json +import os +import tempfile +import time +import zipfile + +from absl import logging +from big_vision.evaluators.proj.uvim import common +import big_vision.pp.builder as pp_builder +import jax +import numpy as np +import panopticapi_converters.twochannels2panoptic_coco_format as converter +from panopticapi.evaluation import pq_compute +import tensorflow as tf +import tensorflow_datasets as tfds + +from tensorflow.io import gfile + + +ROOT = os.environ.get('COCO_DATA_DIR', '.') + +PANOPTIC_COCO_CATS_FILE = f'{ROOT}/panoptic_coco_categories.json' +PANOPTIC_2017 = { + 'train': f'{ROOT}/panoptic_train2017.json', + 'validation': f'{ROOT}/panoptic_val2017.json', +} + +PANOPTIC_GT_ZIP = { + 'train': f'{ROOT}/panoptic_train2017.zip', + 'validation': f'{ROOT}/panoptic_val2017.zip', +} + + +class Evaluator: + """Panoptic segmentation evaluator: calls official COCO API. + + `predict_fn` accepts arbitrary dictionaries of parameters and data, where + the data dictionary is produced by the `pp` op. It is expected to output a + 2-channel mask, where the first channel encodes semantics, and the second + channel encodes instance ids. + """ + + def __init__(self, + predict_fn, + pp_fn, + batch_size, + dataset='coco/2017_panoptic', + dataset_dir=None, + split='validation', + predict_kwargs=None): + # Prepare to run predict on all processes and gather predictions on all + # devices. Note: if needed consider only gather across processes. + def predict(params, batch): + res = { + 'image/id': batch['image/id'], + 'mask': batch['mask'], + 'y': predict_fn(params, batch['input'], **(predict_kwargs or {})), + } + return jax.lax.all_gather(res, axis_name='data', axis=0) + + self.predict_fn = jax.pmap(predict, axis_name='data') + + # Prepare data for each process and pad with zeros so all processes have the + # same number of batches. + def preprocess(example): + return { + 'image/id': example['image/id'], + 'mask': tf.constant(1), + 'input': pp_builder.get_preprocess_fn(pp_fn)(example), + } + + self.data = common.get_jax_process_dataset( + dataset, split, dataset_dir=dataset_dir, + global_batch_size=batch_size, + pp_fn=preprocess) + + # Only process 0 runs conversion to png and calls into coco api. + if jax.process_index() == 0: + self.result_dir = tempfile.TemporaryDirectory() + (self.gt_folder, self.gt_json, self.categories_json, + self.remap, self.size_map) = _prepare_ground_truth( + dataset, split, dataset_dir) + + def _compute_png_predictions(self, params): + """Computes predictions and converts then to png to optimize memory use.""" + count = 0 + logging.info('Panoptic eval: running inference.') + for batch in self.data.as_numpy_iterator(): + out = self.predict_fn(params, batch) + + if jax.process_index(): + continue + + out = jax.device_get(jax.tree_map(lambda x: x[0], out)) + mask = out['mask'] + pan_recs = out['y'][mask != 0] + ids = out['image/id'][mask != 0] + + for pan_rec, image_id in zip(pan_recs, ids): + sem = pan_rec[..., 0] + ins = pan_rec[..., 1] + + sem_remapped = np.array(sem) + for v in np.unique(sem): + sem_remapped[sem == v] = self.remap[v] + sem = sem_remapped + + pan_mask = np.stack([sem, ins, np.zeros_like(sem)], axis=-1) + pan_mask = _resize_nearest(pan_mask, self.size_map[image_id]) + pan_mask_png = tf.io.encode_png(pan_mask.astype('uint8')).numpy() + + fname = f'{self.result_dir.name}/{image_id:012d}.png' + with open(fname, 'wb') as f: + f.write(pan_mask_png) + count += 1 + + logging.log_every_n_seconds( + logging.INFO, 'Panoptic eval: processed %i examples so far.', 30, + count) + + if jax.process_index(): + return None + + logging.info('Panoptic eval: inference done. Processed %d examples.', count) + return self.result_dir + + def run(self, params): + """Run eval.""" + # Note result_dir is constant, but files inside are mutated. + result_dir = self._compute_png_predictions(params) + + if not result_dir: + return + + with tempfile.TemporaryDirectory() as pred_folder, \ + tempfile.NamedTemporaryFile(mode='w') as pred_json: + + logging.info('Panoptic eval: running conversion.') + converter.converter( + source_folder=result_dir.name, + images_json_file=self.gt_json, + categories_json_file=self.categories_json, + segmentations_folder=pred_folder, + predictions_json_file=pred_json.name) + logging.info('Panoptic eval: conversion done.') + + logging.info('Panoptic eval: running metrics computation.') + res = pq_compute(gt_json_file=self.gt_json, + gt_folder=self.gt_folder, + pred_json_file=pred_json.name, + pred_folder=pred_folder) + logging.info('Panoptic eval: metrics computation done.') + + for k in ['All', 'Stuff', 'Things']: + for m in ['pq', 'rq', 'sq']: + yield f'{k}_{m}', res[k][m] + + +def _prepare_ground_truth(dataset, split, data_dir): + """Prepare ground truth from tf.data.Dataset.""" + if dataset == 'coco/2017_panoptic' and data_dir is None: + return _prepare_ground_truth_from_zipfiles(split) + else: + return _prepare_ground_truth_from_dataset(dataset, split, data_dir) + + +@functools.lru_cache(maxsize=None) +def _prepare_ground_truth_from_dataset(dataset, split, data_dir): + """Prepare ground truth from a tf.data.Dataset.""" + dataset = tfds.builder(dataset, data_dir=data_dir).as_dataset(split=split) + + categories_json = _make_local_copy(PANOPTIC_COCO_CATS_FILE) + with gfile.GFile(categories_json, 'rb') as f: + categories = json.loads(f.read()) + + # Build map from tfds class ids to COCO class ids. + remap = {0: 0} + with gfile.GFile(categories_json, 'r') as f: + remap = {**remap, **{(i + 1): x['id'] for i, x in enumerate(categories)}} + + gt_folder = tempfile.mkdtemp() + gfile.makedirs(gt_folder) + size_map = {} + annotations = [] + images = [] + for example in dataset: + image_id = int(example['image/id']) + panoptic_image = example['panoptic_image'] + ann_ids = example['panoptic_objects']['id'] + ann_labels = example['panoptic_objects']['label'] + ann_iscrowd = example['panoptic_objects']['is_crowd'] + ann_area = example['panoptic_objects']['area'] + + fname = f'{image_id:012d}.png' + with gfile.GFile(os.path.join(gt_folder, fname), 'wb') as f: + f.write(tf.io.encode_png(panoptic_image).numpy()) + + size_map[image_id] = (panoptic_image.shape[0], panoptic_image.shape[1]) + + segments_info = [] + for i in range(len(ann_ids)): + segments_info.append({ + 'id': int(ann_ids[i]), + 'category_id': remap[int(ann_labels[i] + 1)], + 'iscrowd': int(ann_iscrowd[i]), + 'area': int(ann_area[i]), + }) + + annotations.append({ + 'file_name': str(fname), + 'image_id': int(image_id), + 'segments_info': segments_info + }) + images.append({ + 'id': image_id, + 'file_name': f'{image_id:012d}.jpg', + }) + + # Write annotations.json needed for pq_compute. + gt_json = os.path.join(gt_folder, 'annotations.json') + with gfile.GFile(gt_json, 'wb') as f: + f.write(json.dumps({ + 'images': images, + 'annotations': annotations, + 'categories': categories, + })) + + return gt_folder, gt_json, categories_json, remap, size_map + + +def _prepare_ground_truth_from_zipfiles(split): + """Prepare ground truth from coco zip files.""" + split_prefix = split.split('[')[0] + if split_prefix not in ('train', 'validation'): + raise ValueError(f'Split {split} not supported') + + # The following 4 calls are cached. This allows to save significant time + # in use cases like sweeping predict_fn hparams on the same run. + gt_json = _make_local_copy(PANOPTIC_2017[split_prefix]) + gt_folder = _make_local_unzip_copy(PANOPTIC_GT_ZIP[split_prefix]) + categories_json = _make_local_copy(PANOPTIC_COCO_CATS_FILE) + image_ids = _list_image_ids('coco/2017_panoptic', split) + + gt_folder = os.path.join( + gt_folder, 'panoptic_val2017' + if split_prefix == 'validation' else 'panoptic_train2017') + + # Build map from tfds class ids to COCO class ids. + remap = {0: 0} + with gfile.GFile(categories_json, 'r') as f: + remap = {**remap, **{(i + 1): x['id'] for i, x in enumerate(json.load(f))}} + + # Filters gt_json to contain only annotations for images in dataset. + with gfile.GFile(gt_json) as f: + data = json.load(f) + logging.info( + 'Panoptic eval: pre-filter %d annotations.', + len(data['annotations']) + ) + data['images'] = [x for x in data['images'] if x['id'] in image_ids] + data['annotations'] = [ + x for x in data['annotations'] if x['image_id'] in image_ids + ] + logging.info( + 'Panoptic eval: post-filter %d annotations.', + len(data['annotations']) + ) + filtered_gt_json = tempfile.NamedTemporaryFile(delete=False).name + with open(filtered_gt_json, 'w') as f: + json.dump(data, f) + + # Precompute images sizes. + size_map = {x['id']: (x['height'], x['width']) for x in data['images']} + + return gt_folder, filtered_gt_json, categories_json, remap, size_map + + +@functools.lru_cache(maxsize=None) +def _list_image_ids(dataset, split): + d = tfds.load(dataset, split=split).map(lambda x: x['image/id']) + return frozenset(d.as_numpy_iterator()) + + +@functools.lru_cache(maxsize=None) +def _make_local_copy(fname) -> str: + start = time.monotonic() + local_file = tempfile.NamedTemporaryFile(delete=False) + gfile.copy(fname, local_file.name, overwrite=True) + logging.info('Copy %s in %d seconds.', fname, time.monotonic() - start) + return local_file.name + + +@functools.lru_cache(maxsize=None) +def _make_local_unzip_copy(fname) -> str: + start = time.monotonic() + folder = tempfile.mkdtemp() + with tempfile.NamedTemporaryFile() as tmp_zip_file: + gfile.copy(fname, tmp_zip_file.name, overwrite=True) + with zipfile.ZipFile(tmp_zip_file.name, 'r') as f: + f.extractall(folder) + logging.info('Copy %s in %d seconds.', fname, time.monotonic() - start) + return folder + + +@partial(jax.jit, static_argnums=(1,), backend='cpu') +def _resize_nearest(image, shape): + return jax.image.resize(image, shape + image.shape[-1:], 'nearest') diff --git a/big_vision/evaluators/proj/uvim/coltran_fid.py b/big_vision/evaluators/proj/uvim/coltran_fid.py new file mode 100644 index 0000000000000000000000000000000000000000..555a7a2b008e6c4a09e172eed250cef92cd04c1b --- /dev/null +++ b/big_vision/evaluators/proj/uvim/coltran_fid.py @@ -0,0 +1,242 @@ +# Copyright 2022 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Evaluation producing ColTran FID-5K metric.""" + +import functools +import os + +from absl import logging +import einops +import jax +import numpy as np +import tensorflow as tf +import tensorflow_datasets as tfds +import tensorflow_gan as tfgan +import tensorflow_hub as tfhub + +from tensorflow.io import gfile + + +ROOT = os.environ.get("FID_DATA_DIR", ".") + + +def _preprocess(image, resolution=512): + """ColTran dataset preprocessing. + + See, + github.com/google-research/google-research/blob/master/coltran/datasets.py#L44 + + Args: + image: ImageNet example from TFDS. + resolution: Integer representing output size. + + Returns: + An int32 image of size (resolution, resolution, 3). + """ + image_shape = tf.shape(image) + height, width = image_shape[0], image_shape[1] + side_size = tf.minimum(height, width) + image = tf.image.resize_with_crop_or_pad( + image, target_height=side_size, target_width=side_size) + image = tf.image.resize(image, method="area", antialias=True, + size=(resolution, resolution)) + image = tf.cast(tf.round(image), dtype=tf.int32) + return image + + +def _normalize(x): + """Coltran normalization to expected range for Inception module. + + Args: + x: Image with values in [0,255]. + + Returns: + Image with values in [-1,1]. + """ + x = tf.cast(x, tf.float32) + x = (x / 128.0) - 1.0 # note: 128.0 is the value used in ColTran. + return x + + +class Evaluator: + """ColTran FID-5K Evaluator. + + This Evaluator aims to mirror the evaluation pipeline used by Kumar et.al. + in Colorization Transformer (https://arxiv.org/abs/2102.04432). + + To be clear: much of this code is direct snippets from ColTran code. + + See, + github.com/google-research/google-research/blob/master/coltran/datasets.py#L44 + + The ColTran pipeline has numerous stages, where serialied data is passed + between binaries via file, etc... While we don't physically write the same + files, we simulate the effects of the serialization (e.g., quantization). + """ + + def __init__(self, + predict_fn, + batch_size, # ignored + device_batch_size=5, + coltran_seed=1, + predict_kwargs=None): + """Create Evaluator. + + Args: + predict_fn: Colorization prediction function. Expects grayscale images + of size (512, 512, 3) in keys `image` and `image_ctx` with values in + the range [-1,1]. Outputs `color` image in range [-1,1]. + batch_size: ignored. + device_batch_size: number of images per batch, per device. + coltran_seed: used to specify the block of 5_000 images used to generate + the reference pool. Value of `1` matches default ColTran code. + predict_kwargs: arguments passed to `predict_fn`. + """ + del batch_size + + self.num_devices = jax.local_device_count() + self.device_batch_size = device_batch_size + logging.log(logging.INFO, "Colorizing with batch size %i on %i devices.", + self.device_batch_size, self.num_devices) + assert 5_000 % (self.device_batch_size * self.num_devices) == 0 + + predict = functools.partial(predict_fn, **(predict_kwargs or {})) + self.predict_fn = jax.pmap(predict) + + module = tfhub.load(tfgan.eval.INCEPTION_TFHUB) + def _pools(x): + return np.squeeze(module(x)[tfgan.eval.INCEPTION_FINAL_POOL].numpy()) + + self.inception_pool = _pools + + # Setup the colorization dataset. + # TRICKY: ColTran FID-5k uses the first 5_000 images returned as read by + # default from tensorflow_datasets (that is: with shard interleaving). + # In particular note that it is different than the set of images returned + # by "validation[:5000]". + def _eval_data_preprocess(example): + # Colorization happens at 512x512 resolution. + image = _preprocess(example["image"], resolution=512) + image = _normalize(image) + grayscale = tf.image.grayscale_to_rgb(tf.image.rgb_to_grayscale(image)) + return { + "image": image, + "grayscale": grayscale, + "file_name": example["file_name"] + } + + ds = tfds.load("imagenet2012", split="validation") + ds = ds.map(_eval_data_preprocess) + ds = ds.take(5_000) + ds = ds.batch(self.device_batch_size) + ds = ds.batch(self.num_devices) + self.eval_data = ds.cache().prefetch(tf.data.AUTOTUNE) + + # Setup the reference dataset. + def _reference_data_preprocess(example): + # ColTran eval operates on 256x256. + image = _preprocess(example["image"], resolution=256) + image = _normalize(image) + return {"image": image, "file_name": example["file_name"]} + + ds = tfds.load("imagenet2012", split="validation") + ds = ds.map(_reference_data_preprocess) + # Skip the images used in colorization. + ds = ds.skip(5_000) + # ColTran eval w/ seed=1 effectively uses 10_000:15_000 to + # calculate reference. + ds = ds.skip(coltran_seed * 5_000) + ds = ds.take(5_000) + ds = ds.batch(device_batch_size) + self.reference_data = ds.cache().prefetch(tf.data.AUTOTUNE) + + def _get_file(name): + return os.path.join(ROOT, name) + + with gfile.GFile(_get_file("eval_file_names.txt")) as f: + self.eval_file_names = frozenset(f.read().splitlines()) + + with gfile.GFile(_get_file("reference_file_names.txt")) as f: + self.reference_file_names = frozenset(f.read().splitlines()) + + def run(self, params): + """Run eval.""" + + if jax.process_index(): # Host0 does all work. + return + + color_pools = [] + color_file_names = set() + for i, batch in enumerate(self.eval_data.as_numpy_iterator()): + predict_batch = { + "labels": batch["image"], + "image": batch["grayscale"], + "image_ctx": batch["grayscale"], + } + y = self.predict_fn(params, predict_batch) + y = y["color"] + y = einops.rearrange(y, "d b h w c -> (d b) h w c") + + # Return to the ColTran eval size of 256x256. + y = tf.image.resize(y, (256, 256), "area") + + # Mimic effect of serializing image as integers and map back to [-1, 1]. + y = np.clip(np.round((y + 1.) * 128.), 0, 255) + y = _normalize(y) + + color_pools.append(self.inception_pool(y)) + + file_names = einops.rearrange(batch["file_name"], "d b -> (d b)") + color_file_names.update([f.decode() for f in file_names]) + + logging.log_every_n_seconds( + logging.INFO, + "ColTran FID eval: processed %i colorized examples so far.", 30, + (i + 1) * self.device_batch_size * self.num_devices) + + reference_pools = [] + reference_file_names = set() + for i, batch in enumerate(self.reference_data.as_numpy_iterator()): + image = batch["image"] + assert np.array_equal(image.shape, (self.device_batch_size, 256, 256, 3)) + reference_pools.append(self.inception_pool(image)) + reference_file_names.update([f.decode() for f in batch["file_name"]]) + + logging.log_every_n_seconds( + logging.INFO, + "ColTran FID eval: processed %i reference examples so far.", 30, + (i + 1) * self.device_batch_size) + + if color_file_names != self.eval_file_names: + raise ValueError("unknown: {}\nmissing: {}".format( + color_file_names - self.eval_file_names, + self.eval_file_names - color_file_names)) + + if reference_file_names != self.reference_file_names: + raise ValueError("unknown: {}\nmissing: {}".format( + reference_file_names - self.reference_file_names, + self.reference_file_names - reference_file_names)) + + color = np.concatenate(color_pools, axis=0) + reference = np.concatenate(reference_pools, axis=0) + + if color.shape[0] != 5_000: + raise ValueError(color.shape) + + if reference.shape[0] != 5_000: + raise ValueError(reference.shape) + + yield "FID_5k", tfgan.eval.frechet_classifier_distance_from_activations( + color, reference) diff --git a/big_vision/evaluators/proj/uvim/coltran_fid_data/eval_file_names.txt b/big_vision/evaluators/proj/uvim/coltran_fid_data/eval_file_names.txt new file mode 100644 index 0000000000000000000000000000000000000000..d3a08b9dd2cc6bc41e07ea92039e49b564312153 --- /dev/null +++ b/big_vision/evaluators/proj/uvim/coltran_fid_data/eval_file_names.txt @@ -0,0 +1,5000 @@ +ILSVRC2012_val_00009670.JPEG +ILSVRC2012_val_00036705.JPEG +ILSVRC2012_val_00003545.JPEG +ILSVRC2012_val_00047963.JPEG +ILSVRC2012_val_00023277.JPEG +ILSVRC2012_val_00004014.JPEG +ILSVRC2012_val_00001121.JPEG +ILSVRC2012_val_00011754.JPEG +ILSVRC2012_val_00025035.JPEG +ILSVRC2012_val_00043797.JPEG +ILSVRC2012_val_00046096.JPEG +ILSVRC2012_val_00024105.JPEG +ILSVRC2012_val_00031747.JPEG +ILSVRC2012_val_00024113.JPEG +ILSVRC2012_val_00025971.JPEG +ILSVRC2012_val_00032467.JPEG +ILSVRC2012_val_00021106.JPEG +ILSVRC2012_val_00048369.JPEG +ILSVRC2012_val_00034488.JPEG +ILSVRC2012_val_00037380.JPEG +ILSVRC2012_val_00016846.JPEG +ILSVRC2012_val_00042664.JPEG +ILSVRC2012_val_00003230.JPEG +ILSVRC2012_val_00045510.JPEG +ILSVRC2012_val_00024667.JPEG +ILSVRC2012_val_00033383.JPEG +ILSVRC2012_val_00016559.JPEG +ILSVRC2012_val_00012460.JPEG +ILSVRC2012_val_00015028.JPEG +ILSVRC2012_val_00036694.JPEG +ILSVRC2012_val_00048401.JPEG +ILSVRC2012_val_00008599.JPEG +ILSVRC2012_val_00024724.JPEG +ILSVRC2012_val_00016613.JPEG +ILSVRC2012_val_00012872.JPEG +ILSVRC2012_val_00029464.JPEG +ILSVRC2012_val_00006791.JPEG +ILSVRC2012_val_00040933.JPEG +ILSVRC2012_val_00020980.JPEG +ILSVRC2012_val_00020933.JPEG +ILSVRC2012_val_00041264.JPEG +ILSVRC2012_val_00009098.JPEG +ILSVRC2012_val_00018611.JPEG +ILSVRC2012_val_00034921.JPEG +ILSVRC2012_val_00010025.JPEG +ILSVRC2012_val_00021580.JPEG +ILSVRC2012_val_00037940.JPEG +ILSVRC2012_val_00012585.JPEG +ILSVRC2012_val_00040631.JPEG +ILSVRC2012_val_00043473.JPEG +ILSVRC2012_val_00035336.JPEG +ILSVRC2012_val_00023147.JPEG +ILSVRC2012_val_00013833.JPEG +ILSVRC2012_val_00016418.JPEG +ILSVRC2012_val_00033390.JPEG +ILSVRC2012_val_00047840.JPEG +ILSVRC2012_val_00016048.JPEG +ILSVRC2012_val_00045736.JPEG +ILSVRC2012_val_00008967.JPEG +ILSVRC2012_val_00020593.JPEG +ILSVRC2012_val_00023548.JPEG +ILSVRC2012_val_00026589.JPEG +ILSVRC2012_val_00047233.JPEG +ILSVRC2012_val_00001523.JPEG +ILSVRC2012_val_00028686.JPEG +ILSVRC2012_val_00006172.JPEG +ILSVRC2012_val_00014856.JPEG +ILSVRC2012_val_00023450.JPEG +ILSVRC2012_val_00012349.JPEG +ILSVRC2012_val_00038876.JPEG +ILSVRC2012_val_00029340.JPEG +ILSVRC2012_val_00009986.JPEG +ILSVRC2012_val_00028311.JPEG +ILSVRC2012_val_00048337.JPEG +ILSVRC2012_val_00020541.JPEG +ILSVRC2012_val_00044507.JPEG +ILSVRC2012_val_00024092.JPEG +ILSVRC2012_val_00049283.JPEG +ILSVRC2012_val_00036416.JPEG +ILSVRC2012_val_00016583.JPEG +ILSVRC2012_val_00045446.JPEG +ILSVRC2012_val_00038567.JPEG +ILSVRC2012_val_00033717.JPEG +ILSVRC2012_val_00038475.JPEG +ILSVRC2012_val_00026422.JPEG +ILSVRC2012_val_00044982.JPEG +ILSVRC2012_val_00044802.JPEG +ILSVRC2012_val_00014315.JPEG +ILSVRC2012_val_00009085.JPEG +ILSVRC2012_val_00020613.JPEG +ILSVRC2012_val_00017637.JPEG +ILSVRC2012_val_00030081.JPEG +ILSVRC2012_val_00022925.JPEG +ILSVRC2012_val_00014615.JPEG +ILSVRC2012_val_00033931.JPEG +ILSVRC2012_val_00038691.JPEG +ILSVRC2012_val_00012242.JPEG +ILSVRC2012_val_00004258.JPEG +ILSVRC2012_val_00016300.JPEG +ILSVRC2012_val_00044764.JPEG +ILSVRC2012_val_00048033.JPEG +ILSVRC2012_val_00042120.JPEG +ILSVRC2012_val_00004634.JPEG +ILSVRC2012_val_00030833.JPEG +ILSVRC2012_val_00040507.JPEG +ILSVRC2012_val_00046915.JPEG +ILSVRC2012_val_00000762.JPEG +ILSVRC2012_val_00044717.JPEG +ILSVRC2012_val_00041480.JPEG +ILSVRC2012_val_00021262.JPEG +ILSVRC2012_val_00025761.JPEG +ILSVRC2012_val_00004455.JPEG +ILSVRC2012_val_00035741.JPEG +ILSVRC2012_val_00040331.JPEG +ILSVRC2012_val_00027190.JPEG +ILSVRC2012_val_00031247.JPEG +ILSVRC2012_val_00022364.JPEG +ILSVRC2012_val_00020647.JPEG +ILSVRC2012_val_00045455.JPEG +ILSVRC2012_val_00036387.JPEG +ILSVRC2012_val_00002584.JPEG +ILSVRC2012_val_00028727.JPEG +ILSVRC2012_val_00033509.JPEG +ILSVRC2012_val_00045960.JPEG +ILSVRC2012_val_00018035.JPEG +ILSVRC2012_val_00042294.JPEG +ILSVRC2012_val_00037068.JPEG +ILSVRC2012_val_00026772.JPEG +ILSVRC2012_val_00043286.JPEG +ILSVRC2012_val_00040479.JPEG +ILSVRC2012_val_00029969.JPEG +ILSVRC2012_val_00027908.JPEG +ILSVRC2012_val_00046152.JPEG +ILSVRC2012_val_00024557.JPEG +ILSVRC2012_val_00025049.JPEG +ILSVRC2012_val_00015498.JPEG +ILSVRC2012_val_00005882.JPEG +ILSVRC2012_val_00014682.JPEG +ILSVRC2012_val_00044484.JPEG +ILSVRC2012_val_00004990.JPEG +ILSVRC2012_val_00016420.JPEG +ILSVRC2012_val_00021289.JPEG +ILSVRC2012_val_00013360.JPEG +ILSVRC2012_val_00038600.JPEG +ILSVRC2012_val_00041932.JPEG +ILSVRC2012_val_00021338.JPEG +ILSVRC2012_val_00002296.JPEG +ILSVRC2012_val_00032757.JPEG +ILSVRC2012_val_00010804.JPEG +ILSVRC2012_val_00035707.JPEG +ILSVRC2012_val_00049995.JPEG +ILSVRC2012_val_00013871.JPEG +ILSVRC2012_val_00036383.JPEG +ILSVRC2012_val_00011718.JPEG +ILSVRC2012_val_00012518.JPEG +ILSVRC2012_val_00018637.JPEG +ILSVRC2012_val_00037722.JPEG +ILSVRC2012_val_00000040.JPEG +ILSVRC2012_val_00021312.JPEG +ILSVRC2012_val_00034539.JPEG +ILSVRC2012_val_00049523.JPEG +ILSVRC2012_val_00000347.JPEG +ILSVRC2012_val_00000428.JPEG +ILSVRC2012_val_00027218.JPEG +ILSVRC2012_val_00004346.JPEG +ILSVRC2012_val_00004030.JPEG +ILSVRC2012_val_00035334.JPEG +ILSVRC2012_val_00004368.JPEG +ILSVRC2012_val_00034466.JPEG +ILSVRC2012_val_00045203.JPEG +ILSVRC2012_val_00017630.JPEG +ILSVRC2012_val_00037067.JPEG +ILSVRC2012_val_00006656.JPEG +ILSVRC2012_val_00029811.JPEG +ILSVRC2012_val_00034522.JPEG +ILSVRC2012_val_00003583.JPEG +ILSVRC2012_val_00008683.JPEG +ILSVRC2012_val_00030886.JPEG +ILSVRC2012_val_00004097.JPEG +ILSVRC2012_val_00032031.JPEG +ILSVRC2012_val_00039652.JPEG +ILSVRC2012_val_00005729.JPEG +ILSVRC2012_val_00013205.JPEG +ILSVRC2012_val_00011653.JPEG +ILSVRC2012_val_00038287.JPEG +ILSVRC2012_val_00046447.JPEG +ILSVRC2012_val_00025152.JPEG +ILSVRC2012_val_00035740.JPEG +ILSVRC2012_val_00017399.JPEG +ILSVRC2012_val_00032779.JPEG +ILSVRC2012_val_00034375.JPEG +ILSVRC2012_val_00030992.JPEG +ILSVRC2012_val_00041114.JPEG +ILSVRC2012_val_00023050.JPEG +ILSVRC2012_val_00030819.JPEG +ILSVRC2012_val_00040316.JPEG +ILSVRC2012_val_00019266.JPEG +ILSVRC2012_val_00017289.JPEG +ILSVRC2012_val_00034924.JPEG +ILSVRC2012_val_00031452.JPEG +ILSVRC2012_val_00032174.JPEG +ILSVRC2012_val_00048817.JPEG +ILSVRC2012_val_00012131.JPEG +ILSVRC2012_val_00009882.JPEG +ILSVRC2012_val_00025737.JPEG +ILSVRC2012_val_00010951.JPEG +ILSVRC2012_val_00024919.JPEG +ILSVRC2012_val_00049774.JPEG +ILSVRC2012_val_00041614.JPEG +ILSVRC2012_val_00023960.JPEG +ILSVRC2012_val_00036801.JPEG +ILSVRC2012_val_00006146.JPEG +ILSVRC2012_val_00002008.JPEG +ILSVRC2012_val_00031210.JPEG +ILSVRC2012_val_00030460.JPEG +ILSVRC2012_val_00020539.JPEG +ILSVRC2012_val_00007465.JPEG +ILSVRC2012_val_00001126.JPEG +ILSVRC2012_val_00039167.JPEG +ILSVRC2012_val_00043920.JPEG +ILSVRC2012_val_00029976.JPEG +ILSVRC2012_val_00013459.JPEG +ILSVRC2012_val_00021410.JPEG +ILSVRC2012_val_00039878.JPEG +ILSVRC2012_val_00012414.JPEG +ILSVRC2012_val_00000749.JPEG +ILSVRC2012_val_00011678.JPEG +ILSVRC2012_val_00007644.JPEG +ILSVRC2012_val_00029391.JPEG +ILSVRC2012_val_00009891.JPEG +ILSVRC2012_val_00003378.JPEG +ILSVRC2012_val_00017374.JPEG +ILSVRC2012_val_00031863.JPEG +ILSVRC2012_val_00042312.JPEG +ILSVRC2012_val_00026452.JPEG +ILSVRC2012_val_00030014.JPEG +ILSVRC2012_val_00048330.JPEG +ILSVRC2012_val_00001343.JPEG +ILSVRC2012_val_00043156.JPEG +ILSVRC2012_val_00029457.JPEG +ILSVRC2012_val_00011316.JPEG +ILSVRC2012_val_00020040.JPEG +ILSVRC2012_val_00025723.JPEG +ILSVRC2012_val_00048364.JPEG +ILSVRC2012_val_00016653.JPEG +ILSVRC2012_val_00032241.JPEG +ILSVRC2012_val_00040938.JPEG +ILSVRC2012_val_00003520.JPEG +ILSVRC2012_val_00003617.JPEG +ILSVRC2012_val_00031484.JPEG +ILSVRC2012_val_00019398.JPEG +ILSVRC2012_val_00020670.JPEG +ILSVRC2012_val_00015257.JPEG +ILSVRC2012_val_00032691.JPEG +ILSVRC2012_val_00045452.JPEG +ILSVRC2012_val_00026480.JPEG +ILSVRC2012_val_00042441.JPEG +ILSVRC2012_val_00039575.JPEG +ILSVRC2012_val_00042619.JPEG +ILSVRC2012_val_00043399.JPEG +ILSVRC2012_val_00044429.JPEG +ILSVRC2012_val_00033520.JPEG +ILSVRC2012_val_00046423.JPEG +ILSVRC2012_val_00042039.JPEG +ILSVRC2012_val_00003944.JPEG +ILSVRC2012_val_00047258.JPEG +ILSVRC2012_val_00037770.JPEG +ILSVRC2012_val_00011256.JPEG +ILSVRC2012_val_00010495.JPEG +ILSVRC2012_val_00010520.JPEG +ILSVRC2012_val_00027160.JPEG +ILSVRC2012_val_00032256.JPEG +ILSVRC2012_val_00001610.JPEG +ILSVRC2012_val_00008326.JPEG +ILSVRC2012_val_00041685.JPEG +ILSVRC2012_val_00021443.JPEG +ILSVRC2012_val_00043981.JPEG +ILSVRC2012_val_00029578.JPEG +ILSVRC2012_val_00030561.JPEG +ILSVRC2012_val_00019953.JPEG +ILSVRC2012_val_00049574.JPEG +ILSVRC2012_val_00026189.JPEG +ILSVRC2012_val_00013869.JPEG +ILSVRC2012_val_00048582.JPEG +ILSVRC2012_val_00044818.JPEG +ILSVRC2012_val_00016735.JPEG +ILSVRC2012_val_00036027.JPEG +ILSVRC2012_val_00024036.JPEG +ILSVRC2012_val_00040658.JPEG +ILSVRC2012_val_00010908.JPEG +ILSVRC2012_val_00012616.JPEG +ILSVRC2012_val_00003581.JPEG +ILSVRC2012_val_00028865.JPEG +ILSVRC2012_val_00042945.JPEG +ILSVRC2012_val_00042482.JPEG +ILSVRC2012_val_00047910.JPEG +ILSVRC2012_val_00025495.JPEG +ILSVRC2012_val_00023440.JPEG +ILSVRC2012_val_00020693.JPEG +ILSVRC2012_val_00021337.JPEG +ILSVRC2012_val_00048664.JPEG +ILSVRC2012_val_00038325.JPEG +ILSVRC2012_val_00039572.JPEG +ILSVRC2012_val_00017513.JPEG +ILSVRC2012_val_00042217.JPEG +ILSVRC2012_val_00028507.JPEG +ILSVRC2012_val_00007673.JPEG +ILSVRC2012_val_00037549.JPEG +ILSVRC2012_val_00021631.JPEG +ILSVRC2012_val_00003637.JPEG +ILSVRC2012_val_00013239.JPEG +ILSVRC2012_val_00030430.JPEG +ILSVRC2012_val_00018513.JPEG +ILSVRC2012_val_00000120.JPEG +ILSVRC2012_val_00046658.JPEG +ILSVRC2012_val_00021739.JPEG +ILSVRC2012_val_00040608.JPEG +ILSVRC2012_val_00031069.JPEG +ILSVRC2012_val_00022141.JPEG +ILSVRC2012_val_00015474.JPEG +ILSVRC2012_val_00046918.JPEG +ILSVRC2012_val_00020371.JPEG +ILSVRC2012_val_00011159.JPEG +ILSVRC2012_val_00030799.JPEG +ILSVRC2012_val_00039534.JPEG +ILSVRC2012_val_00015607.JPEG +ILSVRC2012_val_00023513.JPEG +ILSVRC2012_val_00046763.JPEG +ILSVRC2012_val_00016175.JPEG +ILSVRC2012_val_00017190.JPEG +ILSVRC2012_val_00048277.JPEG +ILSVRC2012_val_00015356.JPEG +ILSVRC2012_val_00023226.JPEG +ILSVRC2012_val_00026849.JPEG +ILSVRC2012_val_00011960.JPEG +ILSVRC2012_val_00046460.JPEG +ILSVRC2012_val_00042976.JPEG +ILSVRC2012_val_00029830.JPEG +ILSVRC2012_val_00037081.JPEG +ILSVRC2012_val_00021119.JPEG +ILSVRC2012_val_00002453.JPEG +ILSVRC2012_val_00001721.JPEG +ILSVRC2012_val_00005983.JPEG +ILSVRC2012_val_00031460.JPEG +ILSVRC2012_val_00010029.JPEG +ILSVRC2012_val_00009830.JPEG +ILSVRC2012_val_00004357.JPEG +ILSVRC2012_val_00038664.JPEG +ILSVRC2012_val_00040415.JPEG +ILSVRC2012_val_00046818.JPEG +ILSVRC2012_val_00047026.JPEG +ILSVRC2012_val_00034616.JPEG +ILSVRC2012_val_00004899.JPEG +ILSVRC2012_val_00033706.JPEG +ILSVRC2012_val_00047344.JPEG +ILSVRC2012_val_00038725.JPEG +ILSVRC2012_val_00031925.JPEG +ILSVRC2012_val_00010633.JPEG +ILSVRC2012_val_00020304.JPEG +ILSVRC2012_val_00036520.JPEG +ILSVRC2012_val_00004818.JPEG +ILSVRC2012_val_00035061.JPEG +ILSVRC2012_val_00018945.JPEG +ILSVRC2012_val_00029504.JPEG +ILSVRC2012_val_00015954.JPEG +ILSVRC2012_val_00009697.JPEG +ILSVRC2012_val_00015848.JPEG +ILSVRC2012_val_00013155.JPEG +ILSVRC2012_val_00008563.JPEG +ILSVRC2012_val_00025830.JPEG +ILSVRC2012_val_00014980.JPEG +ILSVRC2012_val_00006878.JPEG +ILSVRC2012_val_00024270.JPEG +ILSVRC2012_val_00000997.JPEG +ILSVRC2012_val_00031141.JPEG +ILSVRC2012_val_00032404.JPEG +ILSVRC2012_val_00024769.JPEG +ILSVRC2012_val_00037682.JPEG +ILSVRC2012_val_00012718.JPEG +ILSVRC2012_val_00047668.JPEG +ILSVRC2012_val_00021383.JPEG +ILSVRC2012_val_00037072.JPEG +ILSVRC2012_val_00001250.JPEG +ILSVRC2012_val_00017418.JPEG +ILSVRC2012_val_00017824.JPEG +ILSVRC2012_val_00045601.JPEG +ILSVRC2012_val_00025044.JPEG +ILSVRC2012_val_00001379.JPEG +ILSVRC2012_val_00029317.JPEG +ILSVRC2012_val_00029827.JPEG +ILSVRC2012_val_00031128.JPEG +ILSVRC2012_val_00005367.JPEG +ILSVRC2012_val_00046985.JPEG +ILSVRC2012_val_00021191.JPEG +ILSVRC2012_val_00009034.JPEG +ILSVRC2012_val_00048819.JPEG +ILSVRC2012_val_00035806.JPEG +ILSVRC2012_val_00048861.JPEG +ILSVRC2012_val_00043602.JPEG +ILSVRC2012_val_00022356.JPEG +ILSVRC2012_val_00025507.JPEG +ILSVRC2012_val_00031636.JPEG +ILSVRC2012_val_00045799.JPEG +ILSVRC2012_val_00015190.JPEG +ILSVRC2012_val_00008358.JPEG +ILSVRC2012_val_00041364.JPEG +ILSVRC2012_val_00039365.JPEG +ILSVRC2012_val_00041226.JPEG +ILSVRC2012_val_00000970.JPEG +ILSVRC2012_val_00046655.JPEG +ILSVRC2012_val_00040114.JPEG +ILSVRC2012_val_00027236.JPEG +ILSVRC2012_val_00020741.JPEG +ILSVRC2012_val_00011499.JPEG +ILSVRC2012_val_00024154.JPEG +ILSVRC2012_val_00031104.JPEG +ILSVRC2012_val_00009162.JPEG +ILSVRC2012_val_00008631.JPEG +ILSVRC2012_val_00031238.JPEG +ILSVRC2012_val_00024195.JPEG +ILSVRC2012_val_00023134.JPEG +ILSVRC2012_val_00022358.JPEG +ILSVRC2012_val_00029017.JPEG +ILSVRC2012_val_00027568.JPEG +ILSVRC2012_val_00013586.JPEG +ILSVRC2012_val_00014427.JPEG +ILSVRC2012_val_00048022.JPEG +ILSVRC2012_val_00015479.JPEG +ILSVRC2012_val_00027975.JPEG +ILSVRC2012_val_00021307.JPEG +ILSVRC2012_val_00010749.JPEG +ILSVRC2012_val_00025668.JPEG +ILSVRC2012_val_00042487.JPEG +ILSVRC2012_val_00003801.JPEG +ILSVRC2012_val_00036559.JPEG +ILSVRC2012_val_00042887.JPEG +ILSVRC2012_val_00020031.JPEG +ILSVRC2012_val_00030935.JPEG +ILSVRC2012_val_00019987.JPEG +ILSVRC2012_val_00005176.JPEG +ILSVRC2012_val_00046180.JPEG +ILSVRC2012_val_00018344.JPEG +ILSVRC2012_val_00009415.JPEG +ILSVRC2012_val_00006726.JPEG +ILSVRC2012_val_00028534.JPEG +ILSVRC2012_val_00022125.JPEG +ILSVRC2012_val_00037831.JPEG +ILSVRC2012_val_00036219.JPEG +ILSVRC2012_val_00038842.JPEG +ILSVRC2012_val_00047945.JPEG +ILSVRC2012_val_00021740.JPEG +ILSVRC2012_val_00011030.JPEG +ILSVRC2012_val_00034726.JPEG +ILSVRC2012_val_00006179.JPEG +ILSVRC2012_val_00010184.JPEG +ILSVRC2012_val_00010484.JPEG +ILSVRC2012_val_00042439.JPEG +ILSVRC2012_val_00002311.JPEG +ILSVRC2012_val_00043871.JPEG +ILSVRC2012_val_00043371.JPEG +ILSVRC2012_val_00046565.JPEG +ILSVRC2012_val_00030975.JPEG +ILSVRC2012_val_00004729.JPEG +ILSVRC2012_val_00015271.JPEG +ILSVRC2012_val_00043420.JPEG +ILSVRC2012_val_00028268.JPEG +ILSVRC2012_val_00000593.JPEG +ILSVRC2012_val_00018261.JPEG +ILSVRC2012_val_00031540.JPEG +ILSVRC2012_val_00047070.JPEG +ILSVRC2012_val_00042394.JPEG +ILSVRC2012_val_00003789.JPEG +ILSVRC2012_val_00038820.JPEG +ILSVRC2012_val_00036316.JPEG +ILSVRC2012_val_00008978.JPEG +ILSVRC2012_val_00031733.JPEG +ILSVRC2012_val_00017741.JPEG +ILSVRC2012_val_00042005.JPEG +ILSVRC2012_val_00007830.JPEG +ILSVRC2012_val_00030950.JPEG +ILSVRC2012_val_00013929.JPEG +ILSVRC2012_val_00019289.JPEG +ILSVRC2012_val_00002957.JPEG +ILSVRC2012_val_00004270.JPEG +ILSVRC2012_val_00041190.JPEG +ILSVRC2012_val_00025684.JPEG +ILSVRC2012_val_00004780.JPEG +ILSVRC2012_val_00040886.JPEG +ILSVRC2012_val_00029638.JPEG +ILSVRC2012_val_00004267.JPEG +ILSVRC2012_val_00004079.JPEG +ILSVRC2012_val_00032281.JPEG +ILSVRC2012_val_00047480.JPEG +ILSVRC2012_val_00042580.JPEG +ILSVRC2012_val_00022829.JPEG +ILSVRC2012_val_00040378.JPEG +ILSVRC2012_val_00026702.JPEG +ILSVRC2012_val_00008544.JPEG +ILSVRC2012_val_00021700.JPEG +ILSVRC2012_val_00012635.JPEG +ILSVRC2012_val_00042632.JPEG +ILSVRC2012_val_00029979.JPEG +ILSVRC2012_val_00016138.JPEG +ILSVRC2012_val_00045738.JPEG +ILSVRC2012_val_00042282.JPEG +ILSVRC2012_val_00001491.JPEG +ILSVRC2012_val_00004524.JPEG +ILSVRC2012_val_00034589.JPEG +ILSVRC2012_val_00040516.JPEG +ILSVRC2012_val_00006792.JPEG +ILSVRC2012_val_00035627.JPEG +ILSVRC2012_val_00015667.JPEG +ILSVRC2012_val_00048559.JPEG +ILSVRC2012_val_00030235.JPEG +ILSVRC2012_val_00045303.JPEG +ILSVRC2012_val_00030447.JPEG +ILSVRC2012_val_00003650.JPEG +ILSVRC2012_val_00022050.JPEG +ILSVRC2012_val_00005320.JPEG +ILSVRC2012_val_00042326.JPEG +ILSVRC2012_val_00009056.JPEG +ILSVRC2012_val_00017185.JPEG +ILSVRC2012_val_00016667.JPEG +ILSVRC2012_val_00043080.JPEG +ILSVRC2012_val_00039706.JPEG +ILSVRC2012_val_00035939.JPEG +ILSVRC2012_val_00037826.JPEG +ILSVRC2012_val_00039492.JPEG +ILSVRC2012_val_00008439.JPEG +ILSVRC2012_val_00045236.JPEG +ILSVRC2012_val_00000447.JPEG +ILSVRC2012_val_00030547.JPEG +ILSVRC2012_val_00034158.JPEG +ILSVRC2012_val_00005860.JPEG +ILSVRC2012_val_00039716.JPEG +ILSVRC2012_val_00005917.JPEG +ILSVRC2012_val_00045222.JPEG +ILSVRC2012_val_00000909.JPEG +ILSVRC2012_val_00016608.JPEG +ILSVRC2012_val_00044721.JPEG +ILSVRC2012_val_00030604.JPEG +ILSVRC2012_val_00041494.JPEG +ILSVRC2012_val_00043208.JPEG +ILSVRC2012_val_00026190.JPEG +ILSVRC2012_val_00029654.JPEG +ILSVRC2012_val_00007899.JPEG +ILSVRC2012_val_00046908.JPEG +ILSVRC2012_val_00017670.JPEG +ILSVRC2012_val_00002322.JPEG +ILSVRC2012_val_00002785.JPEG +ILSVRC2012_val_00037454.JPEG +ILSVRC2012_val_00045836.JPEG +ILSVRC2012_val_00020676.JPEG +ILSVRC2012_val_00049382.JPEG +ILSVRC2012_val_00029772.JPEG +ILSVRC2012_val_00011720.JPEG +ILSVRC2012_val_00028956.JPEG +ILSVRC2012_val_00038182.JPEG +ILSVRC2012_val_00013411.JPEG +ILSVRC2012_val_00046185.JPEG +ILSVRC2012_val_00049014.JPEG +ILSVRC2012_val_00021642.JPEG +ILSVRC2012_val_00046670.JPEG +ILSVRC2012_val_00030910.JPEG +ILSVRC2012_val_00035971.JPEG +ILSVRC2012_val_00045690.JPEG +ILSVRC2012_val_00039432.JPEG +ILSVRC2012_val_00012133.JPEG +ILSVRC2012_val_00046713.JPEG +ILSVRC2012_val_00031823.JPEG +ILSVRC2012_val_00001943.JPEG +ILSVRC2012_val_00024065.JPEG +ILSVRC2012_val_00018502.JPEG +ILSVRC2012_val_00042610.JPEG +ILSVRC2012_val_00003767.JPEG +ILSVRC2012_val_00000393.JPEG +ILSVRC2012_val_00046280.JPEG +ILSVRC2012_val_00046663.JPEG +ILSVRC2012_val_00036336.JPEG +ILSVRC2012_val_00000979.JPEG +ILSVRC2012_val_00026432.JPEG +ILSVRC2012_val_00017613.JPEG +ILSVRC2012_val_00025885.JPEG +ILSVRC2012_val_00007436.JPEG +ILSVRC2012_val_00027102.JPEG +ILSVRC2012_val_00024286.JPEG +ILSVRC2012_val_00046749.JPEG +ILSVRC2012_val_00046151.JPEG +ILSVRC2012_val_00011151.JPEG +ILSVRC2012_val_00001741.JPEG +ILSVRC2012_val_00034881.JPEG +ILSVRC2012_val_00018527.JPEG +ILSVRC2012_val_00048102.JPEG +ILSVRC2012_val_00020819.JPEG +ILSVRC2012_val_00017642.JPEG +ILSVRC2012_val_00026680.JPEG +ILSVRC2012_val_00039142.JPEG +ILSVRC2012_val_00041587.JPEG +ILSVRC2012_val_00013856.JPEG +ILSVRC2012_val_00005611.JPEG +ILSVRC2012_val_00027603.JPEG +ILSVRC2012_val_00048982.JPEG +ILSVRC2012_val_00039851.JPEG +ILSVRC2012_val_00019631.JPEG +ILSVRC2012_val_00036405.JPEG +ILSVRC2012_val_00020916.JPEG +ILSVRC2012_val_00021716.JPEG +ILSVRC2012_val_00048373.JPEG +ILSVRC2012_val_00000244.JPEG +ILSVRC2012_val_00037089.JPEG +ILSVRC2012_val_00040530.JPEG +ILSVRC2012_val_00036487.JPEG +ILSVRC2012_val_00015440.JPEG +ILSVRC2012_val_00008791.JPEG +ILSVRC2012_val_00020410.JPEG +ILSVRC2012_val_00016186.JPEG +ILSVRC2012_val_00021326.JPEG +ILSVRC2012_val_00044095.JPEG +ILSVRC2012_val_00012615.JPEG +ILSVRC2012_val_00002191.JPEG +ILSVRC2012_val_00016885.JPEG +ILSVRC2012_val_00015676.JPEG +ILSVRC2012_val_00027342.JPEG +ILSVRC2012_val_00005590.JPEG +ILSVRC2012_val_00023216.JPEG +ILSVRC2012_val_00004117.JPEG +ILSVRC2012_val_00039457.JPEG +ILSVRC2012_val_00033268.JPEG +ILSVRC2012_val_00020397.JPEG +ILSVRC2012_val_00010419.JPEG +ILSVRC2012_val_00001813.JPEG +ILSVRC2012_val_00037279.JPEG +ILSVRC2012_val_00040026.JPEG +ILSVRC2012_val_00000830.JPEG +ILSVRC2012_val_00022765.JPEG +ILSVRC2012_val_00009740.JPEG +ILSVRC2012_val_00042032.JPEG +ILSVRC2012_val_00033972.JPEG +ILSVRC2012_val_00033314.JPEG +ILSVRC2012_val_00024704.JPEG +ILSVRC2012_val_00021353.JPEG +ILSVRC2012_val_00005989.JPEG +ILSVRC2012_val_00033953.JPEG +ILSVRC2012_val_00006250.JPEG +ILSVRC2012_val_00042862.JPEG +ILSVRC2012_val_00049804.JPEG +ILSVRC2012_val_00037028.JPEG +ILSVRC2012_val_00011245.JPEG +ILSVRC2012_val_00022488.JPEG +ILSVRC2012_val_00049099.JPEG +ILSVRC2012_val_00038906.JPEG +ILSVRC2012_val_00045665.JPEG +ILSVRC2012_val_00049548.JPEG +ILSVRC2012_val_00030884.JPEG +ILSVRC2012_val_00030607.JPEG +ILSVRC2012_val_00002379.JPEG +ILSVRC2012_val_00044441.JPEG +ILSVRC2012_val_00011964.JPEG +ILSVRC2012_val_00005784.JPEG +ILSVRC2012_val_00018498.JPEG +ILSVRC2012_val_00048229.JPEG +ILSVRC2012_val_00016394.JPEG +ILSVRC2012_val_00010374.JPEG +ILSVRC2012_val_00000565.JPEG +ILSVRC2012_val_00013657.JPEG +ILSVRC2012_val_00021903.JPEG +ILSVRC2012_val_00039676.JPEG +ILSVRC2012_val_00018570.JPEG +ILSVRC2012_val_00041762.JPEG +ILSVRC2012_val_00015314.JPEG +ILSVRC2012_val_00023494.JPEG +ILSVRC2012_val_00015060.JPEG +ILSVRC2012_val_00039543.JPEG +ILSVRC2012_val_00030742.JPEG +ILSVRC2012_val_00024456.JPEG +ILSVRC2012_val_00006026.JPEG +ILSVRC2012_val_00022754.JPEG +ILSVRC2012_val_00019574.JPEG +ILSVRC2012_val_00012194.JPEG +ILSVRC2012_val_00031471.JPEG +ILSVRC2012_val_00003205.JPEG +ILSVRC2012_val_00041311.JPEG +ILSVRC2012_val_00049647.JPEG +ILSVRC2012_val_00023757.JPEG +ILSVRC2012_val_00023546.JPEG +ILSVRC2012_val_00030601.JPEG +ILSVRC2012_val_00012923.JPEG +ILSVRC2012_val_00000686.JPEG +ILSVRC2012_val_00006643.JPEG +ILSVRC2012_val_00008587.JPEG +ILSVRC2012_val_00003375.JPEG +ILSVRC2012_val_00036876.JPEG +ILSVRC2012_val_00016907.JPEG +ILSVRC2012_val_00002793.JPEG +ILSVRC2012_val_00040972.JPEG +ILSVRC2012_val_00033005.JPEG +ILSVRC2012_val_00036578.JPEG +ILSVRC2012_val_00015483.JPEG +ILSVRC2012_val_00047713.JPEG +ILSVRC2012_val_00032235.JPEG +ILSVRC2012_val_00016825.JPEG +ILSVRC2012_val_00028952.JPEG +ILSVRC2012_val_00028046.JPEG +ILSVRC2012_val_00026316.JPEG +ILSVRC2012_val_00011268.JPEG +ILSVRC2012_val_00037878.JPEG +ILSVRC2012_val_00007218.JPEG +ILSVRC2012_val_00048006.JPEG +ILSVRC2012_val_00023586.JPEG +ILSVRC2012_val_00037743.JPEG +ILSVRC2012_val_00008175.JPEG +ILSVRC2012_val_00006467.JPEG +ILSVRC2012_val_00040510.JPEG +ILSVRC2012_val_00022978.JPEG +ILSVRC2012_val_00034599.JPEG +ILSVRC2012_val_00040291.JPEG +ILSVRC2012_val_00002668.JPEG +ILSVRC2012_val_00038661.JPEG +ILSVRC2012_val_00015420.JPEG +ILSVRC2012_val_00024723.JPEG +ILSVRC2012_val_00024255.JPEG +ILSVRC2012_val_00012039.JPEG +ILSVRC2012_val_00011522.JPEG +ILSVRC2012_val_00007093.JPEG +ILSVRC2012_val_00012070.JPEG +ILSVRC2012_val_00005579.JPEG +ILSVRC2012_val_00006032.JPEG +ILSVRC2012_val_00006677.JPEG +ILSVRC2012_val_00006448.JPEG +ILSVRC2012_val_00036734.JPEG +ILSVRC2012_val_00021412.JPEG +ILSVRC2012_val_00001170.JPEG +ILSVRC2012_val_00040690.JPEG +ILSVRC2012_val_00007065.JPEG +ILSVRC2012_val_00027621.JPEG +ILSVRC2012_val_00038562.JPEG +ILSVRC2012_val_00028129.JPEG +ILSVRC2012_val_00004292.JPEG +ILSVRC2012_val_00025653.JPEG +ILSVRC2012_val_00029426.JPEG +ILSVRC2012_val_00036764.JPEG +ILSVRC2012_val_00005391.JPEG +ILSVRC2012_val_00043795.JPEG +ILSVRC2012_val_00025315.JPEG +ILSVRC2012_val_00015040.JPEG +ILSVRC2012_val_00016080.JPEG +ILSVRC2012_val_00022589.JPEG +ILSVRC2012_val_00022597.JPEG +ILSVRC2012_val_00021101.JPEG +ILSVRC2012_val_00002776.JPEG +ILSVRC2012_val_00002544.JPEG +ILSVRC2012_val_00030738.JPEG +ILSVRC2012_val_00034745.JPEG +ILSVRC2012_val_00000355.JPEG +ILSVRC2012_val_00024371.JPEG +ILSVRC2012_val_00017001.JPEG +ILSVRC2012_val_00020376.JPEG +ILSVRC2012_val_00047965.JPEG +ILSVRC2012_val_00046081.JPEG +ILSVRC2012_val_00026656.JPEG +ILSVRC2012_val_00000533.JPEG +ILSVRC2012_val_00016148.JPEG +ILSVRC2012_val_00040222.JPEG +ILSVRC2012_val_00027567.JPEG +ILSVRC2012_val_00003168.JPEG +ILSVRC2012_val_00003428.JPEG +ILSVRC2012_val_00011861.JPEG +ILSVRC2012_val_00017930.JPEG +ILSVRC2012_val_00029399.JPEG +ILSVRC2012_val_00000350.JPEG +ILSVRC2012_val_00032129.JPEG +ILSVRC2012_val_00031047.JPEG +ILSVRC2012_val_00027354.JPEG +ILSVRC2012_val_00002201.JPEG +ILSVRC2012_val_00040174.JPEG +ILSVRC2012_val_00037836.JPEG +ILSVRC2012_val_00037101.JPEG +ILSVRC2012_val_00046725.JPEG +ILSVRC2012_val_00021810.JPEG +ILSVRC2012_val_00022231.JPEG +ILSVRC2012_val_00044520.JPEG +ILSVRC2012_val_00046332.JPEG +ILSVRC2012_val_00001050.JPEG +ILSVRC2012_val_00036945.JPEG +ILSVRC2012_val_00032105.JPEG +ILSVRC2012_val_00006924.JPEG +ILSVRC2012_val_00025564.JPEG +ILSVRC2012_val_00012463.JPEG +ILSVRC2012_val_00035272.JPEG +ILSVRC2012_val_00043478.JPEG +ILSVRC2012_val_00032315.JPEG +ILSVRC2012_val_00022895.JPEG +ILSVRC2012_val_00012099.JPEG +ILSVRC2012_val_00038220.JPEG +ILSVRC2012_val_00037819.JPEG +ILSVRC2012_val_00010846.JPEG +ILSVRC2012_val_00024609.JPEG +ILSVRC2012_val_00004395.JPEG +ILSVRC2012_val_00019227.JPEG +ILSVRC2012_val_00047317.JPEG +ILSVRC2012_val_00004345.JPEG +ILSVRC2012_val_00038402.JPEG +ILSVRC2012_val_00048308.JPEG +ILSVRC2012_val_00043282.JPEG +ILSVRC2012_val_00008913.JPEG +ILSVRC2012_val_00019601.JPEG +ILSVRC2012_val_00037964.JPEG +ILSVRC2012_val_00028937.JPEG +ILSVRC2012_val_00045248.JPEG +ILSVRC2012_val_00019028.JPEG +ILSVRC2012_val_00032004.JPEG +ILSVRC2012_val_00030605.JPEG +ILSVRC2012_val_00022992.JPEG +ILSVRC2012_val_00037659.JPEG +ILSVRC2012_val_00000156.JPEG +ILSVRC2012_val_00025741.JPEG +ILSVRC2012_val_00028333.JPEG +ILSVRC2012_val_00011590.JPEG +ILSVRC2012_val_00010809.JPEG +ILSVRC2012_val_00024455.JPEG +ILSVRC2012_val_00010687.JPEG +ILSVRC2012_val_00034189.JPEG +ILSVRC2012_val_00001857.JPEG +ILSVRC2012_val_00005541.JPEG +ILSVRC2012_val_00005749.JPEG +ILSVRC2012_val_00029658.JPEG +ILSVRC2012_val_00000927.JPEG +ILSVRC2012_val_00017486.JPEG +ILSVRC2012_val_00022250.JPEG +ILSVRC2012_val_00002519.JPEG +ILSVRC2012_val_00016240.JPEG +ILSVRC2012_val_00037746.JPEG +ILSVRC2012_val_00023109.JPEG +ILSVRC2012_val_00021537.JPEG +ILSVRC2012_val_00033137.JPEG +ILSVRC2012_val_00047547.JPEG +ILSVRC2012_val_00029004.JPEG +ILSVRC2012_val_00039427.JPEG +ILSVRC2012_val_00022208.JPEG +ILSVRC2012_val_00005460.JPEG +ILSVRC2012_val_00002653.JPEG +ILSVRC2012_val_00042866.JPEG +ILSVRC2012_val_00028155.JPEG +ILSVRC2012_val_00005570.JPEG +ILSVRC2012_val_00047017.JPEG +ILSVRC2012_val_00020689.JPEG +ILSVRC2012_val_00026170.JPEG +ILSVRC2012_val_00002632.JPEG +ILSVRC2012_val_00017362.JPEG +ILSVRC2012_val_00022859.JPEG +ILSVRC2012_val_00023540.JPEG +ILSVRC2012_val_00012148.JPEG +ILSVRC2012_val_00005630.JPEG +ILSVRC2012_val_00031468.JPEG +ILSVRC2012_val_00026852.JPEG +ILSVRC2012_val_00035849.JPEG +ILSVRC2012_val_00040890.JPEG +ILSVRC2012_val_00038960.JPEG +ILSVRC2012_val_00038107.JPEG +ILSVRC2012_val_00032862.JPEG +ILSVRC2012_val_00045724.JPEG +ILSVRC2012_val_00000756.JPEG +ILSVRC2012_val_00006218.JPEG +ILSVRC2012_val_00032762.JPEG +ILSVRC2012_val_00005937.JPEG +ILSVRC2012_val_00023972.JPEG +ILSVRC2012_val_00036225.JPEG +ILSVRC2012_val_00048892.JPEG +ILSVRC2012_val_00000475.JPEG +ILSVRC2012_val_00042930.JPEG +ILSVRC2012_val_00007759.JPEG +ILSVRC2012_val_00033653.JPEG +ILSVRC2012_val_00001839.JPEG +ILSVRC2012_val_00035020.JPEG +ILSVRC2012_val_00047514.JPEG +ILSVRC2012_val_00042320.JPEG +ILSVRC2012_val_00025587.JPEG +ILSVRC2012_val_00030308.JPEG +ILSVRC2012_val_00046153.JPEG +ILSVRC2012_val_00008152.JPEG +ILSVRC2012_val_00030094.JPEG +ILSVRC2012_val_00005840.JPEG +ILSVRC2012_val_00022277.JPEG +ILSVRC2012_val_00017178.JPEG +ILSVRC2012_val_00008213.JPEG +ILSVRC2012_val_00028163.JPEG +ILSVRC2012_val_00035247.JPEG +ILSVRC2012_val_00013425.JPEG +ILSVRC2012_val_00045463.JPEG +ILSVRC2012_val_00044181.JPEG +ILSVRC2012_val_00045763.JPEG +ILSVRC2012_val_00024670.JPEG +ILSVRC2012_val_00002919.JPEG +ILSVRC2012_val_00046524.JPEG +ILSVRC2012_val_00002308.JPEG +ILSVRC2012_val_00032485.JPEG +ILSVRC2012_val_00003587.JPEG +ILSVRC2012_val_00038132.JPEG +ILSVRC2012_val_00008744.JPEG +ILSVRC2012_val_00040162.JPEG +ILSVRC2012_val_00042851.JPEG +ILSVRC2012_val_00010131.JPEG +ILSVRC2012_val_00005634.JPEG +ILSVRC2012_val_00010350.JPEG +ILSVRC2012_val_00003839.JPEG +ILSVRC2012_val_00042553.JPEG +ILSVRC2012_val_00030256.JPEG +ILSVRC2012_val_00020412.JPEG +ILSVRC2012_val_00013918.JPEG +ILSVRC2012_val_00030064.JPEG +ILSVRC2012_val_00020126.JPEG +ILSVRC2012_val_00048228.JPEG +ILSVRC2012_val_00001057.JPEG +ILSVRC2012_val_00043465.JPEG +ILSVRC2012_val_00011865.JPEG +ILSVRC2012_val_00007322.JPEG +ILSVRC2012_val_00032095.JPEG +ILSVRC2012_val_00005695.JPEG +ILSVRC2012_val_00024544.JPEG +ILSVRC2012_val_00041681.JPEG +ILSVRC2012_val_00018953.JPEG +ILSVRC2012_val_00013998.JPEG +ILSVRC2012_val_00043713.JPEG +ILSVRC2012_val_00046611.JPEG +ILSVRC2012_val_00032514.JPEG +ILSVRC2012_val_00036258.JPEG +ILSVRC2012_val_00029579.JPEG +ILSVRC2012_val_00001465.JPEG +ILSVRC2012_val_00025657.JPEG +ILSVRC2012_val_00039262.JPEG +ILSVRC2012_val_00049591.JPEG +ILSVRC2012_val_00040134.JPEG +ILSVRC2012_val_00032625.JPEG +ILSVRC2012_val_00040958.JPEG +ILSVRC2012_val_00015465.JPEG +ILSVRC2012_val_00015520.JPEG +ILSVRC2012_val_00028644.JPEG +ILSVRC2012_val_00030875.JPEG +ILSVRC2012_val_00037694.JPEG +ILSVRC2012_val_00016264.JPEG +ILSVRC2012_val_00045430.JPEG +ILSVRC2012_val_00014490.JPEG +ILSVRC2012_val_00000743.JPEG +ILSVRC2012_val_00002307.JPEG +ILSVRC2012_val_00047825.JPEG +ILSVRC2012_val_00019072.JPEG +ILSVRC2012_val_00002306.JPEG +ILSVRC2012_val_00038462.JPEG +ILSVRC2012_val_00002780.JPEG +ILSVRC2012_val_00042997.JPEG +ILSVRC2012_val_00016922.JPEG +ILSVRC2012_val_00002883.JPEG +ILSVRC2012_val_00004050.JPEG +ILSVRC2012_val_00047691.JPEG +ILSVRC2012_val_00049395.JPEG +ILSVRC2012_val_00036118.JPEG +ILSVRC2012_val_00037136.JPEG +ILSVRC2012_val_00019575.JPEG +ILSVRC2012_val_00041213.JPEG +ILSVRC2012_val_00041843.JPEG +ILSVRC2012_val_00009905.JPEG +ILSVRC2012_val_00015772.JPEG +ILSVRC2012_val_00031533.JPEG +ILSVRC2012_val_00041091.JPEG +ILSVRC2012_val_00014135.JPEG +ILSVRC2012_val_00003309.JPEG +ILSVRC2012_val_00037958.JPEG +ILSVRC2012_val_00016467.JPEG +ILSVRC2012_val_00045053.JPEG +ILSVRC2012_val_00038758.JPEG +ILSVRC2012_val_00036907.JPEG +ILSVRC2012_val_00022420.JPEG +ILSVRC2012_val_00006488.JPEG +ILSVRC2012_val_00007346.JPEG +ILSVRC2012_val_00016392.JPEG +ILSVRC2012_val_00011360.JPEG +ILSVRC2012_val_00011222.JPEG +ILSVRC2012_val_00004152.JPEG +ILSVRC2012_val_00011655.JPEG +ILSVRC2012_val_00009875.JPEG +ILSVRC2012_val_00028367.JPEG +ILSVRC2012_val_00028252.JPEG +ILSVRC2012_val_00027476.JPEG +ILSVRC2012_val_00033846.JPEG +ILSVRC2012_val_00013033.JPEG +ILSVRC2012_val_00038197.JPEG +ILSVRC2012_val_00048852.JPEG +ILSVRC2012_val_00041002.JPEG +ILSVRC2012_val_00013393.JPEG +ILSVRC2012_val_00035931.JPEG +ILSVRC2012_val_00036174.JPEG +ILSVRC2012_val_00048952.JPEG +ILSVRC2012_val_00030911.JPEG +ILSVRC2012_val_00031274.JPEG +ILSVRC2012_val_00028117.JPEG +ILSVRC2012_val_00047606.JPEG +ILSVRC2012_val_00018830.JPEG +ILSVRC2012_val_00029400.JPEG +ILSVRC2012_val_00045088.JPEG +ILSVRC2012_val_00049979.JPEG +ILSVRC2012_val_00046285.JPEG +ILSVRC2012_val_00008480.JPEG +ILSVRC2012_val_00000991.JPEG +ILSVRC2012_val_00004703.JPEG +ILSVRC2012_val_00019156.JPEG +ILSVRC2012_val_00035242.JPEG +ILSVRC2012_val_00012477.JPEG +ILSVRC2012_val_00008155.JPEG +ILSVRC2012_val_00029221.JPEG +ILSVRC2012_val_00009967.JPEG +ILSVRC2012_val_00003061.JPEG +ILSVRC2012_val_00013749.JPEG +ILSVRC2012_val_00031630.JPEG +ILSVRC2012_val_00037538.JPEG +ILSVRC2012_val_00028070.JPEG +ILSVRC2012_val_00047309.JPEG +ILSVRC2012_val_00019380.JPEG +ILSVRC2012_val_00021557.JPEG +ILSVRC2012_val_00007851.JPEG +ILSVRC2012_val_00006060.JPEG +ILSVRC2012_val_00044137.JPEG +ILSVRC2012_val_00036483.JPEG +ILSVRC2012_val_00016764.JPEG +ILSVRC2012_val_00033415.JPEG +ILSVRC2012_val_00013931.JPEG +ILSVRC2012_val_00004293.JPEG +ILSVRC2012_val_00005162.JPEG +ILSVRC2012_val_00012551.JPEG +ILSVRC2012_val_00047951.JPEG +ILSVRC2012_val_00020049.JPEG +ILSVRC2012_val_00008988.JPEG +ILSVRC2012_val_00019163.JPEG +ILSVRC2012_val_00002103.JPEG +ILSVRC2012_val_00045725.JPEG +ILSVRC2012_val_00013505.JPEG +ILSVRC2012_val_00024425.JPEG +ILSVRC2012_val_00006138.JPEG +ILSVRC2012_val_00043650.JPEG +ILSVRC2012_val_00035683.JPEG +ILSVRC2012_val_00023250.JPEG +ILSVRC2012_val_00024823.JPEG +ILSVRC2012_val_00002089.JPEG +ILSVRC2012_val_00045908.JPEG +ILSVRC2012_val_00031659.JPEG +ILSVRC2012_val_00019345.JPEG +ILSVRC2012_val_00003995.JPEG +ILSVRC2012_val_00034957.JPEG +ILSVRC2012_val_00002386.JPEG +ILSVRC2012_val_00026607.JPEG +ILSVRC2012_val_00008639.JPEG +ILSVRC2012_val_00007973.JPEG +ILSVRC2012_val_00032983.JPEG +ILSVRC2012_val_00020700.JPEG +ILSVRC2012_val_00009241.JPEG +ILSVRC2012_val_00032700.JPEG +ILSVRC2012_val_00005507.JPEG +ILSVRC2012_val_00047855.JPEG +ILSVRC2012_val_00047565.JPEG +ILSVRC2012_val_00012868.JPEG +ILSVRC2012_val_00043074.JPEG +ILSVRC2012_val_00020772.JPEG +ILSVRC2012_val_00048977.JPEG +ILSVRC2012_val_00040853.JPEG +ILSVRC2012_val_00008856.JPEG +ILSVRC2012_val_00038137.JPEG +ILSVRC2012_val_00008577.JPEG +ILSVRC2012_val_00025063.JPEG +ILSVRC2012_val_00020776.JPEG +ILSVRC2012_val_00048703.JPEG +ILSVRC2012_val_00019222.JPEG +ILSVRC2012_val_00011391.JPEG +ILSVRC2012_val_00008659.JPEG +ILSVRC2012_val_00014054.JPEG +ILSVRC2012_val_00024640.JPEG +ILSVRC2012_val_00039121.JPEG +ILSVRC2012_val_00029263.JPEG +ILSVRC2012_val_00049198.JPEG +ILSVRC2012_val_00039055.JPEG +ILSVRC2012_val_00026956.JPEG +ILSVRC2012_val_00033099.JPEG +ILSVRC2012_val_00008719.JPEG +ILSVRC2012_val_00028123.JPEG +ILSVRC2012_val_00011843.JPEG +ILSVRC2012_val_00041862.JPEG +ILSVRC2012_val_00024149.JPEG +ILSVRC2012_val_00034611.JPEG +ILSVRC2012_val_00047471.JPEG +ILSVRC2012_val_00011981.JPEG +ILSVRC2012_val_00020906.JPEG +ILSVRC2012_val_00032103.JPEG +ILSVRC2012_val_00002536.JPEG +ILSVRC2012_val_00042972.JPEG +ILSVRC2012_val_00030881.JPEG +ILSVRC2012_val_00007324.JPEG +ILSVRC2012_val_00035794.JPEG +ILSVRC2012_val_00015771.JPEG +ILSVRC2012_val_00015277.JPEG +ILSVRC2012_val_00031183.JPEG +ILSVRC2012_val_00028776.JPEG +ILSVRC2012_val_00008479.JPEG +ILSVRC2012_val_00025897.JPEG +ILSVRC2012_val_00015059.JPEG +ILSVRC2012_val_00040884.JPEG +ILSVRC2012_val_00026043.JPEG +ILSVRC2012_val_00048130.JPEG +ILSVRC2012_val_00000529.JPEG +ILSVRC2012_val_00038747.JPEG +ILSVRC2012_val_00045152.JPEG +ILSVRC2012_val_00022592.JPEG +ILSVRC2012_val_00022496.JPEG +ILSVRC2012_val_00044743.JPEG +ILSVRC2012_val_00034520.JPEG +ILSVRC2012_val_00029307.JPEG +ILSVRC2012_val_00003958.JPEG +ILSVRC2012_val_00019856.JPEG +ILSVRC2012_val_00041077.JPEG +ILSVRC2012_val_00033832.JPEG +ILSVRC2012_val_00002400.JPEG +ILSVRC2012_val_00011760.JPEG +ILSVRC2012_val_00032296.JPEG +ILSVRC2012_val_00038604.JPEG +ILSVRC2012_val_00046405.JPEG +ILSVRC2012_val_00015149.JPEG +ILSVRC2012_val_00042509.JPEG +ILSVRC2012_val_00049643.JPEG +ILSVRC2012_val_00017092.JPEG +ILSVRC2012_val_00028660.JPEG +ILSVRC2012_val_00024830.JPEG +ILSVRC2012_val_00043697.JPEG +ILSVRC2012_val_00035607.JPEG +ILSVRC2012_val_00033790.JPEG +ILSVRC2012_val_00019238.JPEG +ILSVRC2012_val_00033208.JPEG +ILSVRC2012_val_00042927.JPEG +ILSVRC2012_val_00048822.JPEG +ILSVRC2012_val_00023042.JPEG +ILSVRC2012_val_00029271.JPEG +ILSVRC2012_val_00033482.JPEG +ILSVRC2012_val_00009849.JPEG +ILSVRC2012_val_00042792.JPEG +ILSVRC2012_val_00049425.JPEG +ILSVRC2012_val_00039020.JPEG +ILSVRC2012_val_00010490.JPEG +ILSVRC2012_val_00016720.JPEG +ILSVRC2012_val_00007554.JPEG +ILSVRC2012_val_00048529.JPEG +ILSVRC2012_val_00017891.JPEG +ILSVRC2012_val_00016015.JPEG +ILSVRC2012_val_00013151.JPEG +ILSVRC2012_val_00048648.JPEG +ILSVRC2012_val_00004890.JPEG +ILSVRC2012_val_00040127.JPEG +ILSVRC2012_val_00049159.JPEG +ILSVRC2012_val_00019984.JPEG +ILSVRC2012_val_00025114.JPEG +ILSVRC2012_val_00046208.JPEG +ILSVRC2012_val_00011790.JPEG +ILSVRC2012_val_00007901.JPEG +ILSVRC2012_val_00041572.JPEG +ILSVRC2012_val_00045550.JPEG +ILSVRC2012_val_00011115.JPEG +ILSVRC2012_val_00025566.JPEG +ILSVRC2012_val_00009610.JPEG +ILSVRC2012_val_00026207.JPEG +ILSVRC2012_val_00032441.JPEG +ILSVRC2012_val_00018659.JPEG +ILSVRC2012_val_00018167.JPEG +ILSVRC2012_val_00001920.JPEG +ILSVRC2012_val_00041171.JPEG +ILSVRC2012_val_00017795.JPEG +ILSVRC2012_val_00008371.JPEG +ILSVRC2012_val_00048688.JPEG +ILSVRC2012_val_00034689.JPEG +ILSVRC2012_val_00020218.JPEG +ILSVRC2012_val_00009045.JPEG +ILSVRC2012_val_00015289.JPEG +ILSVRC2012_val_00041347.JPEG +ILSVRC2012_val_00026708.JPEG +ILSVRC2012_val_00025533.JPEG +ILSVRC2012_val_00038134.JPEG +ILSVRC2012_val_00043398.JPEG +ILSVRC2012_val_00014728.JPEG +ILSVRC2012_val_00036639.JPEG +ILSVRC2012_val_00022330.JPEG +ILSVRC2012_val_00034615.JPEG +ILSVRC2012_val_00019758.JPEG +ILSVRC2012_val_00032140.JPEG +ILSVRC2012_val_00006486.JPEG +ILSVRC2012_val_00045767.JPEG +ILSVRC2012_val_00019655.JPEG +ILSVRC2012_val_00048700.JPEG +ILSVRC2012_val_00000271.JPEG +ILSVRC2012_val_00022462.JPEG +ILSVRC2012_val_00032627.JPEG +ILSVRC2012_val_00035407.JPEG +ILSVRC2012_val_00039633.JPEG +ILSVRC2012_val_00004198.JPEG +ILSVRC2012_val_00038016.JPEG +ILSVRC2012_val_00020756.JPEG +ILSVRC2012_val_00024042.JPEG +ILSVRC2012_val_00014271.JPEG +ILSVRC2012_val_00042143.JPEG +ILSVRC2012_val_00004550.JPEG +ILSVRC2012_val_00034427.JPEG +ILSVRC2012_val_00027122.JPEG +ILSVRC2012_val_00019177.JPEG +ILSVRC2012_val_00039071.JPEG +ILSVRC2012_val_00030257.JPEG +ILSVRC2012_val_00048385.JPEG +ILSVRC2012_val_00035943.JPEG +ILSVRC2012_val_00035051.JPEG +ILSVRC2012_val_00045030.JPEG +ILSVRC2012_val_00036665.JPEG +ILSVRC2012_val_00031717.JPEG +ILSVRC2012_val_00048803.JPEG +ILSVRC2012_val_00037496.JPEG +ILSVRC2012_val_00042947.JPEG +ILSVRC2012_val_00004996.JPEG +ILSVRC2012_val_00015792.JPEG +ILSVRC2012_val_00020440.JPEG +ILSVRC2012_val_00032131.JPEG +ILSVRC2012_val_00014082.JPEG +ILSVRC2012_val_00008541.JPEG +ILSVRC2012_val_00010656.JPEG +ILSVRC2012_val_00042816.JPEG +ILSVRC2012_val_00044235.JPEG +ILSVRC2012_val_00001677.JPEG +ILSVRC2012_val_00049380.JPEG +ILSVRC2012_val_00011331.JPEG +ILSVRC2012_val_00003250.JPEG +ILSVRC2012_val_00002093.JPEG +ILSVRC2012_val_00039968.JPEG +ILSVRC2012_val_00036862.JPEG +ILSVRC2012_val_00003271.JPEG +ILSVRC2012_val_00010363.JPEG +ILSVRC2012_val_00011065.JPEG +ILSVRC2012_val_00044252.JPEG +ILSVRC2012_val_00027441.JPEG +ILSVRC2012_val_00041536.JPEG +ILSVRC2012_val_00038055.JPEG +ILSVRC2012_val_00016065.JPEG +ILSVRC2012_val_00014448.JPEG +ILSVRC2012_val_00018300.JPEG +ILSVRC2012_val_00012362.JPEG +ILSVRC2012_val_00014531.JPEG +ILSVRC2012_val_00045987.JPEG +ILSVRC2012_val_00008507.JPEG +ILSVRC2012_val_00035787.JPEG +ILSVRC2012_val_00047898.JPEG +ILSVRC2012_val_00034565.JPEG +ILSVRC2012_val_00013483.JPEG +ILSVRC2012_val_00011965.JPEG +ILSVRC2012_val_00016313.JPEG +ILSVRC2012_val_00049255.JPEG +ILSVRC2012_val_00035663.JPEG +ILSVRC2012_val_00005235.JPEG +ILSVRC2012_val_00017694.JPEG +ILSVRC2012_val_00029142.JPEG +ILSVRC2012_val_00046709.JPEG +ILSVRC2012_val_00035652.JPEG +ILSVRC2012_val_00000064.JPEG +ILSVRC2012_val_00014071.JPEG +ILSVRC2012_val_00040381.JPEG +ILSVRC2012_val_00037088.JPEG +ILSVRC2012_val_00003202.JPEG +ILSVRC2012_val_00012863.JPEG +ILSVRC2012_val_00024408.JPEG +ILSVRC2012_val_00030827.JPEG +ILSVRC2012_val_00020602.JPEG +ILSVRC2012_val_00028839.JPEG +ILSVRC2012_val_00011966.JPEG +ILSVRC2012_val_00010521.JPEG +ILSVRC2012_val_00020467.JPEG +ILSVRC2012_val_00048823.JPEG +ILSVRC2012_val_00039047.JPEG +ILSVRC2012_val_00018855.JPEG +ILSVRC2012_val_00039687.JPEG +ILSVRC2012_val_00038864.JPEG +ILSVRC2012_val_00021511.JPEG +ILSVRC2012_val_00013243.JPEG +ILSVRC2012_val_00008668.JPEG +ILSVRC2012_val_00023300.JPEG +ILSVRC2012_val_00002959.JPEG +ILSVRC2012_val_00045028.JPEG +ILSVRC2012_val_00039313.JPEG +ILSVRC2012_val_00021724.JPEG +ILSVRC2012_val_00035742.JPEG +ILSVRC2012_val_00018341.JPEG +ILSVRC2012_val_00002594.JPEG +ILSVRC2012_val_00012435.JPEG +ILSVRC2012_val_00045098.JPEG +ILSVRC2012_val_00028664.JPEG +ILSVRC2012_val_00016549.JPEG +ILSVRC2012_val_00042229.JPEG +ILSVRC2012_val_00024197.JPEG +ILSVRC2012_val_00005416.JPEG +ILSVRC2012_val_00022368.JPEG +ILSVRC2012_val_00002096.JPEG +ILSVRC2012_val_00049648.JPEG +ILSVRC2012_val_00020529.JPEG +ILSVRC2012_val_00017845.JPEG +ILSVRC2012_val_00047113.JPEG +ILSVRC2012_val_00009877.JPEG +ILSVRC2012_val_00002483.JPEG +ILSVRC2012_val_00002831.JPEG +ILSVRC2012_val_00006568.JPEG +ILSVRC2012_val_00002169.JPEG +ILSVRC2012_val_00027812.JPEG +ILSVRC2012_val_00046603.JPEG +ILSVRC2012_val_00018295.JPEG +ILSVRC2012_val_00009389.JPEG +ILSVRC2012_val_00048649.JPEG +ILSVRC2012_val_00021357.JPEG +ILSVRC2012_val_00040190.JPEG +ILSVRC2012_val_00006610.JPEG +ILSVRC2012_val_00003903.JPEG +ILSVRC2012_val_00036455.JPEG +ILSVRC2012_val_00028314.JPEG +ILSVRC2012_val_00000939.JPEG +ILSVRC2012_val_00026815.JPEG +ILSVRC2012_val_00043592.JPEG +ILSVRC2012_val_00034770.JPEG +ILSVRC2012_val_00025436.JPEG +ILSVRC2012_val_00001508.JPEG +ILSVRC2012_val_00017608.JPEG +ILSVRC2012_val_00033371.JPEG +ILSVRC2012_val_00006660.JPEG +ILSVRC2012_val_00002475.JPEG +ILSVRC2012_val_00003979.JPEG +ILSVRC2012_val_00007582.JPEG +ILSVRC2012_val_00015809.JPEG +ILSVRC2012_val_00016929.JPEG +ILSVRC2012_val_00012636.JPEG +ILSVRC2012_val_00006796.JPEG +ILSVRC2012_val_00014756.JPEG +ILSVRC2012_val_00045859.JPEG +ILSVRC2012_val_00008067.JPEG +ILSVRC2012_val_00043924.JPEG +ILSVRC2012_val_00043566.JPEG +ILSVRC2012_val_00035598.JPEG +ILSVRC2012_val_00011476.JPEG +ILSVRC2012_val_00030448.JPEG +ILSVRC2012_val_00033252.JPEG +ILSVRC2012_val_00023921.JPEG +ILSVRC2012_val_00043464.JPEG +ILSVRC2012_val_00017705.JPEG +ILSVRC2012_val_00002430.JPEG +ILSVRC2012_val_00013144.JPEG +ILSVRC2012_val_00046957.JPEG +ILSVRC2012_val_00003657.JPEG +ILSVRC2012_val_00009191.JPEG +ILSVRC2012_val_00041021.JPEG +ILSVRC2012_val_00010893.JPEG +ILSVRC2012_val_00041492.JPEG +ILSVRC2012_val_00037181.JPEG +ILSVRC2012_val_00018399.JPEG +ILSVRC2012_val_00002871.JPEG +ILSVRC2012_val_00017934.JPEG +ILSVRC2012_val_00044354.JPEG +ILSVRC2012_val_00048738.JPEG +ILSVRC2012_val_00032995.JPEG +ILSVRC2012_val_00026662.JPEG +ILSVRC2012_val_00028621.JPEG +ILSVRC2012_val_00022416.JPEG +ILSVRC2012_val_00001122.JPEG +ILSVRC2012_val_00047103.JPEG +ILSVRC2012_val_00013296.JPEG +ILSVRC2012_val_00048553.JPEG +ILSVRC2012_val_00036223.JPEG +ILSVRC2012_val_00040931.JPEG +ILSVRC2012_val_00040780.JPEG +ILSVRC2012_val_00011744.JPEG +ILSVRC2012_val_00018607.JPEG +ILSVRC2012_val_00045840.JPEG +ILSVRC2012_val_00035515.JPEG +ILSVRC2012_val_00015845.JPEG +ILSVRC2012_val_00012866.JPEG +ILSVRC2012_val_00009160.JPEG +ILSVRC2012_val_00036299.JPEG +ILSVRC2012_val_00016127.JPEG +ILSVRC2012_val_00035042.JPEG +ILSVRC2012_val_00013952.JPEG +ILSVRC2012_val_00035681.JPEG +ILSVRC2012_val_00018902.JPEG +ILSVRC2012_val_00001928.JPEG +ILSVRC2012_val_00020534.JPEG +ILSVRC2012_val_00026825.JPEG +ILSVRC2012_val_00024390.JPEG +ILSVRC2012_val_00035400.JPEG +ILSVRC2012_val_00005350.JPEG +ILSVRC2012_val_00042218.JPEG +ILSVRC2012_val_00028305.JPEG +ILSVRC2012_val_00048533.JPEG +ILSVRC2012_val_00017773.JPEG +ILSVRC2012_val_00002682.JPEG +ILSVRC2012_val_00033249.JPEG +ILSVRC2012_val_00011750.JPEG +ILSVRC2012_val_00034409.JPEG +ILSVRC2012_val_00025482.JPEG +ILSVRC2012_val_00006051.JPEG +ILSVRC2012_val_00003953.JPEG +ILSVRC2012_val_00003355.JPEG +ILSVRC2012_val_00039970.JPEG +ILSVRC2012_val_00020045.JPEG +ILSVRC2012_val_00027306.JPEG +ILSVRC2012_val_00019551.JPEG +ILSVRC2012_val_00041592.JPEG +ILSVRC2012_val_00041928.JPEG +ILSVRC2012_val_00041135.JPEG +ILSVRC2012_val_00030093.JPEG +ILSVRC2012_val_00007727.JPEG +ILSVRC2012_val_00013486.JPEG +ILSVRC2012_val_00027470.JPEG +ILSVRC2012_val_00020454.JPEG +ILSVRC2012_val_00036218.JPEG +ILSVRC2012_val_00019828.JPEG +ILSVRC2012_val_00033342.JPEG +ILSVRC2012_val_00041935.JPEG +ILSVRC2012_val_00038898.JPEG +ILSVRC2012_val_00014670.JPEG +ILSVRC2012_val_00012320.JPEG +ILSVRC2012_val_00021553.JPEG +ILSVRC2012_val_00047647.JPEG +ILSVRC2012_val_00026646.JPEG +ILSVRC2012_val_00021094.JPEG +ILSVRC2012_val_00010627.JPEG +ILSVRC2012_val_00002100.JPEG +ILSVRC2012_val_00048098.JPEG +ILSVRC2012_val_00014584.JPEG +ILSVRC2012_val_00021875.JPEG +ILSVRC2012_val_00009189.JPEG +ILSVRC2012_val_00049161.JPEG +ILSVRC2012_val_00042937.JPEG +ILSVRC2012_val_00026346.JPEG +ILSVRC2012_val_00018522.JPEG +ILSVRC2012_val_00031552.JPEG +ILSVRC2012_val_00040727.JPEG +ILSVRC2012_val_00048157.JPEG +ILSVRC2012_val_00009411.JPEG +ILSVRC2012_val_00036844.JPEG +ILSVRC2012_val_00025463.JPEG +ILSVRC2012_val_00006942.JPEG +ILSVRC2012_val_00015545.JPEG +ILSVRC2012_val_00022865.JPEG +ILSVRC2012_val_00031888.JPEG +ILSVRC2012_val_00023278.JPEG +ILSVRC2012_val_00027084.JPEG +ILSVRC2012_val_00007051.JPEG +ILSVRC2012_val_00047956.JPEG +ILSVRC2012_val_00010808.JPEG +ILSVRC2012_val_00015412.JPEG +ILSVRC2012_val_00035912.JPEG +ILSVRC2012_val_00048499.JPEG +ILSVRC2012_val_00007502.JPEG +ILSVRC2012_val_00030389.JPEG +ILSVRC2012_val_00015723.JPEG +ILSVRC2012_val_00045788.JPEG +ILSVRC2012_val_00006797.JPEG +ILSVRC2012_val_00031077.JPEG +ILSVRC2012_val_00012240.JPEG +ILSVRC2012_val_00018650.JPEG +ILSVRC2012_val_00002292.JPEG +ILSVRC2012_val_00006352.JPEG +ILSVRC2012_val_00028561.JPEG +ILSVRC2012_val_00007654.JPEG +ILSVRC2012_val_00024628.JPEG +ILSVRC2012_val_00014159.JPEG +ILSVRC2012_val_00009901.JPEG +ILSVRC2012_val_00011253.JPEG +ILSVRC2012_val_00009785.JPEG +ILSVRC2012_val_00015680.JPEG +ILSVRC2012_val_00012673.JPEG +ILSVRC2012_val_00017751.JPEG +ILSVRC2012_val_00017695.JPEG +ILSVRC2012_val_00045978.JPEG +ILSVRC2012_val_00003098.JPEG +ILSVRC2012_val_00038844.JPEG +ILSVRC2012_val_00007997.JPEG +ILSVRC2012_val_00010210.JPEG +ILSVRC2012_val_00013626.JPEG +ILSVRC2012_val_00036256.JPEG +ILSVRC2012_val_00019421.JPEG +ILSVRC2012_val_00042756.JPEG +ILSVRC2012_val_00036780.JPEG +ILSVRC2012_val_00049405.JPEG +ILSVRC2012_val_00004052.JPEG +ILSVRC2012_val_00032203.JPEG +ILSVRC2012_val_00029416.JPEG +ILSVRC2012_val_00042738.JPEG +ILSVRC2012_val_00037080.JPEG +ILSVRC2012_val_00034323.JPEG +ILSVRC2012_val_00025471.JPEG +ILSVRC2012_val_00020181.JPEG +ILSVRC2012_val_00009028.JPEG +ILSVRC2012_val_00014014.JPEG +ILSVRC2012_val_00024754.JPEG +ILSVRC2012_val_00042010.JPEG +ILSVRC2012_val_00002183.JPEG +ILSVRC2012_val_00043121.JPEG +ILSVRC2012_val_00008583.JPEG +ILSVRC2012_val_00048876.JPEG +ILSVRC2012_val_00022088.JPEG +ILSVRC2012_val_00022229.JPEG +ILSVRC2012_val_00045498.JPEG +ILSVRC2012_val_00040461.JPEG +ILSVRC2012_val_00005642.JPEG +ILSVRC2012_val_00040335.JPEG +ILSVRC2012_val_00049148.JPEG +ILSVRC2012_val_00022899.JPEG +ILSVRC2012_val_00019453.JPEG +ILSVRC2012_val_00036950.JPEG +ILSVRC2012_val_00012715.JPEG +ILSVRC2012_val_00035640.JPEG +ILSVRC2012_val_00030699.JPEG +ILSVRC2012_val_00023621.JPEG +ILSVRC2012_val_00014399.JPEG +ILSVRC2012_val_00043517.JPEG +ILSVRC2012_val_00022839.JPEG +ILSVRC2012_val_00016554.JPEG +ILSVRC2012_val_00030253.JPEG +ILSVRC2012_val_00006407.JPEG +ILSVRC2012_val_00038893.JPEG +ILSVRC2012_val_00017657.JPEG +ILSVRC2012_val_00020963.JPEG +ILSVRC2012_val_00000214.JPEG +ILSVRC2012_val_00016120.JPEG +ILSVRC2012_val_00013366.JPEG +ILSVRC2012_val_00004528.JPEG +ILSVRC2012_val_00031793.JPEG +ILSVRC2012_val_00036181.JPEG +ILSVRC2012_val_00022251.JPEG +ILSVRC2012_val_00001243.JPEG +ILSVRC2012_val_00049614.JPEG +ILSVRC2012_val_00025892.JPEG +ILSVRC2012_val_00049178.JPEG +ILSVRC2012_val_00015559.JPEG +ILSVRC2012_val_00004886.JPEG +ILSVRC2012_val_00032770.JPEG +ILSVRC2012_val_00002329.JPEG +ILSVRC2012_val_00014312.JPEG +ILSVRC2012_val_00009080.JPEG +ILSVRC2012_val_00023172.JPEG +ILSVRC2012_val_00020222.JPEG +ILSVRC2012_val_00004731.JPEG +ILSVRC2012_val_00022875.JPEG +ILSVRC2012_val_00049660.JPEG +ILSVRC2012_val_00047241.JPEG +ILSVRC2012_val_00035907.JPEG +ILSVRC2012_val_00048044.JPEG +ILSVRC2012_val_00020309.JPEG +ILSVRC2012_val_00038000.JPEG +ILSVRC2012_val_00019572.JPEG +ILSVRC2012_val_00017969.JPEG +ILSVRC2012_val_00010032.JPEG +ILSVRC2012_val_00004989.JPEG +ILSVRC2012_val_00031061.JPEG +ILSVRC2012_val_00046862.JPEG +ILSVRC2012_val_00016226.JPEG +ILSVRC2012_val_00010398.JPEG +ILSVRC2012_val_00049439.JPEG +ILSVRC2012_val_00038870.JPEG +ILSVRC2012_val_00015639.JPEG +ILSVRC2012_val_00022987.JPEG +ILSVRC2012_val_00044054.JPEG +ILSVRC2012_val_00037555.JPEG +ILSVRC2012_val_00007514.JPEG +ILSVRC2012_val_00016663.JPEG +ILSVRC2012_val_00036758.JPEG +ILSVRC2012_val_00032834.JPEG +ILSVRC2012_val_00033200.JPEG +ILSVRC2012_val_00029042.JPEG +ILSVRC2012_val_00037804.JPEG +ILSVRC2012_val_00048264.JPEG +ILSVRC2012_val_00045954.JPEG +ILSVRC2012_val_00002003.JPEG +ILSVRC2012_val_00030895.JPEG +ILSVRC2012_val_00006624.JPEG +ILSVRC2012_val_00025492.JPEG +ILSVRC2012_val_00007371.JPEG +ILSVRC2012_val_00005608.JPEG +ILSVRC2012_val_00020167.JPEG +ILSVRC2012_val_00026567.JPEG +ILSVRC2012_val_00037416.JPEG +ILSVRC2012_val_00034306.JPEG +ILSVRC2012_val_00034688.JPEG +ILSVRC2012_val_00041551.JPEG +ILSVRC2012_val_00041966.JPEG +ILSVRC2012_val_00026370.JPEG +ILSVRC2012_val_00020804.JPEG +ILSVRC2012_val_00018845.JPEG +ILSVRC2012_val_00012052.JPEG +ILSVRC2012_val_00026985.JPEG +ILSVRC2012_val_00007615.JPEG +ILSVRC2012_val_00027589.JPEG +ILSVRC2012_val_00031696.JPEG +ILSVRC2012_val_00008662.JPEG +ILSVRC2012_val_00041147.JPEG +ILSVRC2012_val_00043429.JPEG +ILSVRC2012_val_00001089.JPEG +ILSVRC2012_val_00041460.JPEG +ILSVRC2012_val_00041101.JPEG +ILSVRC2012_val_00043554.JPEG +ILSVRC2012_val_00014388.JPEG +ILSVRC2012_val_00016486.JPEG +ILSVRC2012_val_00015673.JPEG +ILSVRC2012_val_00033281.JPEG +ILSVRC2012_val_00045916.JPEG +ILSVRC2012_val_00023388.JPEG +ILSVRC2012_val_00032190.JPEG +ILSVRC2012_val_00010227.JPEG +ILSVRC2012_val_00021052.JPEG +ILSVRC2012_val_00000056.JPEG +ILSVRC2012_val_00031921.JPEG +ILSVRC2012_val_00016209.JPEG +ILSVRC2012_val_00001600.JPEG +ILSVRC2012_val_00049397.JPEG +ILSVRC2012_val_00042665.JPEG +ILSVRC2012_val_00045917.JPEG +ILSVRC2012_val_00006931.JPEG +ILSVRC2012_val_00021900.JPEG +ILSVRC2012_val_00004526.JPEG +ILSVRC2012_val_00010975.JPEG +ILSVRC2012_val_00006573.JPEG +ILSVRC2012_val_00034883.JPEG +ILSVRC2012_val_00032120.JPEG +ILSVRC2012_val_00009606.JPEG +ILSVRC2012_val_00039745.JPEG +ILSVRC2012_val_00036624.JPEG +ILSVRC2012_val_00034139.JPEG +ILSVRC2012_val_00026600.JPEG +ILSVRC2012_val_00035856.JPEG +ILSVRC2012_val_00039822.JPEG +ILSVRC2012_val_00025545.JPEG +ILSVRC2012_val_00011946.JPEG +ILSVRC2012_val_00028736.JPEG +ILSVRC2012_val_00030298.JPEG +ILSVRC2012_val_00024148.JPEG +ILSVRC2012_val_00011624.JPEG +ILSVRC2012_val_00035100.JPEG +ILSVRC2012_val_00019330.JPEG +ILSVRC2012_val_00045205.JPEG +ILSVRC2012_val_00024442.JPEG +ILSVRC2012_val_00049086.JPEG +ILSVRC2012_val_00017679.JPEG +ILSVRC2012_val_00039384.JPEG +ILSVRC2012_val_00032356.JPEG +ILSVRC2012_val_00002679.JPEG +ILSVRC2012_val_00044150.JPEG +ILSVRC2012_val_00044301.JPEG +ILSVRC2012_val_00044703.JPEG +ILSVRC2012_val_00040281.JPEG +ILSVRC2012_val_00035360.JPEG +ILSVRC2012_val_00002826.JPEG +ILSVRC2012_val_00011048.JPEG +ILSVRC2012_val_00044247.JPEG +ILSVRC2012_val_00033556.JPEG +ILSVRC2012_val_00019066.JPEG +ILSVRC2012_val_00027208.JPEG +ILSVRC2012_val_00041850.JPEG +ILSVRC2012_val_00028771.JPEG +ILSVRC2012_val_00037561.JPEG +ILSVRC2012_val_00018385.JPEG +ILSVRC2012_val_00000474.JPEG +ILSVRC2012_val_00025888.JPEG +ILSVRC2012_val_00006282.JPEG +ILSVRC2012_val_00018780.JPEG +ILSVRC2012_val_00003970.JPEG +ILSVRC2012_val_00015294.JPEG +ILSVRC2012_val_00036770.JPEG +ILSVRC2012_val_00039511.JPEG +ILSVRC2012_val_00042783.JPEG +ILSVRC2012_val_00021458.JPEG +ILSVRC2012_val_00005289.JPEG +ILSVRC2012_val_00037116.JPEG +ILSVRC2012_val_00039495.JPEG +ILSVRC2012_val_00046083.JPEG +ILSVRC2012_val_00034938.JPEG +ILSVRC2012_val_00012282.JPEG +ILSVRC2012_val_00032655.JPEG +ILSVRC2012_val_00005677.JPEG +ILSVRC2012_val_00014735.JPEG +ILSVRC2012_val_00024027.JPEG +ILSVRC2012_val_00034736.JPEG +ILSVRC2012_val_00040671.JPEG +ILSVRC2012_val_00045781.JPEG +ILSVRC2012_val_00019205.JPEG +ILSVRC2012_val_00047756.JPEG +ILSVRC2012_val_00034827.JPEG +ILSVRC2012_val_00042515.JPEG +ILSVRC2012_val_00015061.JPEG +ILSVRC2012_val_00007267.JPEG +ILSVRC2012_val_00021405.JPEG +ILSVRC2012_val_00034718.JPEG +ILSVRC2012_val_00003662.JPEG +ILSVRC2012_val_00036137.JPEG +ILSVRC2012_val_00002856.JPEG +ILSVRC2012_val_00021647.JPEG +ILSVRC2012_val_00028606.JPEG +ILSVRC2012_val_00029886.JPEG +ILSVRC2012_val_00028169.JPEG +ILSVRC2012_val_00047804.JPEG +ILSVRC2012_val_00046069.JPEG +ILSVRC2012_val_00038345.JPEG +ILSVRC2012_val_00025485.JPEG +ILSVRC2012_val_00009950.JPEG +ILSVRC2012_val_00036214.JPEG +ILSVRC2012_val_00042459.JPEG +ILSVRC2012_val_00017378.JPEG +ILSVRC2012_val_00016437.JPEG +ILSVRC2012_val_00041304.JPEG +ILSVRC2012_val_00026942.JPEG +ILSVRC2012_val_00019682.JPEG +ILSVRC2012_val_00023072.JPEG +ILSVRC2012_val_00038500.JPEG +ILSVRC2012_val_00036677.JPEG +ILSVRC2012_val_00042950.JPEG +ILSVRC2012_val_00038601.JPEG +ILSVRC2012_val_00044350.JPEG +ILSVRC2012_val_00012539.JPEG +ILSVRC2012_val_00038743.JPEG +ILSVRC2012_val_00000332.JPEG +ILSVRC2012_val_00018147.JPEG +ILSVRC2012_val_00025257.JPEG +ILSVRC2012_val_00020358.JPEG +ILSVRC2012_val_00028135.JPEG +ILSVRC2012_val_00032194.JPEG +ILSVRC2012_val_00038078.JPEG +ILSVRC2012_val_00043137.JPEG +ILSVRC2012_val_00014555.JPEG +ILSVRC2012_val_00015114.JPEG +ILSVRC2012_val_00033450.JPEG +ILSVRC2012_val_00030624.JPEG +ILSVRC2012_val_00033839.JPEG +ILSVRC2012_val_00014992.JPEG +ILSVRC2012_val_00013794.JPEG +ILSVRC2012_val_00029575.JPEG +ILSVRC2012_val_00005414.JPEG +ILSVRC2012_val_00036124.JPEG +ILSVRC2012_val_00019736.JPEG +ILSVRC2012_val_00013671.JPEG +ILSVRC2012_val_00017682.JPEG +ILSVRC2012_val_00000386.JPEG +ILSVRC2012_val_00009750.JPEG +ILSVRC2012_val_00007126.JPEG +ILSVRC2012_val_00043183.JPEG +ILSVRC2012_val_00000129.JPEG +ILSVRC2012_val_00029149.JPEG +ILSVRC2012_val_00030337.JPEG +ILSVRC2012_val_00003994.JPEG +ILSVRC2012_val_00010641.JPEG +ILSVRC2012_val_00003724.JPEG +ILSVRC2012_val_00005869.JPEG +ILSVRC2012_val_00026865.JPEG +ILSVRC2012_val_00003268.JPEG +ILSVRC2012_val_00012836.JPEG +ILSVRC2012_val_00007036.JPEG +ILSVRC2012_val_00017849.JPEG +ILSVRC2012_val_00040650.JPEG +ILSVRC2012_val_00001700.JPEG +ILSVRC2012_val_00028207.JPEG +ILSVRC2012_val_00047630.JPEG +ILSVRC2012_val_00009296.JPEG +ILSVRC2012_val_00016094.JPEG +ILSVRC2012_val_00013982.JPEG +ILSVRC2012_val_00039923.JPEG +ILSVRC2012_val_00008717.JPEG +ILSVRC2012_val_00045220.JPEG +ILSVRC2012_val_00021786.JPEG +ILSVRC2012_val_00012607.JPEG +ILSVRC2012_val_00004361.JPEG +ILSVRC2012_val_00038176.JPEG +ILSVRC2012_val_00012939.JPEG +ILSVRC2012_val_00008075.JPEG +ILSVRC2012_val_00018009.JPEG +ILSVRC2012_val_00049520.JPEG +ILSVRC2012_val_00011023.JPEG +ILSVRC2012_val_00014212.JPEG +ILSVRC2012_val_00026140.JPEG +ILSVRC2012_val_00022816.JPEG +ILSVRC2012_val_00007918.JPEG +ILSVRC2012_val_00002947.JPEG +ILSVRC2012_val_00040138.JPEG +ILSVRC2012_val_00015292.JPEG +ILSVRC2012_val_00031952.JPEG +ILSVRC2012_val_00047595.JPEG +ILSVRC2012_val_00000985.JPEG +ILSVRC2012_val_00000814.JPEG +ILSVRC2012_val_00047175.JPEG +ILSVRC2012_val_00034220.JPEG +ILSVRC2012_val_00032733.JPEG +ILSVRC2012_val_00012293.JPEG +ILSVRC2012_val_00005306.JPEG +ILSVRC2012_val_00037049.JPEG +ILSVRC2012_val_00043913.JPEG +ILSVRC2012_val_00038622.JPEG +ILSVRC2012_val_00033495.JPEG +ILSVRC2012_val_00008254.JPEG +ILSVRC2012_val_00045683.JPEG +ILSVRC2012_val_00006659.JPEG +ILSVRC2012_val_00022537.JPEG +ILSVRC2012_val_00013226.JPEG +ILSVRC2012_val_00028726.JPEG +ILSVRC2012_val_00014015.JPEG +ILSVRC2012_val_00029321.JPEG +ILSVRC2012_val_00000086.JPEG +ILSVRC2012_val_00021922.JPEG +ILSVRC2012_val_00006030.JPEG +ILSVRC2012_val_00016381.JPEG +ILSVRC2012_val_00033329.JPEG +ILSVRC2012_val_00013165.JPEG +ILSVRC2012_val_00007154.JPEG +ILSVRC2012_val_00020356.JPEG +ILSVRC2012_val_00036756.JPEG +ILSVRC2012_val_00009178.JPEG +ILSVRC2012_val_00010716.JPEG +ILSVRC2012_val_00029184.JPEG +ILSVRC2012_val_00026863.JPEG +ILSVRC2012_val_00013206.JPEG +ILSVRC2012_val_00027247.JPEG +ILSVRC2012_val_00003971.JPEG +ILSVRC2012_val_00018591.JPEG +ILSVRC2012_val_00021649.JPEG +ILSVRC2012_val_00029629.JPEG +ILSVRC2012_val_00033631.JPEG +ILSVRC2012_val_00037862.JPEG +ILSVRC2012_val_00019484.JPEG +ILSVRC2012_val_00025015.JPEG +ILSVRC2012_val_00013061.JPEG +ILSVRC2012_val_00035167.JPEG +ILSVRC2012_val_00000400.JPEG +ILSVRC2012_val_00015421.JPEG +ILSVRC2012_val_00039408.JPEG +ILSVRC2012_val_00048896.JPEG +ILSVRC2012_val_00004736.JPEG +ILSVRC2012_val_00047215.JPEG +ILSVRC2012_val_00038868.JPEG +ILSVRC2012_val_00031639.JPEG +ILSVRC2012_val_00023487.JPEG +ILSVRC2012_val_00019449.JPEG +ILSVRC2012_val_00018767.JPEG +ILSVRC2012_val_00046022.JPEG +ILSVRC2012_val_00026512.JPEG +ILSVRC2012_val_00046621.JPEG +ILSVRC2012_val_00022283.JPEG +ILSVRC2012_val_00007804.JPEG +ILSVRC2012_val_00009364.JPEG +ILSVRC2012_val_00022912.JPEG +ILSVRC2012_val_00000928.JPEG +ILSVRC2012_val_00024604.JPEG +ILSVRC2012_val_00030035.JPEG +ILSVRC2012_val_00037444.JPEG +ILSVRC2012_val_00022365.JPEG +ILSVRC2012_val_00022269.JPEG +ILSVRC2012_val_00013882.JPEG +ILSVRC2012_val_00016490.JPEG +ILSVRC2012_val_00011472.JPEG +ILSVRC2012_val_00049433.JPEG +ILSVRC2012_val_00015918.JPEG +ILSVRC2012_val_00005991.JPEG +ILSVRC2012_val_00031214.JPEG +ILSVRC2012_val_00006422.JPEG +ILSVRC2012_val_00045370.JPEG +ILSVRC2012_val_00045870.JPEG +ILSVRC2012_val_00000724.JPEG +ILSVRC2012_val_00039926.JPEG +ILSVRC2012_val_00024500.JPEG +ILSVRC2012_val_00002101.JPEG +ILSVRC2012_val_00029812.JPEG +ILSVRC2012_val_00015610.JPEG +ILSVRC2012_val_00049246.JPEG +ILSVRC2012_val_00047001.JPEG +ILSVRC2012_val_00030037.JPEG +ILSVRC2012_val_00007608.JPEG +ILSVRC2012_val_00021615.JPEG +ILSVRC2012_val_00019751.JPEG +ILSVRC2012_val_00032910.JPEG +ILSVRC2012_val_00044803.JPEG +ILSVRC2012_val_00036367.JPEG +ILSVRC2012_val_00023207.JPEG +ILSVRC2012_val_00009318.JPEG +ILSVRC2012_val_00031114.JPEG +ILSVRC2012_val_00047589.JPEG +ILSVRC2012_val_00004136.JPEG +ILSVRC2012_val_00043823.JPEG +ILSVRC2012_val_00027106.JPEG +ILSVRC2012_val_00033686.JPEG +ILSVRC2012_val_00045409.JPEG +ILSVRC2012_val_00027909.JPEG +ILSVRC2012_val_00040572.JPEG +ILSVRC2012_val_00034483.JPEG +ILSVRC2012_val_00046956.JPEG +ILSVRC2012_val_00039190.JPEG +ILSVRC2012_val_00018068.JPEG +ILSVRC2012_val_00037952.JPEG +ILSVRC2012_val_00026652.JPEG +ILSVRC2012_val_00034494.JPEG +ILSVRC2012_val_00002133.JPEG +ILSVRC2012_val_00000508.JPEG +ILSVRC2012_val_00000051.JPEG +ILSVRC2012_val_00005187.JPEG +ILSVRC2012_val_00033221.JPEG +ILSVRC2012_val_00005072.JPEG +ILSVRC2012_val_00030476.JPEG +ILSVRC2012_val_00047496.JPEG +ILSVRC2012_val_00039816.JPEG +ILSVRC2012_val_00031849.JPEG +ILSVRC2012_val_00030715.JPEG +ILSVRC2012_val_00036409.JPEG +ILSVRC2012_val_00026523.JPEG +ILSVRC2012_val_00046349.JPEG +ILSVRC2012_val_00039622.JPEG +ILSVRC2012_val_00025192.JPEG +ILSVRC2012_val_00036702.JPEG +ILSVRC2012_val_00012329.JPEG +ILSVRC2012_val_00037844.JPEG +ILSVRC2012_val_00005323.JPEG +ILSVRC2012_val_00020824.JPEG +ILSVRC2012_val_00042283.JPEG +ILSVRC2012_val_00037259.JPEG +ILSVRC2012_val_00012772.JPEG +ILSVRC2012_val_00048844.JPEG +ILSVRC2012_val_00017697.JPEG +ILSVRC2012_val_00012992.JPEG +ILSVRC2012_val_00010104.JPEG +ILSVRC2012_val_00029937.JPEG +ILSVRC2012_val_00022953.JPEG +ILSVRC2012_val_00002114.JPEG +ILSVRC2012_val_00037442.JPEG +ILSVRC2012_val_00023028.JPEG +ILSVRC2012_val_00036926.JPEG +ILSVRC2012_val_00030251.JPEG +ILSVRC2012_val_00003076.JPEG +ILSVRC2012_val_00015385.JPEG +ILSVRC2012_val_00001464.JPEG +ILSVRC2012_val_00011218.JPEG +ILSVRC2012_val_00016569.JPEG +ILSVRC2012_val_00043881.JPEG +ILSVRC2012_val_00008623.JPEG +ILSVRC2012_val_00031923.JPEG +ILSVRC2012_val_00028247.JPEG +ILSVRC2012_val_00021504.JPEG +ILSVRC2012_val_00018312.JPEG +ILSVRC2012_val_00013954.JPEG +ILSVRC2012_val_00012805.JPEG +ILSVRC2012_val_00007206.JPEG +ILSVRC2012_val_00043862.JPEG +ILSVRC2012_val_00026038.JPEG +ILSVRC2012_val_00041761.JPEG +ILSVRC2012_val_00013831.JPEG +ILSVRC2012_val_00024245.JPEG +ILSVRC2012_val_00020113.JPEG +ILSVRC2012_val_00007191.JPEG +ILSVRC2012_val_00042112.JPEG +ILSVRC2012_val_00037389.JPEG +ILSVRC2012_val_00009489.JPEG +ILSVRC2012_val_00045945.JPEG +ILSVRC2012_val_00002014.JPEG +ILSVRC2012_val_00000561.JPEG +ILSVRC2012_val_00015322.JPEG +ILSVRC2012_val_00037156.JPEG +ILSVRC2012_val_00023140.JPEG +ILSVRC2012_val_00033642.JPEG +ILSVRC2012_val_00017688.JPEG +ILSVRC2012_val_00021674.JPEG +ILSVRC2012_val_00006100.JPEG +ILSVRC2012_val_00006838.JPEG +ILSVRC2012_val_00040675.JPEG +ILSVRC2012_val_00040668.JPEG +ILSVRC2012_val_00014695.JPEG +ILSVRC2012_val_00012893.JPEG +ILSVRC2012_val_00009587.JPEG +ILSVRC2012_val_00030442.JPEG +ILSVRC2012_val_00048435.JPEG +ILSVRC2012_val_00035095.JPEG +ILSVRC2012_val_00010296.JPEG +ILSVRC2012_val_00028913.JPEG +ILSVRC2012_val_00030883.JPEG +ILSVRC2012_val_00048886.JPEG +ILSVRC2012_val_00017481.JPEG +ILSVRC2012_val_00015336.JPEG +ILSVRC2012_val_00042392.JPEG +ILSVRC2012_val_00035433.JPEG +ILSVRC2012_val_00021212.JPEG +ILSVRC2012_val_00005539.JPEG +ILSVRC2012_val_00028149.JPEG +ILSVRC2012_val_00006848.JPEG +ILSVRC2012_val_00001112.JPEG +ILSVRC2012_val_00025166.JPEG +ILSVRC2012_val_00018163.JPEG +ILSVRC2012_val_00013191.JPEG +ILSVRC2012_val_00014803.JPEG +ILSVRC2012_val_00016499.JPEG +ILSVRC2012_val_00016474.JPEG +ILSVRC2012_val_00010325.JPEG +ILSVRC2012_val_00025880.JPEG +ILSVRC2012_val_00047328.JPEG +ILSVRC2012_val_00032642.JPEG +ILSVRC2012_val_00015913.JPEG +ILSVRC2012_val_00023118.JPEG +ILSVRC2012_val_00049509.JPEG +ILSVRC2012_val_00008373.JPEG +ILSVRC2012_val_00028188.JPEG +ILSVRC2012_val_00007774.JPEG +ILSVRC2012_val_00022460.JPEG +ILSVRC2012_val_00021777.JPEG +ILSVRC2012_val_00038066.JPEG +ILSVRC2012_val_00009957.JPEG +ILSVRC2012_val_00044176.JPEG +ILSVRC2012_val_00041218.JPEG +ILSVRC2012_val_00010664.JPEG +ILSVRC2012_val_00000440.JPEG +ILSVRC2012_val_00036279.JPEG +ILSVRC2012_val_00011206.JPEG +ILSVRC2012_val_00021773.JPEG +ILSVRC2012_val_00011454.JPEG +ILSVRC2012_val_00037821.JPEG +ILSVRC2012_val_00003558.JPEG +ILSVRC2012_val_00019642.JPEG +ILSVRC2012_val_00047424.JPEG +ILSVRC2012_val_00024166.JPEG +ILSVRC2012_val_00036165.JPEG +ILSVRC2012_val_00022523.JPEG +ILSVRC2012_val_00030616.JPEG +ILSVRC2012_val_00014869.JPEG +ILSVRC2012_val_00016709.JPEG +ILSVRC2012_val_00001285.JPEG +ILSVRC2012_val_00045733.JPEG +ILSVRC2012_val_00039136.JPEG +ILSVRC2012_val_00024580.JPEG +ILSVRC2012_val_00042506.JPEG +ILSVRC2012_val_00020960.JPEG +ILSVRC2012_val_00048840.JPEG +ILSVRC2012_val_00028503.JPEG +ILSVRC2012_val_00042354.JPEG +ILSVRC2012_val_00030356.JPEG +ILSVRC2012_val_00048926.JPEG +ILSVRC2012_val_00006539.JPEG +ILSVRC2012_val_00007641.JPEG +ILSVRC2012_val_00019245.JPEG +ILSVRC2012_val_00008861.JPEG +ILSVRC2012_val_00020174.JPEG +ILSVRC2012_val_00030569.JPEG +ILSVRC2012_val_00019779.JPEG +ILSVRC2012_val_00019936.JPEG +ILSVRC2012_val_00019986.JPEG +ILSVRC2012_val_00011611.JPEG +ILSVRC2012_val_00025648.JPEG +ILSVRC2012_val_00043804.JPEG +ILSVRC2012_val_00030551.JPEG +ILSVRC2012_val_00036865.JPEG +ILSVRC2012_val_00019097.JPEG +ILSVRC2012_val_00006957.JPEG +ILSVRC2012_val_00023828.JPEG +ILSVRC2012_val_00047810.JPEG +ILSVRC2012_val_00028482.JPEG +ILSVRC2012_val_00030726.JPEG +ILSVRC2012_val_00005319.JPEG +ILSVRC2012_val_00017881.JPEG +ILSVRC2012_val_00020811.JPEG +ILSVRC2012_val_00008682.JPEG +ILSVRC2012_val_00033423.JPEG +ILSVRC2012_val_00039984.JPEG +ILSVRC2012_val_00022711.JPEG +ILSVRC2012_val_00044466.JPEG +ILSVRC2012_val_00001845.JPEG +ILSVRC2012_val_00007295.JPEG +ILSVRC2012_val_00016184.JPEG +ILSVRC2012_val_00017238.JPEG +ILSVRC2012_val_00033321.JPEG +ILSVRC2012_val_00008217.JPEG +ILSVRC2012_val_00009956.JPEG +ILSVRC2012_val_00035831.JPEG +ILSVRC2012_val_00044961.JPEG +ILSVRC2012_val_00025984.JPEG +ILSVRC2012_val_00004067.JPEG +ILSVRC2012_val_00038202.JPEG +ILSVRC2012_val_00049432.JPEG +ILSVRC2012_val_00023617.JPEG +ILSVRC2012_val_00013989.JPEG +ILSVRC2012_val_00014158.JPEG +ILSVRC2012_val_00036853.JPEG +ILSVRC2012_val_00024117.JPEG +ILSVRC2012_val_00042702.JPEG +ILSVRC2012_val_00030711.JPEG +ILSVRC2012_val_00036921.JPEG +ILSVRC2012_val_00028859.JPEG +ILSVRC2012_val_00005377.JPEG +ILSVRC2012_val_00045319.JPEG +ILSVRC2012_val_00001821.JPEG +ILSVRC2012_val_00016575.JPEG +ILSVRC2012_val_00020535.JPEG +ILSVRC2012_val_00009592.JPEG +ILSVRC2012_val_00020492.JPEG +ILSVRC2012_val_00019164.JPEG +ILSVRC2012_val_00034722.JPEG +ILSVRC2012_val_00037278.JPEG +ILSVRC2012_val_00046570.JPEG +ILSVRC2012_val_00013926.JPEG +ILSVRC2012_val_00031041.JPEG +ILSVRC2012_val_00042092.JPEG +ILSVRC2012_val_00012565.JPEG +ILSVRC2012_val_00018643.JPEG +ILSVRC2012_val_00005793.JPEG +ILSVRC2012_val_00015794.JPEG +ILSVRC2012_val_00001414.JPEG +ILSVRC2012_val_00046896.JPEG +ILSVRC2012_val_00015428.JPEG +ILSVRC2012_val_00004235.JPEG +ILSVRC2012_val_00043450.JPEG +ILSVRC2012_val_00025036.JPEG +ILSVRC2012_val_00038798.JPEG +ILSVRC2012_val_00048325.JPEG +ILSVRC2012_val_00034096.JPEG +ILSVRC2012_val_00023003.JPEG +ILSVRC2012_val_00002276.JPEG +ILSVRC2012_val_00034132.JPEG +ILSVRC2012_val_00045454.JPEG +ILSVRC2012_val_00008448.JPEG +ILSVRC2012_val_00011686.JPEG +ILSVRC2012_val_00046617.JPEG +ILSVRC2012_val_00032890.JPEG +ILSVRC2012_val_00042011.JPEG +ILSVRC2012_val_00015602.JPEG +ILSVRC2012_val_00046269.JPEG +ILSVRC2012_val_00047960.JPEG +ILSVRC2012_val_00033593.JPEG +ILSVRC2012_val_00022352.JPEG +ILSVRC2012_val_00042910.JPEG +ILSVRC2012_val_00014482.JPEG +ILSVRC2012_val_00029668.JPEG +ILSVRC2012_val_00014740.JPEG +ILSVRC2012_val_00019972.JPEG +ILSVRC2012_val_00011129.JPEG +ILSVRC2012_val_00001851.JPEG +ILSVRC2012_val_00026886.JPEG +ILSVRC2012_val_00011127.JPEG +ILSVRC2012_val_00017273.JPEG +ILSVRC2012_val_00028861.JPEG +ILSVRC2012_val_00018977.JPEG +ILSVRC2012_val_00024473.JPEG +ILSVRC2012_val_00023480.JPEG +ILSVRC2012_val_00047688.JPEG +ILSVRC2012_val_00014628.JPEG +ILSVRC2012_val_00022489.JPEG +ILSVRC2012_val_00030848.JPEG +ILSVRC2012_val_00029237.JPEG +ILSVRC2012_val_00047041.JPEG +ILSVRC2012_val_00013750.JPEG +ILSVRC2012_val_00011702.JPEG +ILSVRC2012_val_00016917.JPEG +ILSVRC2012_val_00029055.JPEG +ILSVRC2012_val_00010769.JPEG +ILSVRC2012_val_00039573.JPEG +ILSVRC2012_val_00039339.JPEG +ILSVRC2012_val_00021669.JPEG +ILSVRC2012_val_00002043.JPEG +ILSVRC2012_val_00008143.JPEG +ILSVRC2012_val_00012961.JPEG +ILSVRC2012_val_00037945.JPEG +ILSVRC2012_val_00019317.JPEG +ILSVRC2012_val_00033525.JPEG +ILSVRC2012_val_00009797.JPEG +ILSVRC2012_val_00006405.JPEG +ILSVRC2012_val_00011098.JPEG +ILSVRC2012_val_00034261.JPEG +ILSVRC2012_val_00009224.JPEG +ILSVRC2012_val_00023122.JPEG +ILSVRC2012_val_00047460.JPEG +ILSVRC2012_val_00013896.JPEG +ILSVRC2012_val_00007122.JPEG +ILSVRC2012_val_00039218.JPEG +ILSVRC2012_val_00037667.JPEG +ILSVRC2012_val_00033809.JPEG +ILSVRC2012_val_00042049.JPEG +ILSVRC2012_val_00009536.JPEG +ILSVRC2012_val_00010235.JPEG +ILSVRC2012_val_00034428.JPEG +ILSVRC2012_val_00024526.JPEG +ILSVRC2012_val_00000266.JPEG +ILSVRC2012_val_00003374.JPEG +ILSVRC2012_val_00003414.JPEG +ILSVRC2012_val_00004756.JPEG +ILSVRC2012_val_00001460.JPEG +ILSVRC2012_val_00005415.JPEG +ILSVRC2012_val_00015703.JPEG +ILSVRC2012_val_00046801.JPEG +ILSVRC2012_val_00047462.JPEG +ILSVRC2012_val_00018368.JPEG +ILSVRC2012_val_00048476.JPEG +ILSVRC2012_val_00029322.JPEG +ILSVRC2012_val_00033398.JPEG +ILSVRC2012_val_00033699.JPEG +ILSVRC2012_val_00000955.JPEG +ILSVRC2012_val_00001294.JPEG +ILSVRC2012_val_00013308.JPEG +ILSVRC2012_val_00000459.JPEG +ILSVRC2012_val_00028450.JPEG +ILSVRC2012_val_00045161.JPEG +ILSVRC2012_val_00029286.JPEG +ILSVRC2012_val_00008876.JPEG +ILSVRC2012_val_00001222.JPEG +ILSVRC2012_val_00006870.JPEG +ILSVRC2012_val_00018422.JPEG +ILSVRC2012_val_00014642.JPEG +ILSVRC2012_val_00008649.JPEG +ILSVRC2012_val_00026704.JPEG +ILSVRC2012_val_00049661.JPEG +ILSVRC2012_val_00011242.JPEG +ILSVRC2012_val_00032783.JPEG +ILSVRC2012_val_00029541.JPEG +ILSVRC2012_val_00035132.JPEG +ILSVRC2012_val_00042521.JPEG +ILSVRC2012_val_00010700.JPEG +ILSVRC2012_val_00007159.JPEG +ILSVRC2012_val_00032279.JPEG +ILSVRC2012_val_00040201.JPEG +ILSVRC2012_val_00013589.JPEG +ILSVRC2012_val_00009703.JPEG +ILSVRC2012_val_00000915.JPEG +ILSVRC2012_val_00015923.JPEG +ILSVRC2012_val_00032736.JPEG +ILSVRC2012_val_00012252.JPEG +ILSVRC2012_val_00020372.JPEG +ILSVRC2012_val_00003760.JPEG +ILSVRC2012_val_00017274.JPEG +ILSVRC2012_val_00033072.JPEG +ILSVRC2012_val_00006369.JPEG +ILSVRC2012_val_00023713.JPEG +ILSVRC2012_val_00003141.JPEG +ILSVRC2012_val_00002281.JPEG +ILSVRC2012_val_00040788.JPEG +ILSVRC2012_val_00011124.JPEG +ILSVRC2012_val_00003178.JPEG +ILSVRC2012_val_00024549.JPEG +ILSVRC2012_val_00045457.JPEG +ILSVRC2012_val_00007902.JPEG +ILSVRC2012_val_00036920.JPEG +ILSVRC2012_val_00047698.JPEG +ILSVRC2012_val_00004510.JPEG +ILSVRC2012_val_00021893.JPEG +ILSVRC2012_val_00032693.JPEG +ILSVRC2012_val_00044184.JPEG +ILSVRC2012_val_00013652.JPEG +ILSVRC2012_val_00045178.JPEG +ILSVRC2012_val_00016785.JPEG +ILSVRC2012_val_00022571.JPEG +ILSVRC2012_val_00019441.JPEG +ILSVRC2012_val_00002062.JPEG +ILSVRC2012_val_00007505.JPEG +ILSVRC2012_val_00045676.JPEG +ILSVRC2012_val_00015093.JPEG +ILSVRC2012_val_00013526.JPEG +ILSVRC2012_val_00024021.JPEG +ILSVRC2012_val_00011819.JPEG +ILSVRC2012_val_00010406.JPEG +ILSVRC2012_val_00017838.JPEG +ILSVRC2012_val_00031778.JPEG +ILSVRC2012_val_00032878.JPEG +ILSVRC2012_val_00005740.JPEG +ILSVRC2012_val_00039357.JPEG +ILSVRC2012_val_00001991.JPEG +ILSVRC2012_val_00015280.JPEG +ILSVRC2012_val_00025429.JPEG +ILSVRC2012_val_00024645.JPEG +ILSVRC2012_val_00016285.JPEG +ILSVRC2012_val_00036271.JPEG +ILSVRC2012_val_00000982.JPEG +ILSVRC2012_val_00034315.JPEG +ILSVRC2012_val_00045608.JPEG +ILSVRC2012_val_00047365.JPEG +ILSVRC2012_val_00007195.JPEG +ILSVRC2012_val_00017014.JPEG +ILSVRC2012_val_00030401.JPEG +ILSVRC2012_val_00017113.JPEG +ILSVRC2012_val_00047896.JPEG +ILSVRC2012_val_00004268.JPEG +ILSVRC2012_val_00004568.JPEG +ILSVRC2012_val_00045663.JPEG +ILSVRC2012_val_00010306.JPEG +ILSVRC2012_val_00022990.JPEG +ILSVRC2012_val_00006455.JPEG +ILSVRC2012_val_00020924.JPEG +ILSVRC2012_val_00035821.JPEG +ILSVRC2012_val_00022932.JPEG +ILSVRC2012_val_00022864.JPEG +ILSVRC2012_val_00005083.JPEG +ILSVRC2012_val_00007911.JPEG +ILSVRC2012_val_00036667.JPEG +ILSVRC2012_val_00024652.JPEG +ILSVRC2012_val_00042543.JPEG +ILSVRC2012_val_00009398.JPEG +ILSVRC2012_val_00035097.JPEG +ILSVRC2012_val_00049905.JPEG +ILSVRC2012_val_00026734.JPEG +ILSVRC2012_val_00030929.JPEG +ILSVRC2012_val_00030920.JPEG +ILSVRC2012_val_00005056.JPEG +ILSVRC2012_val_00029837.JPEG +ILSVRC2012_val_00039182.JPEG +ILSVRC2012_val_00017823.JPEG +ILSVRC2012_val_00008850.JPEG +ILSVRC2012_val_00006533.JPEG +ILSVRC2012_val_00011289.JPEG +ILSVRC2012_val_00018666.JPEG +ILSVRC2012_val_00026224.JPEG +ILSVRC2012_val_00033906.JPEG +ILSVRC2012_val_00028084.JPEG +ILSVRC2012_val_00004072.JPEG +ILSVRC2012_val_00035301.JPEG +ILSVRC2012_val_00046365.JPEG +ILSVRC2012_val_00034624.JPEG +ILSVRC2012_val_00003734.JPEG +ILSVRC2012_val_00004028.JPEG +ILSVRC2012_val_00030128.JPEG +ILSVRC2012_val_00036993.JPEG +ILSVRC2012_val_00007133.JPEG +ILSVRC2012_val_00007468.JPEG +ILSVRC2012_val_00008438.JPEG +ILSVRC2012_val_00021028.JPEG +ILSVRC2012_val_00023403.JPEG +ILSVRC2012_val_00034393.JPEG +ILSVRC2012_val_00019495.JPEG +ILSVRC2012_val_00031441.JPEG +ILSVRC2012_val_00006935.JPEG +ILSVRC2012_val_00029141.JPEG +ILSVRC2012_val_00016628.JPEG +ILSVRC2012_val_00008375.JPEG +ILSVRC2012_val_00036822.JPEG +ILSVRC2012_val_00023965.JPEG +ILSVRC2012_val_00006037.JPEG +ILSVRC2012_val_00014560.JPEG +ILSVRC2012_val_00009661.JPEG +ILSVRC2012_val_00044711.JPEG +ILSVRC2012_val_00015249.JPEG +ILSVRC2012_val_00042583.JPEG +ILSVRC2012_val_00032273.JPEG +ILSVRC2012_val_00028995.JPEG +ILSVRC2012_val_00043643.JPEG +ILSVRC2012_val_00008411.JPEG +ILSVRC2012_val_00036142.JPEG +ILSVRC2012_val_00023654.JPEG +ILSVRC2012_val_00001926.JPEG +ILSVRC2012_val_00014874.JPEG +ILSVRC2012_val_00016511.JPEG +ILSVRC2012_val_00042984.JPEG +ILSVRC2012_val_00019621.JPEG +ILSVRC2012_val_00008624.JPEG +ILSVRC2012_val_00038385.JPEG +ILSVRC2012_val_00004882.JPEG +ILSVRC2012_val_00001702.JPEG +ILSVRC2012_val_00045999.JPEG +ILSVRC2012_val_00025883.JPEG +ILSVRC2012_val_00016854.JPEG +ILSVRC2012_val_00033128.JPEG +ILSVRC2012_val_00006916.JPEG +ILSVRC2012_val_00031994.JPEG +ILSVRC2012_val_00041278.JPEG +ILSVRC2012_val_00049162.JPEG +ILSVRC2012_val_00030591.JPEG +ILSVRC2012_val_00000981.JPEG +ILSVRC2012_val_00009102.JPEG +ILSVRC2012_val_00003367.JPEG +ILSVRC2012_val_00012725.JPEG +ILSVRC2012_val_00046971.JPEG +ILSVRC2012_val_00018567.JPEG +ILSVRC2012_val_00047353.JPEG +ILSVRC2012_val_00037938.JPEG +ILSVRC2012_val_00016095.JPEG +ILSVRC2012_val_00010137.JPEG +ILSVRC2012_val_00013461.JPEG +ILSVRC2012_val_00042599.JPEG +ILSVRC2012_val_00022581.JPEG +ILSVRC2012_val_00026040.JPEG +ILSVRC2012_val_00007228.JPEG +ILSVRC2012_val_00034875.JPEG +ILSVRC2012_val_00016242.JPEG +ILSVRC2012_val_00048550.JPEG +ILSVRC2012_val_00035249.JPEG +ILSVRC2012_val_00045901.JPEG +ILSVRC2012_val_00008905.JPEG +ILSVRC2012_val_00036605.JPEG +ILSVRC2012_val_00030845.JPEG +ILSVRC2012_val_00025531.JPEG +ILSVRC2012_val_00017312.JPEG +ILSVRC2012_val_00041078.JPEG +ILSVRC2012_val_00022349.JPEG +ILSVRC2012_val_00019956.JPEG +ILSVRC2012_val_00033695.JPEG +ILSVRC2012_val_00047770.JPEG +ILSVRC2012_val_00023390.JPEG +ILSVRC2012_val_00011701.JPEG +ILSVRC2012_val_00027808.JPEG +ILSVRC2012_val_00014342.JPEG +ILSVRC2012_val_00002121.JPEG +ILSVRC2012_val_00023938.JPEG +ILSVRC2012_val_00023462.JPEG +ILSVRC2012_val_00038666.JPEG +ILSVRC2012_val_00035793.JPEG +ILSVRC2012_val_00028806.JPEG +ILSVRC2012_val_00002565.JPEG +ILSVRC2012_val_00031826.JPEG +ILSVRC2012_val_00043319.JPEG +ILSVRC2012_val_00017447.JPEG +ILSVRC2012_val_00039008.JPEG +ILSVRC2012_val_00019934.JPEG +ILSVRC2012_val_00006570.JPEG +ILSVRC2012_val_00040875.JPEG +ILSVRC2012_val_00032361.JPEG +ILSVRC2012_val_00007257.JPEG +ILSVRC2012_val_00008461.JPEG +ILSVRC2012_val_00049667.JPEG +ILSVRC2012_val_00038643.JPEG +ILSVRC2012_val_00041027.JPEG +ILSVRC2012_val_00022999.JPEG +ILSVRC2012_val_00025162.JPEG +ILSVRC2012_val_00043744.JPEG +ILSVRC2012_val_00028528.JPEG +ILSVRC2012_val_00032207.JPEG +ILSVRC2012_val_00027183.JPEG +ILSVRC2012_val_00011281.JPEG +ILSVRC2012_val_00046308.JPEG +ILSVRC2012_val_00009409.JPEG +ILSVRC2012_val_00040768.JPEG +ILSVRC2012_val_00038415.JPEG +ILSVRC2012_val_00045521.JPEG +ILSVRC2012_val_00019458.JPEG +ILSVRC2012_val_00011344.JPEG +ILSVRC2012_val_00002802.JPEG +ILSVRC2012_val_00009310.JPEG +ILSVRC2012_val_00033987.JPEG +ILSVRC2012_val_00018792.JPEG +ILSVRC2012_val_00044030.JPEG +ILSVRC2012_val_00041470.JPEG +ILSVRC2012_val_00008834.JPEG +ILSVRC2012_val_00032118.JPEG +ILSVRC2012_val_00022691.JPEG +ILSVRC2012_val_00029962.JPEG +ILSVRC2012_val_00042474.JPEG +ILSVRC2012_val_00024044.JPEG +ILSVRC2012_val_00038352.JPEG +ILSVRC2012_val_00048240.JPEG +ILSVRC2012_val_00005569.JPEG +ILSVRC2012_val_00027875.JPEG +ILSVRC2012_val_00049116.JPEG +ILSVRC2012_val_00044154.JPEG +ILSVRC2012_val_00035678.JPEG +ILSVRC2012_val_00042235.JPEG +ILSVRC2012_val_00011743.JPEG +ILSVRC2012_val_00035685.JPEG +ILSVRC2012_val_00041547.JPEG +ILSVRC2012_val_00043486.JPEG +ILSVRC2012_val_00027046.JPEG +ILSVRC2012_val_00016405.JPEG +ILSVRC2012_val_00010204.JPEG +ILSVRC2012_val_00004645.JPEG +ILSVRC2012_val_00037697.JPEG +ILSVRC2012_val_00002711.JPEG +ILSVRC2012_val_00043536.JPEG +ILSVRC2012_val_00004925.JPEG +ILSVRC2012_val_00048416.JPEG +ILSVRC2012_val_00009141.JPEG +ILSVRC2012_val_00029283.JPEG +ILSVRC2012_val_00017569.JPEG +ILSVRC2012_val_00042662.JPEG +ILSVRC2012_val_00049551.JPEG +ILSVRC2012_val_00029382.JPEG +ILSVRC2012_val_00030118.JPEG +ILSVRC2012_val_00048754.JPEG +ILSVRC2012_val_00001015.JPEG +ILSVRC2012_val_00038250.JPEG +ILSVRC2012_val_00018040.JPEG +ILSVRC2012_val_00043767.JPEG +ILSVRC2012_val_00034627.JPEG +ILSVRC2012_val_00040254.JPEG +ILSVRC2012_val_00027864.JPEG +ILSVRC2012_val_00024926.JPEG +ILSVRC2012_val_00007382.JPEG +ILSVRC2012_val_00001118.JPEG +ILSVRC2012_val_00014365.JPEG +ILSVRC2012_val_00026060.JPEG +ILSVRC2012_val_00045094.JPEG +ILSVRC2012_val_00029874.JPEG +ILSVRC2012_val_00049716.JPEG +ILSVRC2012_val_00019158.JPEG +ILSVRC2012_val_00031197.JPEG +ILSVRC2012_val_00037899.JPEG +ILSVRC2012_val_00014669.JPEG +ILSVRC2012_val_00036906.JPEG +ILSVRC2012_val_00043079.JPEG +ILSVRC2012_val_00017658.JPEG +ILSVRC2012_val_00008167.JPEG +ILSVRC2012_val_00045959.JPEG +ILSVRC2012_val_00024534.JPEG +ILSVRC2012_val_00015118.JPEG +ILSVRC2012_val_00022355.JPEG +ILSVRC2012_val_00048281.JPEG +ILSVRC2012_val_00048720.JPEG +ILSVRC2012_val_00009113.JPEG +ILSVRC2012_val_00027806.JPEG +ILSVRC2012_val_00009491.JPEG +ILSVRC2012_val_00046486.JPEG +ILSVRC2012_val_00041279.JPEG +ILSVRC2012_val_00045544.JPEG +ILSVRC2012_val_00001635.JPEG +ILSVRC2012_val_00001165.JPEG +ILSVRC2012_val_00033772.JPEG +ILSVRC2012_val_00023686.JPEG +ILSVRC2012_val_00024802.JPEG +ILSVRC2012_val_00004862.JPEG +ILSVRC2012_val_00028501.JPEG +ILSVRC2012_val_00019547.JPEG +ILSVRC2012_val_00002408.JPEG +ILSVRC2012_val_00041442.JPEG +ILSVRC2012_val_00010787.JPEG +ILSVRC2012_val_00040870.JPEG +ILSVRC2012_val_00021500.JPEG +ILSVRC2012_val_00048230.JPEG +ILSVRC2012_val_00005058.JPEG +ILSVRC2012_val_00023279.JPEG +ILSVRC2012_val_00026048.JPEG +ILSVRC2012_val_00024234.JPEG +ILSVRC2012_val_00017941.JPEG +ILSVRC2012_val_00002357.JPEG +ILSVRC2012_val_00032158.JPEG +ILSVRC2012_val_00024738.JPEG +ILSVRC2012_val_00002258.JPEG +ILSVRC2012_val_00034780.JPEG +ILSVRC2012_val_00002527.JPEG +ILSVRC2012_val_00046290.JPEG +ILSVRC2012_val_00006150.JPEG +ILSVRC2012_val_00045024.JPEG +ILSVRC2012_val_00005738.JPEG +ILSVRC2012_val_00001777.JPEG +ILSVRC2012_val_00024719.JPEG +ILSVRC2012_val_00027895.JPEG +ILSVRC2012_val_00012609.JPEG +ILSVRC2012_val_00014734.JPEG +ILSVRC2012_val_00033804.JPEG +ILSVRC2012_val_00032590.JPEG +ILSVRC2012_val_00020166.JPEG +ILSVRC2012_val_00031852.JPEG +ILSVRC2012_val_00018176.JPEG +ILSVRC2012_val_00017918.JPEG +ILSVRC2012_val_00046886.JPEG +ILSVRC2012_val_00030076.JPEG +ILSVRC2012_val_00039684.JPEG +ILSVRC2012_val_00008505.JPEG +ILSVRC2012_val_00032687.JPEG +ILSVRC2012_val_00032998.JPEG +ILSVRC2012_val_00028832.JPEG +ILSVRC2012_val_00037663.JPEG +ILSVRC2012_val_00003674.JPEG +ILSVRC2012_val_00024423.JPEG +ILSVRC2012_val_00029852.JPEG +ILSVRC2012_val_00024589.JPEG +ILSVRC2012_val_00021650.JPEG +ILSVRC2012_val_00040408.JPEG +ILSVRC2012_val_00004054.JPEG +ILSVRC2012_val_00000259.JPEG +ILSVRC2012_val_00026889.JPEG +ILSVRC2012_val_00002248.JPEG +ILSVRC2012_val_00023595.JPEG +ILSVRC2012_val_00030579.JPEG +ILSVRC2012_val_00017840.JPEG +ILSVRC2012_val_00006757.JPEG +ILSVRC2012_val_00048202.JPEG +ILSVRC2012_val_00035116.JPEG +ILSVRC2012_val_00047683.JPEG +ILSVRC2012_val_00007888.JPEG +ILSVRC2012_val_00038782.JPEG +ILSVRC2012_val_00048178.JPEG +ILSVRC2012_val_00038673.JPEG +ILSVRC2012_val_00047101.JPEG +ILSVRC2012_val_00034437.JPEG +ILSVRC2012_val_00047473.JPEG +ILSVRC2012_val_00033983.JPEG +ILSVRC2012_val_00034811.JPEG +ILSVRC2012_val_00002139.JPEG +ILSVRC2012_val_00017594.JPEG +ILSVRC2012_val_00025225.JPEG +ILSVRC2012_val_00048015.JPEG +ILSVRC2012_val_00036222.JPEG +ILSVRC2012_val_00035355.JPEG +ILSVRC2012_val_00023472.JPEG +ILSVRC2012_val_00041469.JPEG +ILSVRC2012_val_00013762.JPEG +ILSVRC2012_val_00004439.JPEG +ILSVRC2012_val_00047640.JPEG +ILSVRC2012_val_00046352.JPEG +ILSVRC2012_val_00011545.JPEG +ILSVRC2012_val_00044988.JPEG +ILSVRC2012_val_00016796.JPEG +ILSVRC2012_val_00044494.JPEG +ILSVRC2012_val_00048492.JPEG +ILSVRC2012_val_00040976.JPEG +ILSVRC2012_val_00048458.JPEG +ILSVRC2012_val_00033916.JPEG +ILSVRC2012_val_00047546.JPEG +ILSVRC2012_val_00005267.JPEG +ILSVRC2012_val_00032763.JPEG +ILSVRC2012_val_00012576.JPEG +ILSVRC2012_val_00036552.JPEG +ILSVRC2012_val_00034318.JPEG +ILSVRC2012_val_00009641.JPEG +ILSVRC2012_val_00010201.JPEG +ILSVRC2012_val_00026550.JPEG +ILSVRC2012_val_00042838.JPEG +ILSVRC2012_val_00026053.JPEG +ILSVRC2012_val_00015873.JPEG +ILSVRC2012_val_00045456.JPEG +ILSVRC2012_val_00043757.JPEG +ILSVRC2012_val_00042192.JPEG +ILSVRC2012_val_00015136.JPEG +ILSVRC2012_val_00048198.JPEG +ILSVRC2012_val_00019269.JPEG +ILSVRC2012_val_00004142.JPEG +ILSVRC2012_val_00007870.JPEG +ILSVRC2012_val_00013858.JPEG +ILSVRC2012_val_00017120.JPEG +ILSVRC2012_val_00003466.JPEG +ILSVRC2012_val_00018116.JPEG +ILSVRC2012_val_00022212.JPEG +ILSVRC2012_val_00002142.JPEG +ILSVRC2012_val_00029985.JPEG +ILSVRC2012_val_00046323.JPEG +ILSVRC2012_val_00033142.JPEG +ILSVRC2012_val_00015862.JPEG +ILSVRC2012_val_00048037.JPEG +ILSVRC2012_val_00037707.JPEG +ILSVRC2012_val_00001930.JPEG +ILSVRC2012_val_00012815.JPEG +ILSVRC2012_val_00018449.JPEG +ILSVRC2012_val_00040499.JPEG +ILSVRC2012_val_00039448.JPEG +ILSVRC2012_val_00042696.JPEG +ILSVRC2012_val_00030924.JPEG +ILSVRC2012_val_00032473.JPEG +ILSVRC2012_val_00003563.JPEG +ILSVRC2012_val_00024103.JPEG +ILSVRC2012_val_00025357.JPEG +ILSVRC2012_val_00001883.JPEG +ILSVRC2012_val_00024706.JPEG +ILSVRC2012_val_00033898.JPEG +ILSVRC2012_val_00000426.JPEG +ILSVRC2012_val_00035033.JPEG +ILSVRC2012_val_00031840.JPEG +ILSVRC2012_val_00037055.JPEG +ILSVRC2012_val_00046314.JPEG +ILSVRC2012_val_00036660.JPEG +ILSVRC2012_val_00007128.JPEG +ILSVRC2012_val_00041322.JPEG +ILSVRC2012_val_00036434.JPEG +ILSVRC2012_val_00046436.JPEG +ILSVRC2012_val_00021078.JPEG +ILSVRC2012_val_00026725.JPEG +ILSVRC2012_val_00034260.JPEG +ILSVRC2012_val_00022872.JPEG +ILSVRC2012_val_00022455.JPEG +ILSVRC2012_val_00038219.JPEG +ILSVRC2012_val_00047820.JPEG +ILSVRC2012_val_00025884.JPEG +ILSVRC2012_val_00047904.JPEG +ILSVRC2012_val_00029699.JPEG +ILSVRC2012_val_00003489.JPEG +ILSVRC2012_val_00018218.JPEG +ILSVRC2012_val_00021812.JPEG +ILSVRC2012_val_00047998.JPEG +ILSVRC2012_val_00046805.JPEG +ILSVRC2012_val_00045291.JPEG +ILSVRC2012_val_00025506.JPEG +ILSVRC2012_val_00026291.JPEG +ILSVRC2012_val_00004977.JPEG +ILSVRC2012_val_00027040.JPEG +ILSVRC2012_val_00011645.JPEG +ILSVRC2012_val_00033854.JPEG +ILSVRC2012_val_00004682.JPEG +ILSVRC2012_val_00048943.JPEG +ILSVRC2012_val_00015575.JPEG +ILSVRC2012_val_00032365.JPEG +ILSVRC2012_val_00031558.JPEG +ILSVRC2012_val_00016915.JPEG +ILSVRC2012_val_00036293.JPEG +ILSVRC2012_val_00019687.JPEG +ILSVRC2012_val_00013628.JPEG +ILSVRC2012_val_00004082.JPEG +ILSVRC2012_val_00020191.JPEG +ILSVRC2012_val_00017066.JPEG +ILSVRC2012_val_00043481.JPEG +ILSVRC2012_val_00011674.JPEG +ILSVRC2012_val_00048134.JPEG +ILSVRC2012_val_00009720.JPEG +ILSVRC2012_val_00019022.JPEG +ILSVRC2012_val_00020790.JPEG +ILSVRC2012_val_00015883.JPEG +ILSVRC2012_val_00027410.JPEG +ILSVRC2012_val_00013403.JPEG +ILSVRC2012_val_00031072.JPEG +ILSVRC2012_val_00016055.JPEG +ILSVRC2012_val_00045082.JPEG +ILSVRC2012_val_00014714.JPEG +ILSVRC2012_val_00034577.JPEG +ILSVRC2012_val_00016805.JPEG +ILSVRC2012_val_00020662.JPEG +ILSVRC2012_val_00022123.JPEG +ILSVRC2012_val_00035597.JPEG +ILSVRC2012_val_00038022.JPEG +ILSVRC2012_val_00020019.JPEG +ILSVRC2012_val_00027905.JPEG +ILSVRC2012_val_00002684.JPEG +ILSVRC2012_val_00001230.JPEG +ILSVRC2012_val_00029952.JPEG +ILSVRC2012_val_00038936.JPEG +ILSVRC2012_val_00027844.JPEG +ILSVRC2012_val_00038623.JPEG +ILSVRC2012_val_00048675.JPEG +ILSVRC2012_val_00007826.JPEG +ILSVRC2012_val_00021227.JPEG +ILSVRC2012_val_00012079.JPEG +ILSVRC2012_val_00020828.JPEG +ILSVRC2012_val_00037856.JPEG +ILSVRC2012_val_00009855.JPEG +ILSVRC2012_val_00019378.JPEG +ILSVRC2012_val_00023807.JPEG +ILSVRC2012_val_00002115.JPEG +ILSVRC2012_val_00046465.JPEG +ILSVRC2012_val_00031416.JPEG +ILSVRC2012_val_00046686.JPEG +ILSVRC2012_val_00028407.JPEG +ILSVRC2012_val_00024487.JPEG +ILSVRC2012_val_00014282.JPEG +ILSVRC2012_val_00025447.JPEG +ILSVRC2012_val_00034389.JPEG +ILSVRC2012_val_00017690.JPEG +ILSVRC2012_val_00020458.JPEG +ILSVRC2012_val_00032455.JPEG +ILSVRC2012_val_00032439.JPEG +ILSVRC2012_val_00009193.JPEG +ILSVRC2012_val_00048547.JPEG +ILSVRC2012_val_00003328.JPEG +ILSVRC2012_val_00040446.JPEG +ILSVRC2012_val_00002625.JPEG +ILSVRC2012_val_00024297.JPEG +ILSVRC2012_val_00048174.JPEG +ILSVRC2012_val_00035692.JPEG +ILSVRC2012_val_00018270.JPEG +ILSVRC2012_val_00044151.JPEG +ILSVRC2012_val_00039658.JPEG +ILSVRC2012_val_00026869.JPEG +ILSVRC2012_val_00014627.JPEG +ILSVRC2012_val_00016344.JPEG +ILSVRC2012_val_00035458.JPEG +ILSVRC2012_val_00021425.JPEG +ILSVRC2012_val_00046054.JPEG +ILSVRC2012_val_00008700.JPEG +ILSVRC2012_val_00020024.JPEG +ILSVRC2012_val_00005151.JPEG +ILSVRC2012_val_00030437.JPEG +ILSVRC2012_val_00003167.JPEG +ILSVRC2012_val_00024766.JPEG +ILSVRC2012_val_00026631.JPEG +ILSVRC2012_val_00038692.JPEG +ILSVRC2012_val_00003420.JPEG +ILSVRC2012_val_00018438.JPEG +ILSVRC2012_val_00044146.JPEG +ILSVRC2012_val_00046534.JPEG +ILSVRC2012_val_00013029.JPEG +ILSVRC2012_val_00041792.JPEG +ILSVRC2012_val_00038502.JPEG +ILSVRC2012_val_00020059.JPEG +ILSVRC2012_val_00013855.JPEG +ILSVRC2012_val_00021332.JPEG +ILSVRC2012_val_00015646.JPEG +ILSVRC2012_val_00013818.JPEG +ILSVRC2012_val_00033614.JPEG +ILSVRC2012_val_00031194.JPEG +ILSVRC2012_val_00029013.JPEG +ILSVRC2012_val_00004114.JPEG +ILSVRC2012_val_00039537.JPEG +ILSVRC2012_val_00016153.JPEG +ILSVRC2012_val_00010909.JPEG +ILSVRC2012_val_00030552.JPEG +ILSVRC2012_val_00036408.JPEG +ILSVRC2012_val_00027561.JPEG +ILSVRC2012_val_00038147.JPEG +ILSVRC2012_val_00002969.JPEG +ILSVRC2012_val_00045300.JPEG +ILSVRC2012_val_00045989.JPEG +ILSVRC2012_val_00011348.JPEG +ILSVRC2012_val_00042882.JPEG +ILSVRC2012_val_00047831.JPEG +ILSVRC2012_val_00041648.JPEG +ILSVRC2012_val_00024650.JPEG +ILSVRC2012_val_00016043.JPEG +ILSVRC2012_val_00028605.JPEG +ILSVRC2012_val_00013890.JPEG +ILSVRC2012_val_00011225.JPEG +ILSVRC2012_val_00005269.JPEG +ILSVRC2012_val_00016134.JPEG +ILSVRC2012_val_00029430.JPEG +ILSVRC2012_val_00012727.JPEG +ILSVRC2012_val_00013067.JPEG +ILSVRC2012_val_00032774.JPEG +ILSVRC2012_val_00028446.JPEG +ILSVRC2012_val_00045411.JPEG +ILSVRC2012_val_00049959.JPEG +ILSVRC2012_val_00024844.JPEG +ILSVRC2012_val_00029169.JPEG +ILSVRC2012_val_00001381.JPEG +ILSVRC2012_val_00038473.JPEG +ILSVRC2012_val_00034576.JPEG +ILSVRC2012_val_00009617.JPEG +ILSVRC2012_val_00023040.JPEG +ILSVRC2012_val_00005365.JPEG +ILSVRC2012_val_00019058.JPEG +ILSVRC2012_val_00019414.JPEG +ILSVRC2012_val_00008304.JPEG +ILSVRC2012_val_00022519.JPEG +ILSVRC2012_val_00018170.JPEG +ILSVRC2012_val_00013259.JPEG +ILSVRC2012_val_00044492.JPEG +ILSVRC2012_val_00021188.JPEG +ILSVRC2012_val_00041755.JPEG +ILSVRC2012_val_00001075.JPEG +ILSVRC2012_val_00033396.JPEG +ILSVRC2012_val_00002872.JPEG +ILSVRC2012_val_00009909.JPEG +ILSVRC2012_val_00017778.JPEG +ILSVRC2012_val_00014140.JPEG +ILSVRC2012_val_00010483.JPEG +ILSVRC2012_val_00005750.JPEG +ILSVRC2012_val_00031520.JPEG +ILSVRC2012_val_00044013.JPEG +ILSVRC2012_val_00026579.JPEG +ILSVRC2012_val_00033102.JPEG +ILSVRC2012_val_00047124.JPEG +ILSVRC2012_val_00005826.JPEG +ILSVRC2012_val_00039194.JPEG +ILSVRC2012_val_00049831.JPEG +ILSVRC2012_val_00043567.JPEG +ILSVRC2012_val_00005153.JPEG +ILSVRC2012_val_00040664.JPEG +ILSVRC2012_val_00016488.JPEG +ILSVRC2012_val_00030078.JPEG +ILSVRC2012_val_00017680.JPEG +ILSVRC2012_val_00043108.JPEG +ILSVRC2012_val_00043279.JPEG +ILSVRC2012_val_00026305.JPEG +ILSVRC2012_val_00009704.JPEG +ILSVRC2012_val_00011204.JPEG +ILSVRC2012_val_00019337.JPEG +ILSVRC2012_val_00015812.JPEG +ILSVRC2012_val_00033203.JPEG +ILSVRC2012_val_00012806.JPEG +ILSVRC2012_val_00007014.JPEG +ILSVRC2012_val_00008932.JPEG +ILSVRC2012_val_00048612.JPEG +ILSVRC2012_val_00018996.JPEG +ILSVRC2012_val_00010834.JPEG +ILSVRC2012_val_00014839.JPEG +ILSVRC2012_val_00039904.JPEG +ILSVRC2012_val_00048560.JPEG +ILSVRC2012_val_00017548.JPEG +ILSVRC2012_val_00030903.JPEG +ILSVRC2012_val_00001367.JPEG +ILSVRC2012_val_00041372.JPEG +ILSVRC2012_val_00011738.JPEG +ILSVRC2012_val_00031094.JPEG +ILSVRC2012_val_00005397.JPEG +ILSVRC2012_val_00034006.JPEG +ILSVRC2012_val_00017421.JPEG +ILSVRC2012_val_00024748.JPEG +ILSVRC2012_val_00019234.JPEG +ILSVRC2012_val_00007607.JPEG +ILSVRC2012_val_00003730.JPEG +ILSVRC2012_val_00034797.JPEG +ILSVRC2012_val_00042800.JPEG +ILSVRC2012_val_00009057.JPEG +ILSVRC2012_val_00030639.JPEG +ILSVRC2012_val_00021974.JPEG +ILSVRC2012_val_00044412.JPEG +ILSVRC2012_val_00023829.JPEG +ILSVRC2012_val_00030913.JPEG +ILSVRC2012_val_00000193.JPEG +ILSVRC2012_val_00021960.JPEG +ILSVRC2012_val_00039818.JPEG +ILSVRC2012_val_00012464.JPEG +ILSVRC2012_val_00025039.JPEG +ILSVRC2012_val_00025098.JPEG +ILSVRC2012_val_00007347.JPEG +ILSVRC2012_val_00037463.JPEG +ILSVRC2012_val_00000122.JPEG +ILSVRC2012_val_00022102.JPEG +ILSVRC2012_val_00010106.JPEG +ILSVRC2012_val_00040223.JPEG +ILSVRC2012_val_00025040.JPEG +ILSVRC2012_val_00017352.JPEG +ILSVRC2012_val_00023287.JPEG +ILSVRC2012_val_00045975.JPEG +ILSVRC2012_val_00046355.JPEG +ILSVRC2012_val_00019561.JPEG +ILSVRC2012_val_00008050.JPEG +ILSVRC2012_val_00016514.JPEG +ILSVRC2012_val_00027925.JPEG +ILSVRC2012_val_00041349.JPEG +ILSVRC2012_val_00037249.JPEG +ILSVRC2012_val_00048749.JPEG +ILSVRC2012_val_00033833.JPEG +ILSVRC2012_val_00031325.JPEG +ILSVRC2012_val_00023696.JPEG +ILSVRC2012_val_00004823.JPEG +ILSVRC2012_val_00007848.JPEG +ILSVRC2012_val_00046806.JPEG +ILSVRC2012_val_00028803.JPEG +ILSVRC2012_val_00006890.JPEG +ILSVRC2012_val_00002685.JPEG +ILSVRC2012_val_00034712.JPEG +ILSVRC2012_val_00010991.JPEG +ILSVRC2012_val_00026039.JPEG +ILSVRC2012_val_00029390.JPEG +ILSVRC2012_val_00021339.JPEG +ILSVRC2012_val_00020605.JPEG +ILSVRC2012_val_00020465.JPEG +ILSVRC2012_val_00032471.JPEG +ILSVRC2012_val_00011018.JPEG +ILSVRC2012_val_00015399.JPEG +ILSVRC2012_val_00000354.JPEG +ILSVRC2012_val_00007390.JPEG +ILSVRC2012_val_00005647.JPEG +ILSVRC2012_val_00025696.JPEG +ILSVRC2012_val_00039098.JPEG +ILSVRC2012_val_00048418.JPEG +ILSVRC2012_val_00042773.JPEG +ILSVRC2012_val_00026930.JPEG +ILSVRC2012_val_00012540.JPEG +ILSVRC2012_val_00030740.JPEG +ILSVRC2012_val_00043072.JPEG +ILSVRC2012_val_00040071.JPEG +ILSVRC2012_val_00015547.JPEG +ILSVRC2012_val_00045897.JPEG +ILSVRC2012_val_00022879.JPEG +ILSVRC2012_val_00025956.JPEG +ILSVRC2012_val_00022056.JPEG +ILSVRC2012_val_00011302.JPEG +ILSVRC2012_val_00004829.JPEG +ILSVRC2012_val_00011365.JPEG +ILSVRC2012_val_00032556.JPEG +ILSVRC2012_val_00021427.JPEG +ILSVRC2012_val_00032797.JPEG +ILSVRC2012_val_00048838.JPEG +ILSVRC2012_val_00019673.JPEG +ILSVRC2012_val_00037452.JPEG +ILSVRC2012_val_00036987.JPEG +ILSVRC2012_val_00017293.JPEG +ILSVRC2012_val_00014712.JPEG +ILSVRC2012_val_00042119.JPEG +ILSVRC2012_val_00005962.JPEG +ILSVRC2012_val_00026244.JPEG +ILSVRC2012_val_00046270.JPEG +ILSVRC2012_val_00007251.JPEG +ILSVRC2012_val_00024686.JPEG +ILSVRC2012_val_00020672.JPEG +ILSVRC2012_val_00043504.JPEG +ILSVRC2012_val_00044726.JPEG +ILSVRC2012_val_00030647.JPEG +ILSVRC2012_val_00038282.JPEG +ILSVRC2012_val_00010610.JPEG +ILSVRC2012_val_00018777.JPEG +ILSVRC2012_val_00004946.JPEG +ILSVRC2012_val_00044956.JPEG +ILSVRC2012_val_00015228.JPEG +ILSVRC2012_val_00024378.JPEG +ILSVRC2012_val_00024443.JPEG +ILSVRC2012_val_00025208.JPEG +ILSVRC2012_val_00048100.JPEG +ILSVRC2012_val_00046205.JPEG +ILSVRC2012_val_00041782.JPEG +ILSVRC2012_val_00033610.JPEG +ILSVRC2012_val_00008246.JPEG +ILSVRC2012_val_00003679.JPEG +ILSVRC2012_val_00043123.JPEG +ILSVRC2012_val_00000210.JPEG +ILSVRC2012_val_00025762.JPEG +ILSVRC2012_val_00027488.JPEG +ILSVRC2012_val_00048938.JPEG +ILSVRC2012_val_00025075.JPEG +ILSVRC2012_val_00041074.JPEG +ILSVRC2012_val_00022556.JPEG +ILSVRC2012_val_00018823.JPEG +ILSVRC2012_val_00045753.JPEG +ILSVRC2012_val_00013386.JPEG +ILSVRC2012_val_00026946.JPEG +ILSVRC2012_val_00048235.JPEG +ILSVRC2012_val_00025756.JPEG +ILSVRC2012_val_00041384.JPEG +ILSVRC2012_val_00013678.JPEG +ILSVRC2012_val_00047135.JPEG +ILSVRC2012_val_00026366.JPEG +ILSVRC2012_val_00016900.JPEG +ILSVRC2012_val_00011345.JPEG +ILSVRC2012_val_00009725.JPEG +ILSVRC2012_val_00013007.JPEG +ILSVRC2012_val_00004201.JPEG +ILSVRC2012_val_00009511.JPEG +ILSVRC2012_val_00036527.JPEG +ILSVRC2012_val_00047454.JPEG +ILSVRC2012_val_00015043.JPEG +ILSVRC2012_val_00009853.JPEG +ILSVRC2012_val_00017604.JPEG +ILSVRC2012_val_00032836.JPEG +ILSVRC2012_val_00033595.JPEG +ILSVRC2012_val_00048906.JPEG +ILSVRC2012_val_00011517.JPEG +ILSVRC2012_val_00048678.JPEG +ILSVRC2012_val_00045996.JPEG +ILSVRC2012_val_00040617.JPEG +ILSVRC2012_val_00003978.JPEG +ILSVRC2012_val_00020299.JPEG +ILSVRC2012_val_00024133.JPEG +ILSVRC2012_val_00028842.JPEG +ILSVRC2012_val_00011453.JPEG +ILSVRC2012_val_00021601.JPEG +ILSVRC2012_val_00024227.JPEG +ILSVRC2012_val_00045589.JPEG +ILSVRC2012_val_00014498.JPEG +ILSVRC2012_val_00004416.JPEG +ILSVRC2012_val_00032984.JPEG +ILSVRC2012_val_00026525.JPEG +ILSVRC2012_val_00042430.JPEG +ILSVRC2012_val_00016546.JPEG +ILSVRC2012_val_00043119.JPEG +ILSVRC2012_val_00017948.JPEG +ILSVRC2012_val_00048245.JPEG +ILSVRC2012_val_00009110.JPEG +ILSVRC2012_val_00025007.JPEG +ILSVRC2012_val_00015402.JPEG +ILSVRC2012_val_00029990.JPEG +ILSVRC2012_val_00004363.JPEG +ILSVRC2012_val_00002878.JPEG +ILSVRC2012_val_00029485.JPEG +ILSVRC2012_val_00031510.JPEG +ILSVRC2012_val_00014520.JPEG +ILSVRC2012_val_00041160.JPEG +ILSVRC2012_val_00047942.JPEG +ILSVRC2012_val_00026490.JPEG +ILSVRC2012_val_00005196.JPEG +ILSVRC2012_val_00042645.JPEG +ILSVRC2012_val_00016568.JPEG +ILSVRC2012_val_00011622.JPEG +ILSVRC2012_val_00009490.JPEG +ILSVRC2012_val_00047793.JPEG +ILSVRC2012_val_00031270.JPEG +ILSVRC2012_val_00038955.JPEG +ILSVRC2012_val_00027017.JPEG +ILSVRC2012_val_00039795.JPEG +ILSVRC2012_val_00031125.JPEG +ILSVRC2012_val_00046846.JPEG +ILSVRC2012_val_00033685.JPEG +ILSVRC2012_val_00044674.JPEG +ILSVRC2012_val_00037219.JPEG +ILSVRC2012_val_00041885.JPEG +ILSVRC2012_val_00027034.JPEG +ILSVRC2012_val_00034841.JPEG +ILSVRC2012_val_00045795.JPEG +ILSVRC2012_val_00003336.JPEG +ILSVRC2012_val_00048872.JPEG +ILSVRC2012_val_00018610.JPEG +ILSVRC2012_val_00047937.JPEG +ILSVRC2012_val_00016410.JPEG +ILSVRC2012_val_00006735.JPEG +ILSVRC2012_val_00047970.JPEG +ILSVRC2012_val_00007180.JPEG +ILSVRC2012_val_00043352.JPEG +ILSVRC2012_val_00045598.JPEG +ILSVRC2012_val_00044681.JPEG +ILSVRC2012_val_00016522.JPEG +ILSVRC2012_val_00018734.JPEG +ILSVRC2012_val_00017994.JPEG +ILSVRC2012_val_00011107.JPEG +ILSVRC2012_val_00032339.JPEG +ILSVRC2012_val_00005650.JPEG +ILSVRC2012_val_00046592.JPEG +ILSVRC2012_val_00016999.JPEG +ILSVRC2012_val_00007256.JPEG +ILSVRC2012_val_00031014.JPEG +ILSVRC2012_val_00043930.JPEG +ILSVRC2012_val_00042914.JPEG +ILSVRC2012_val_00015203.JPEG +ILSVRC2012_val_00007062.JPEG +ILSVRC2012_val_00048588.JPEG +ILSVRC2012_val_00041555.JPEG +ILSVRC2012_val_00024219.JPEG +ILSVRC2012_val_00017549.JPEG +ILSVRC2012_val_00012833.JPEG +ILSVRC2012_val_00013018.JPEG +ILSVRC2012_val_00013253.JPEG +ILSVRC2012_val_00023799.JPEG +ILSVRC2012_val_00041210.JPEG +ILSVRC2012_val_00047201.JPEG +ILSVRC2012_val_00023755.JPEG +ILSVRC2012_val_00006761.JPEG +ILSVRC2012_val_00005336.JPEG +ILSVRC2012_val_00002481.JPEG +ILSVRC2012_val_00036505.JPEG +ILSVRC2012_val_00042366.JPEG +ILSVRC2012_val_00015770.JPEG +ILSVRC2012_val_00005154.JPEG +ILSVRC2012_val_00032494.JPEG +ILSVRC2012_val_00013367.JPEG +ILSVRC2012_val_00011415.JPEG +ILSVRC2012_val_00030635.JPEG +ILSVRC2012_val_00010707.JPEG +ILSVRC2012_val_00023534.JPEG +ILSVRC2012_val_00027156.JPEG +ILSVRC2012_val_00038828.JPEG +ILSVRC2012_val_00021369.JPEG +ILSVRC2012_val_00009037.JPEG +ILSVRC2012_val_00025816.JPEG +ILSVRC2012_val_00000528.JPEG +ILSVRC2012_val_00015516.JPEG +ILSVRC2012_val_00013365.JPEG +ILSVRC2012_val_00001375.JPEG +ILSVRC2012_val_00013254.JPEG +ILSVRC2012_val_00024024.JPEG +ILSVRC2012_val_00049743.JPEG +ILSVRC2012_val_00024668.JPEG +ILSVRC2012_val_00019979.JPEG +ILSVRC2012_val_00049403.JPEG +ILSVRC2012_val_00033031.JPEG +ILSVRC2012_val_00026435.JPEG +ILSVRC2012_val_00005014.JPEG +ILSVRC2012_val_00044519.JPEG +ILSVRC2012_val_00041400.JPEG +ILSVRC2012_val_00023153.JPEG +ILSVRC2012_val_00023174.JPEG +ILSVRC2012_val_00036031.JPEG +ILSVRC2012_val_00003064.JPEG +ILSVRC2012_val_00009806.JPEG +ILSVRC2012_val_00001591.JPEG +ILSVRC2012_val_00037113.JPEG +ILSVRC2012_val_00010688.JPEG +ILSVRC2012_val_00017687.JPEG +ILSVRC2012_val_00028076.JPEG +ILSVRC2012_val_00007170.JPEG +ILSVRC2012_val_00000497.JPEG +ILSVRC2012_val_00025752.JPEG +ILSVRC2012_val_00039846.JPEG +ILSVRC2012_val_00047242.JPEG +ILSVRC2012_val_00036323.JPEG +ILSVRC2012_val_00042671.JPEG +ILSVRC2012_val_00028484.JPEG +ILSVRC2012_val_00021998.JPEG +ILSVRC2012_val_00027139.JPEG +ILSVRC2012_val_00039510.JPEG +ILSVRC2012_val_00009327.JPEG +ILSVRC2012_val_00022843.JPEG +ILSVRC2012_val_00031504.JPEG +ILSVRC2012_val_00017809.JPEG +ILSVRC2012_val_00030672.JPEG +ILSVRC2012_val_00000827.JPEG +ILSVRC2012_val_00011571.JPEG +ILSVRC2012_val_00049690.JPEG +ILSVRC2012_val_00045433.JPEG +ILSVRC2012_val_00033826.JPEG +ILSVRC2012_val_00001737.JPEG +ILSVRC2012_val_00027087.JPEG +ILSVRC2012_val_00032310.JPEG +ILSVRC2012_val_00047735.JPEG +ILSVRC2012_val_00036664.JPEG +ILSVRC2012_val_00039458.JPEG +ILSVRC2012_val_00026424.JPEG +ILSVRC2012_val_00012523.JPEG +ILSVRC2012_val_00015160.JPEG +ILSVRC2012_val_00010941.JPEG +ILSVRC2012_val_00039656.JPEG +ILSVRC2012_val_00031896.JPEG +ILSVRC2012_val_00017885.JPEG +ILSVRC2012_val_00038186.JPEG +ILSVRC2012_val_00021767.JPEG +ILSVRC2012_val_00036012.JPEG +ILSVRC2012_val_00048380.JPEG +ILSVRC2012_val_00013820.JPEG +ILSVRC2012_val_00026529.JPEG +ILSVRC2012_val_00006710.JPEG +ILSVRC2012_val_00020398.JPEG +ILSVRC2012_val_00047326.JPEG +ILSVRC2012_val_00012669.JPEG +ILSVRC2012_val_00026976.JPEG +ILSVRC2012_val_00004278.JPEG +ILSVRC2012_val_00016393.JPEG +ILSVRC2012_val_00047266.JPEG +ILSVRC2012_val_00032641.JPEG +ILSVRC2012_val_00006798.JPEG +ILSVRC2012_val_00024817.JPEG +ILSVRC2012_val_00029226.JPEG +ILSVRC2012_val_00034945.JPEG +ILSVRC2012_val_00035375.JPEG +ILSVRC2012_val_00004798.JPEG +ILSVRC2012_val_00013497.JPEG +ILSVRC2012_val_00025542.JPEG +ILSVRC2012_val_00020380.JPEG +ILSVRC2012_val_00002931.JPEG +ILSVRC2012_val_00018754.JPEG +ILSVRC2012_val_00042797.JPEG +ILSVRC2012_val_00030382.JPEG +ILSVRC2012_val_00023236.JPEG +ILSVRC2012_val_00037214.JPEG +ILSVRC2012_val_00000135.JPEG +ILSVRC2012_val_00008748.JPEG +ILSVRC2012_val_00011084.JPEG +ILSVRC2012_val_00037563.JPEG +ILSVRC2012_val_00011226.JPEG +ILSVRC2012_val_00023323.JPEG +ILSVRC2012_val_00001730.JPEG +ILSVRC2012_val_00006939.JPEG +ILSVRC2012_val_00016731.JPEG +ILSVRC2012_val_00033291.JPEG +ILSVRC2012_val_00005491.JPEG +ILSVRC2012_val_00014689.JPEG +ILSVRC2012_val_00005596.JPEG +ILSVRC2012_val_00014184.JPEG +ILSVRC2012_val_00031508.JPEG +ILSVRC2012_val_00002880.JPEG +ILSVRC2012_val_00002616.JPEG +ILSVRC2012_val_00013140.JPEG +ILSVRC2012_val_00029193.JPEG +ILSVRC2012_val_00001271.JPEG +ILSVRC2012_val_00029653.JPEG +ILSVRC2012_val_00036161.JPEG +ILSVRC2012_val_00018733.JPEG +ILSVRC2012_val_00023571.JPEG +ILSVRC2012_val_00037044.JPEG +ILSVRC2012_val_00047272.JPEG +ILSVRC2012_val_00045497.JPEG +ILSVRC2012_val_00021722.JPEG +ILSVRC2012_val_00013570.JPEG +ILSVRC2012_val_00031753.JPEG +ILSVRC2012_val_00034837.JPEG +ILSVRC2012_val_00001774.JPEG +ILSVRC2012_val_00048482.JPEG +ILSVRC2012_val_00003848.JPEG +ILSVRC2012_val_00009515.JPEG +ILSVRC2012_val_00028551.JPEG +ILSVRC2012_val_00013078.JPEG +ILSVRC2012_val_00043860.JPEG +ILSVRC2012_val_00025865.JPEG +ILSVRC2012_val_00027787.JPEG +ILSVRC2012_val_00027684.JPEG +ILSVRC2012_val_00034322.JPEG +ILSVRC2012_val_00026524.JPEG +ILSVRC2012_val_00047833.JPEG +ILSVRC2012_val_00012522.JPEG +ILSVRC2012_val_00033748.JPEG +ILSVRC2012_val_00037257.JPEG +ILSVRC2012_val_00036908.JPEG +ILSVRC2012_val_00015542.JPEG +ILSVRC2012_val_00025141.JPEG +ILSVRC2012_val_00030907.JPEG +ILSVRC2012_val_00048147.JPEG +ILSVRC2012_val_00043921.JPEG +ILSVRC2012_val_00010771.JPEG +ILSVRC2012_val_00019635.JPEG +ILSVRC2012_val_00005907.JPEG +ILSVRC2012_val_00024121.JPEG +ILSVRC2012_val_00002018.JPEG +ILSVRC2012_val_00014408.JPEG +ILSVRC2012_val_00046525.JPEG +ILSVRC2012_val_00032221.JPEG +ILSVRC2012_val_00036459.JPEG +ILSVRC2012_val_00003479.JPEG +ILSVRC2012_val_00005618.JPEG +ILSVRC2012_val_00002493.JPEG +ILSVRC2012_val_00015807.JPEG +ILSVRC2012_val_00011386.JPEG +ILSVRC2012_val_00020733.JPEG +ILSVRC2012_val_00002482.JPEG +ILSVRC2012_val_00034622.JPEG +ILSVRC2012_val_00012771.JPEG +ILSVRC2012_val_00043302.JPEG +ILSVRC2012_val_00022408.JPEG +ILSVRC2012_val_00041327.JPEG +ILSVRC2012_val_00019171.JPEG +ILSVRC2012_val_00043995.JPEG +ILSVRC2012_val_00041223.JPEG +ILSVRC2012_val_00008465.JPEG +ILSVRC2012_val_00038063.JPEG +ILSVRC2012_val_00039038.JPEG +ILSVRC2012_val_00033159.JPEG +ILSVRC2012_val_00011382.JPEG +ILSVRC2012_val_00044823.JPEG +ILSVRC2012_val_00007399.JPEG +ILSVRC2012_val_00047507.JPEG +ILSVRC2012_val_00031973.JPEG +ILSVRC2012_val_00032965.JPEG +ILSVRC2012_val_00004937.JPEG +ILSVRC2012_val_00024038.JPEG +ILSVRC2012_val_00008984.JPEG +ILSVRC2012_val_00005456.JPEG +ILSVRC2012_val_00014141.JPEG +ILSVRC2012_val_00049914.JPEG +ILSVRC2012_val_00033036.JPEG +ILSVRC2012_val_00035721.JPEG +ILSVRC2012_val_00009258.JPEG +ILSVRC2012_val_00009894.JPEG +ILSVRC2012_val_00038122.JPEG +ILSVRC2012_val_00023047.JPEG +ILSVRC2012_val_00005399.JPEG +ILSVRC2012_val_00029479.JPEG +ILSVRC2012_val_00008114.JPEG +ILSVRC2012_val_00037288.JPEG +ILSVRC2012_val_00014414.JPEG +ILSVRC2012_val_00049021.JPEG +ILSVRC2012_val_00032807.JPEG +ILSVRC2012_val_00011872.JPEG +ILSVRC2012_val_00020290.JPEG +ILSVRC2012_val_00025221.JPEG +ILSVRC2012_val_00008645.JPEG +ILSVRC2012_val_00023939.JPEG +ILSVRC2012_val_00030926.JPEG +ILSVRC2012_val_00039414.JPEG +ILSVRC2012_val_00036206.JPEG +ILSVRC2012_val_00010715.JPEG +ILSVRC2012_val_00044384.JPEG +ILSVRC2012_val_00024903.JPEG +ILSVRC2012_val_00010352.JPEG +ILSVRC2012_val_00001360.JPEG +ILSVRC2012_val_00047931.JPEG +ILSVRC2012_val_00032664.JPEG +ILSVRC2012_val_00029760.JPEG +ILSVRC2012_val_00041369.JPEG +ILSVRC2012_val_00044874.JPEG +ILSVRC2012_val_00028223.JPEG +ILSVRC2012_val_00047818.JPEG +ILSVRC2012_val_00016276.JPEG +ILSVRC2012_val_00017543.JPEG +ILSVRC2012_val_00006048.JPEG +ILSVRC2012_val_00037206.JPEG +ILSVRC2012_val_00010147.JPEG +ILSVRC2012_val_00014000.JPEG +ILSVRC2012_val_00010511.JPEG +ILSVRC2012_val_00036604.JPEG +ILSVRC2012_val_00025713.JPEG +ILSVRC2012_val_00037759.JPEG +ILSVRC2012_val_00038048.JPEG +ILSVRC2012_val_00043154.JPEG +ILSVRC2012_val_00035140.JPEG +ILSVRC2012_val_00030110.JPEG +ILSVRC2012_val_00037239.JPEG +ILSVRC2012_val_00046044.JPEG +ILSVRC2012_val_00038143.JPEG +ILSVRC2012_val_00014965.JPEG +ILSVRC2012_val_00016544.JPEG +ILSVRC2012_val_00019020.JPEG +ILSVRC2012_val_00035467.JPEG +ILSVRC2012_val_00027289.JPEG +ILSVRC2012_val_00023677.JPEG +ILSVRC2012_val_00048537.JPEG +ILSVRC2012_val_00041241.JPEG +ILSVRC2012_val_00044502.JPEG +ILSVRC2012_val_00003035.JPEG +ILSVRC2012_val_00047662.JPEG +ILSVRC2012_val_00001398.JPEG +ILSVRC2012_val_00039204.JPEG +ILSVRC2012_val_00022559.JPEG +ILSVRC2012_val_00008368.JPEG +ILSVRC2012_val_00033096.JPEG +ILSVRC2012_val_00047351.JPEG +ILSVRC2012_val_00039635.JPEG +ILSVRC2012_val_00039703.JPEG +ILSVRC2012_val_00029755.JPEG +ILSVRC2012_val_00022114.JPEG +ILSVRC2012_val_00005679.JPEG +ILSVRC2012_val_00047827.JPEG +ILSVRC2012_val_00032430.JPEG +ILSVRC2012_val_00011076.JPEG +ILSVRC2012_val_00047902.JPEG +ILSVRC2012_val_00003982.JPEG +ILSVRC2012_val_00004243.JPEG +ILSVRC2012_val_00043230.JPEG +ILSVRC2012_val_00020292.JPEG +ILSVRC2012_val_00044219.JPEG +ILSVRC2012_val_00019123.JPEG +ILSVRC2012_val_00032522.JPEG +ILSVRC2012_val_00017295.JPEG +ILSVRC2012_val_00035834.JPEG +ILSVRC2012_val_00037647.JPEG +ILSVRC2012_val_00026070.JPEG +ILSVRC2012_val_00041269.JPEG +ILSVRC2012_val_00031611.JPEG +ILSVRC2012_val_00029664.JPEG +ILSVRC2012_val_00013693.JPEG +ILSVRC2012_val_00001840.JPEG +ILSVRC2012_val_00040194.JPEG +ILSVRC2012_val_00040940.JPEG +ILSVRC2012_val_00049385.JPEG +ILSVRC2012_val_00001100.JPEG +ILSVRC2012_val_00003868.JPEG +ILSVRC2012_val_00001894.JPEG +ILSVRC2012_val_00038766.JPEG +ILSVRC2012_val_00017576.JPEG +ILSVRC2012_val_00017462.JPEG +ILSVRC2012_val_00017485.JPEG +ILSVRC2012_val_00045020.JPEG +ILSVRC2012_val_00027846.JPEG +ILSVRC2012_val_00026762.JPEG +ILSVRC2012_val_00022174.JPEG +ILSVRC2012_val_00038702.JPEG +ILSVRC2012_val_00032721.JPEG +ILSVRC2012_val_00038270.JPEG +ILSVRC2012_val_00049299.JPEG +ILSVRC2012_val_00033499.JPEG +ILSVRC2012_val_00046257.JPEG +ILSVRC2012_val_00018381.JPEG +ILSVRC2012_val_00009651.JPEG +ILSVRC2012_val_00032305.JPEG +ILSVRC2012_val_00003718.JPEG +ILSVRC2012_val_00006326.JPEG +ILSVRC2012_val_00001916.JPEG +ILSVRC2012_val_00039118.JPEG +ILSVRC2012_val_00048732.JPEG +ILSVRC2012_val_00011364.JPEG +ILSVRC2012_val_00041867.JPEG +ILSVRC2012_val_00031951.JPEG +ILSVRC2012_val_00010436.JPEG +ILSVRC2012_val_00044302.JPEG +ILSVRC2012_val_00036652.JPEG +ILSVRC2012_val_00017962.JPEG +ILSVRC2012_val_00015922.JPEG +ILSVRC2012_val_00028691.JPEG +ILSVRC2012_val_00023703.JPEG +ILSVRC2012_val_00027118.JPEG +ILSVRC2012_val_00031088.JPEG +ILSVRC2012_val_00048035.JPEG +ILSVRC2012_val_00008881.JPEG +ILSVRC2012_val_00049787.JPEG +ILSVRC2012_val_00000899.JPEG +ILSVRC2012_val_00003070.JPEG +ILSVRC2012_val_00038633.JPEG +ILSVRC2012_val_00035694.JPEG +ILSVRC2012_val_00029887.JPEG +ILSVRC2012_val_00013835.JPEG +ILSVRC2012_val_00031751.JPEG +ILSVRC2012_val_00016772.JPEG +ILSVRC2012_val_00028999.JPEG +ILSVRC2012_val_00018409.JPEG +ILSVRC2012_val_00026861.JPEG +ILSVRC2012_val_00034830.JPEG +ILSVRC2012_val_00034926.JPEG +ILSVRC2012_val_00043571.JPEG +ILSVRC2012_val_00003241.JPEG +ILSVRC2012_val_00028256.JPEG +ILSVRC2012_val_00030584.JPEG +ILSVRC2012_val_00005095.JPEG +ILSVRC2012_val_00018816.JPEG +ILSVRC2012_val_00005278.JPEG +ILSVRC2012_val_00001729.JPEG +ILSVRC2012_val_00016595.JPEG +ILSVRC2012_val_00029643.JPEG +ILSVRC2012_val_00039985.JPEG +ILSVRC2012_val_00043800.JPEG +ILSVRC2012_val_00003554.JPEG +ILSVRC2012_val_00012793.JPEG +ILSVRC2012_val_00009880.JPEG +ILSVRC2012_val_00048828.JPEG +ILSVRC2012_val_00029320.JPEG +ILSVRC2012_val_00044601.JPEG +ILSVRC2012_val_00036231.JPEG +ILSVRC2012_val_00001570.JPEG +ILSVRC2012_val_00008314.JPEG +ILSVRC2012_val_00003759.JPEG +ILSVRC2012_val_00045292.JPEG +ILSVRC2012_val_00001537.JPEG +ILSVRC2012_val_00016062.JPEG +ILSVRC2012_val_00010649.JPEG +ILSVRC2012_val_00005086.JPEG +ILSVRC2012_val_00020472.JPEG +ILSVRC2012_val_00016376.JPEG +ILSVRC2012_val_00006484.JPEG +ILSVRC2012_val_00006818.JPEG +ILSVRC2012_val_00005477.JPEG +ILSVRC2012_val_00000758.JPEG +ILSVRC2012_val_00034580.JPEG +ILSVRC2012_val_00040628.JPEG +ILSVRC2012_val_00025019.JPEG +ILSVRC2012_val_00046997.JPEG +ILSVRC2012_val_00039483.JPEG +ILSVRC2012_val_00012646.JPEG +ILSVRC2012_val_00022453.JPEG +ILSVRC2012_val_00002412.JPEG +ILSVRC2012_val_00006101.JPEG +ILSVRC2012_val_00022385.JPEG +ILSVRC2012_val_00030333.JPEG +ILSVRC2012_val_00036622.JPEG +ILSVRC2012_val_00011425.JPEG +ILSVRC2012_val_00029540.JPEG +ILSVRC2012_val_00003227.JPEG +ILSVRC2012_val_00038888.JPEG +ILSVRC2012_val_00034908.JPEG +ILSVRC2012_val_00042152.JPEG +ILSVRC2012_val_00024803.JPEG +ILSVRC2012_val_00026906.JPEG +ILSVRC2012_val_00041262.JPEG +ILSVRC2012_val_00015979.JPEG +ILSVRC2012_val_00028790.JPEG +ILSVRC2012_val_00040478.JPEG +ILSVRC2012_val_00036017.JPEG +ILSVRC2012_val_00030186.JPEG +ILSVRC2012_val_00035088.JPEG +ILSVRC2012_val_00049286.JPEG +ILSVRC2012_val_00006379.JPEG +ILSVRC2012_val_00042853.JPEG +ILSVRC2012_val_00004196.JPEG +ILSVRC2012_val_00037488.JPEG +ILSVRC2012_val_00001353.JPEG +ILSVRC2012_val_00036304.JPEG +ILSVRC2012_val_00045517.JPEG +ILSVRC2012_val_00003291.JPEG +ILSVRC2012_val_00029734.JPEG +ILSVRC2012_val_00009106.JPEG +ILSVRC2012_val_00020087.JPEG +ILSVRC2012_val_00003786.JPEG +ILSVRC2012_val_00044277.JPEG +ILSVRC2012_val_00043575.JPEG +ILSVRC2012_val_00045955.JPEG +ILSVRC2012_val_00042236.JPEG +ILSVRC2012_val_00007953.JPEG +ILSVRC2012_val_00006787.JPEG +ILSVRC2012_val_00021956.JPEG +ILSVRC2012_val_00000322.JPEG +ILSVRC2012_val_00031844.JPEG +ILSVRC2012_val_00036936.JPEG +ILSVRC2012_val_00034424.JPEG +ILSVRC2012_val_00006994.JPEG +ILSVRC2012_val_00040676.JPEG +ILSVRC2012_val_00011749.JPEG +ILSVRC2012_val_00003317.JPEG +ILSVRC2012_val_00009013.JPEG +ILSVRC2012_val_00023386.JPEG +ILSVRC2012_val_00025628.JPEG +ILSVRC2012_val_00015736.JPEG +ILSVRC2012_val_00030808.JPEG +ILSVRC2012_val_00031380.JPEG +ILSVRC2012_val_00035777.JPEG +ILSVRC2012_val_00024722.JPEG +ILSVRC2012_val_00013158.JPEG +ILSVRC2012_val_00047989.JPEG +ILSVRC2012_val_00043091.JPEG +ILSVRC2012_val_00003066.JPEG +ILSVRC2012_val_00017016.JPEG +ILSVRC2012_val_00030188.JPEG +ILSVRC2012_val_00042585.JPEG +ILSVRC2012_val_00037261.JPEG +ILSVRC2012_val_00013331.JPEG +ILSVRC2012_val_00004888.JPEG +ILSVRC2012_val_00028013.JPEG +ILSVRC2012_val_00017265.JPEG +ILSVRC2012_val_00003048.JPEG +ILSVRC2012_val_00013356.JPEG +ILSVRC2012_val_00024811.JPEG +ILSVRC2012_val_00023836.JPEG +ILSVRC2012_val_00020958.JPEG +ILSVRC2012_val_00007571.JPEG +ILSVRC2012_val_00022137.JPEG +ILSVRC2012_val_00049928.JPEG +ILSVRC2012_val_00000404.JPEG +ILSVRC2012_val_00035404.JPEG +ILSVRC2012_val_00022471.JPEG +ILSVRC2012_val_00020125.JPEG +ILSVRC2012_val_00002021.JPEG +ILSVRC2012_val_00009245.JPEG +ILSVRC2012_val_00022094.JPEG +ILSVRC2012_val_00045984.JPEG +ILSVRC2012_val_00017407.JPEG +ILSVRC2012_val_00006886.JPEG +ILSVRC2012_val_00046123.JPEG +ILSVRC2012_val_00048988.JPEG +ILSVRC2012_val_00047495.JPEG +ILSVRC2012_val_00000996.JPEG +ILSVRC2012_val_00027511.JPEG +ILSVRC2012_val_00037891.JPEG +ILSVRC2012_val_00009687.JPEG +ILSVRC2012_val_00043475.JPEG +ILSVRC2012_val_00034760.JPEG +ILSVRC2012_val_00005557.JPEG +ILSVRC2012_val_00043562.JPEG +ILSVRC2012_val_00010685.JPEG +ILSVRC2012_val_00044440.JPEG +ILSVRC2012_val_00028249.JPEG +ILSVRC2012_val_00020931.JPEG +ILSVRC2012_val_00001117.JPEG +ILSVRC2012_val_00033382.JPEG +ILSVRC2012_val_00049517.JPEG +ILSVRC2012_val_00012415.JPEG +ILSVRC2012_val_00046017.JPEG +ILSVRC2012_val_00048110.JPEG +ILSVRC2012_val_00029294.JPEG +ILSVRC2012_val_00021012.JPEG +ILSVRC2012_val_00002932.JPEG +ILSVRC2012_val_00023370.JPEG +ILSVRC2012_val_00044269.JPEG +ILSVRC2012_val_00041109.JPEG +ILSVRC2012_val_00003521.JPEG +ILSVRC2012_val_00041790.JPEG +ILSVRC2012_val_00026386.JPEG +ILSVRC2012_val_00031970.JPEG +ILSVRC2012_val_00036389.JPEG +ILSVRC2012_val_00016927.JPEG +ILSVRC2012_val_00045838.JPEG +ILSVRC2012_val_00029025.JPEG +ILSVRC2012_val_00029919.JPEG +ILSVRC2012_val_00045995.JPEG +ILSVRC2012_val_00049404.JPEG +ILSVRC2012_val_00028213.JPEG +ILSVRC2012_val_00021552.JPEG +ILSVRC2012_val_00011462.JPEG +ILSVRC2012_val_00012120.JPEG +ILSVRC2012_val_00031518.JPEG +ILSVRC2012_val_00001755.JPEG +ILSVRC2012_val_00043483.JPEG +ILSVRC2012_val_00047619.JPEG +ILSVRC2012_val_00035608.JPEG +ILSVRC2012_val_00045005.JPEG +ILSVRC2012_val_00043222.JPEG +ILSVRC2012_val_00016323.JPEG +ILSVRC2012_val_00043819.JPEG +ILSVRC2012_val_00006955.JPEG +ILSVRC2012_val_00034282.JPEG +ILSVRC2012_val_00039261.JPEG +ILSVRC2012_val_00006523.JPEG +ILSVRC2012_val_00025410.JPEG +ILSVRC2012_val_00001433.JPEG +ILSVRC2012_val_00047901.JPEG +ILSVRC2012_val_00009908.JPEG +ILSVRC2012_val_00007784.JPEG +ILSVRC2012_val_00008784.JPEG +ILSVRC2012_val_00033738.JPEG +ILSVRC2012_val_00004560.JPEG +ILSVRC2012_val_00044617.JPEG +ILSVRC2012_val_00043150.JPEG +ILSVRC2012_val_00049023.JPEG +ILSVRC2012_val_00038347.JPEG +ILSVRC2012_val_00039363.JPEG +ILSVRC2012_val_00037161.JPEG +ILSVRC2012_val_00012640.JPEG +ILSVRC2012_val_00027094.JPEG +ILSVRC2012_val_00029996.JPEG +ILSVRC2012_val_00046781.JPEG +ILSVRC2012_val_00013109.JPEG +ILSVRC2012_val_00035398.JPEG +ILSVRC2012_val_00042465.JPEG +ILSVRC2012_val_00003967.JPEG +ILSVRC2012_val_00022517.JPEG +ILSVRC2012_val_00007171.JPEG +ILSVRC2012_val_00046469.JPEG +ILSVRC2012_val_00042136.JPEG +ILSVRC2012_val_00048739.JPEG +ILSVRC2012_val_00024516.JPEG +ILSVRC2012_val_00030376.JPEG +ILSVRC2012_val_00022400.JPEG +ILSVRC2012_val_00029228.JPEG +ILSVRC2012_val_00002017.JPEG +ILSVRC2012_val_00049261.JPEG +ILSVRC2012_val_00035203.JPEG +ILSVRC2012_val_00016722.JPEG +ILSVRC2012_val_00047226.JPEG +ILSVRC2012_val_00001556.JPEG +ILSVRC2012_val_00039535.JPEG +ILSVRC2012_val_00023588.JPEG +ILSVRC2012_val_00036212.JPEG +ILSVRC2012_val_00026723.JPEG +ILSVRC2012_val_00032167.JPEG +ILSVRC2012_val_00017004.JPEG +ILSVRC2012_val_00045489.JPEG +ILSVRC2012_val_00003631.JPEG +ILSVRC2012_val_00042571.JPEG +ILSVRC2012_val_00041284.JPEG +ILSVRC2012_val_00035772.JPEG +ILSVRC2012_val_00033797.JPEG +ILSVRC2012_val_00011165.JPEG +ILSVRC2012_val_00032591.JPEG +ILSVRC2012_val_00016995.JPEG +ILSVRC2012_val_00010645.JPEG +ILSVRC2012_val_00049371.JPEG +ILSVRC2012_val_00020545.JPEG +ILSVRC2012_val_00039430.JPEG +ILSVRC2012_val_00019767.JPEG +ILSVRC2012_val_00039222.JPEG +ILSVRC2012_val_00016588.JPEG +ILSVRC2012_val_00039019.JPEG +ILSVRC2012_val_00021775.JPEG +ILSVRC2012_val_00045721.JPEG +ILSVRC2012_val_00001666.JPEG +ILSVRC2012_val_00041768.JPEG +ILSVRC2012_val_00014957.JPEG +ILSVRC2012_val_00038019.JPEG +ILSVRC2012_val_00038731.JPEG +ILSVRC2012_val_00049518.JPEG +ILSVRC2012_val_00028903.JPEG +ILSVRC2012_val_00047887.JPEG +ILSVRC2012_val_00049926.JPEG +ILSVRC2012_val_00032111.JPEG +ILSVRC2012_val_00034358.JPEG +ILSVRC2012_val_00016019.JPEG +ILSVRC2012_val_00010221.JPEG +ILSVRC2012_val_00042170.JPEG +ILSVRC2012_val_00000832.JPEG +ILSVRC2012_val_00043534.JPEG +ILSVRC2012_val_00049387.JPEG +ILSVRC2012_val_00043140.JPEG +ILSVRC2012_val_00022151.JPEG +ILSVRC2012_val_00014128.JPEG +ILSVRC2012_val_00002211.JPEG +ILSVRC2012_val_00025487.JPEG +ILSVRC2012_val_00034507.JPEG +ILSVRC2012_val_00038410.JPEG +ILSVRC2012_val_00028000.JPEG +ILSVRC2012_val_00049130.JPEG +ILSVRC2012_val_00038018.JPEG +ILSVRC2012_val_00002262.JPEG +ILSVRC2012_val_00043992.JPEG +ILSVRC2012_val_00007252.JPEG +ILSVRC2012_val_00022061.JPEG +ILSVRC2012_val_00017137.JPEG +ILSVRC2012_val_00042384.JPEG +ILSVRC2012_val_00039246.JPEG +ILSVRC2012_val_00001875.JPEG +ILSVRC2012_val_00032960.JPEG +ILSVRC2012_val_00003123.JPEG +ILSVRC2012_val_00006903.JPEG +ILSVRC2012_val_00023394.JPEG +ILSVRC2012_val_00001497.JPEG +ILSVRC2012_val_00026018.JPEG +ILSVRC2012_val_00018105.JPEG +ILSVRC2012_val_00035215.JPEG +ILSVRC2012_val_00039759.JPEG +ILSVRC2012_val_00046988.JPEG +ILSVRC2012_val_00012314.JPEG +ILSVRC2012_val_00033769.JPEG +ILSVRC2012_val_00045328.JPEG +ILSVRC2012_val_00016416.JPEG +ILSVRC2012_val_00000850.JPEG +ILSVRC2012_val_00004022.JPEG +ILSVRC2012_val_00036497.JPEG +ILSVRC2012_val_00025198.JPEG +ILSVRC2012_val_00005199.JPEG +ILSVRC2012_val_00012549.JPEG +ILSVRC2012_val_00047046.JPEG +ILSVRC2012_val_00045591.JPEG +ILSVRC2012_val_00014145.JPEG +ILSVRC2012_val_00009287.JPEG +ILSVRC2012_val_00016576.JPEG +ILSVRC2012_val_00046868.JPEG +ILSVRC2012_val_00049634.JPEG +ILSVRC2012_val_00039238.JPEG +ILSVRC2012_val_00033605.JPEG +ILSVRC2012_val_00011518.JPEG +ILSVRC2012_val_00023675.JPEG +ILSVRC2012_val_00029282.JPEG +ILSVRC2012_val_00022657.JPEG +ILSVRC2012_val_00013839.JPEG +ILSVRC2012_val_00013074.JPEG +ILSVRC2012_val_00014987.JPEG +ILSVRC2012_val_00044480.JPEG +ILSVRC2012_val_00014632.JPEG +ILSVRC2012_val_00002935.JPEG +ILSVRC2012_val_00014674.JPEG +ILSVRC2012_val_00039811.JPEG +ILSVRC2012_val_00037474.JPEG +ILSVRC2012_val_00033153.JPEG +ILSVRC2012_val_00005537.JPEG +ILSVRC2012_val_00021388.JPEG +ILSVRC2012_val_00035717.JPEG +ILSVRC2012_val_00027501.JPEG +ILSVRC2012_val_00033262.JPEG +ILSVRC2012_val_00028253.JPEG +ILSVRC2012_val_00006535.JPEG +ILSVRC2012_val_00014201.JPEG +ILSVRC2012_val_00011976.JPEG +ILSVRC2012_val_00019377.JPEG +ILSVRC2012_val_00033251.JPEG +ILSVRC2012_val_00011157.JPEG +ILSVRC2012_val_00008469.JPEG +ILSVRC2012_val_00046268.JPEG +ILSVRC2012_val_00039320.JPEG +ILSVRC2012_val_00015869.JPEG +ILSVRC2012_val_00001651.JPEG +ILSVRC2012_val_00046351.JPEG +ILSVRC2012_val_00024579.JPEG +ILSVRC2012_val_00003919.JPEG +ILSVRC2012_val_00002967.JPEG +ILSVRC2012_val_00044037.JPEG +ILSVRC2012_val_00012832.JPEG +ILSVRC2012_val_00048282.JPEG +ILSVRC2012_val_00013389.JPEG +ILSVRC2012_val_00038304.JPEG +ILSVRC2012_val_00033708.JPEG +ILSVRC2012_val_00032684.JPEG +ILSVRC2012_val_00028251.JPEG +ILSVRC2012_val_00006302.JPEG +ILSVRC2012_val_00018415.JPEG +ILSVRC2012_val_00009804.JPEG +ILSVRC2012_val_00018453.JPEG +ILSVRC2012_val_00025255.JPEG +ILSVRC2012_val_00039287.JPEG +ILSVRC2012_val_00007637.JPEG +ILSVRC2012_val_00023800.JPEG +ILSVRC2012_val_00001300.JPEG +ILSVRC2012_val_00016966.JPEG +ILSVRC2012_val_00005064.JPEG +ILSVRC2012_val_00045068.JPEG +ILSVRC2012_val_00005578.JPEG +ILSVRC2012_val_00036839.JPEG +ILSVRC2012_val_00008215.JPEG +ILSVRC2012_val_00016104.JPEG +ILSVRC2012_val_00006892.JPEG +ILSVRC2012_val_00003713.JPEG +ILSVRC2012_val_00037319.JPEG +ILSVRC2012_val_00009054.JPEG +ILSVRC2012_val_00040830.JPEG +ILSVRC2012_val_00007243.JPEG +ILSVRC2012_val_00030086.JPEG +ILSVRC2012_val_00005930.JPEG +ILSVRC2012_val_00002861.JPEG +ILSVRC2012_val_00024411.JPEG +ILSVRC2012_val_00038356.JPEG +ILSVRC2012_val_00034793.JPEG +ILSVRC2012_val_00012932.JPEG +ILSVRC2012_val_00016076.JPEG +ILSVRC2012_val_00015315.JPEG +ILSVRC2012_val_00011659.JPEG +ILSVRC2012_val_00002102.JPEG +ILSVRC2012_val_00024781.JPEG +ILSVRC2012_val_00017704.JPEG +ILSVRC2012_val_00014979.JPEG +ILSVRC2012_val_00018211.JPEG +ILSVRC2012_val_00045469.JPEG +ILSVRC2012_val_00009434.JPEG +ILSVRC2012_val_00012139.JPEG +ILSVRC2012_val_00042462.JPEG +ILSVRC2012_val_00046771.JPEG +ILSVRC2012_val_00032258.JPEG +ILSVRC2012_val_00041429.JPEG +ILSVRC2012_val_00007600.JPEG +ILSVRC2012_val_00048919.JPEG +ILSVRC2012_val_00045803.JPEG +ILSVRC2012_val_00028896.JPEG +ILSVRC2012_val_00037286.JPEG +ILSVRC2012_val_00042990.JPEG +ILSVRC2012_val_00025106.JPEG +ILSVRC2012_val_00049881.JPEG +ILSVRC2012_val_00044645.JPEG +ILSVRC2012_val_00044742.JPEG +ILSVRC2012_val_00021860.JPEG +ILSVRC2012_val_00034001.JPEG +ILSVRC2012_val_00031221.JPEG +ILSVRC2012_val_00023638.JPEG +ILSVRC2012_val_00028505.JPEG +ILSVRC2012_val_00047751.JPEG +ILSVRC2012_val_00008547.JPEG +ILSVRC2012_val_00041146.JPEG +ILSVRC2012_val_00021899.JPEG +ILSVRC2012_val_00034294.JPEG +ILSVRC2012_val_00043531.JPEG +ILSVRC2012_val_00021923.JPEG +ILSVRC2012_val_00024518.JPEG +ILSVRC2012_val_00009886.JPEG +ILSVRC2012_val_00004272.JPEG +ILSVRC2012_val_00029981.JPEG +ILSVRC2012_val_00034581.JPEG +ILSVRC2012_val_00003886.JPEG +ILSVRC2012_val_00037220.JPEG +ILSVRC2012_val_00003288.JPEG +ILSVRC2012_val_00039046.JPEG +ILSVRC2012_val_00043582.JPEG +ILSVRC2012_val_00030708.JPEG +ILSVRC2012_val_00043035.JPEG +ILSVRC2012_val_00049154.JPEG +ILSVRC2012_val_00028218.JPEG +ILSVRC2012_val_00013923.JPEG +ILSVRC2012_val_00013161.JPEG +ILSVRC2012_val_00014367.JPEG +ILSVRC2012_val_00019367.JPEG +ILSVRC2012_val_00035878.JPEG +ILSVRC2012_val_00009679.JPEG +ILSVRC2012_val_00007082.JPEG +ILSVRC2012_val_00012892.JPEG +ILSVRC2012_val_00033925.JPEG +ILSVRC2012_val_00009413.JPEG +ILSVRC2012_val_00046813.JPEG +ILSVRC2012_val_00001661.JPEG +ILSVRC2012_val_00042285.JPEG +ILSVRC2012_val_00040852.JPEG +ILSVRC2012_val_00015359.JPEG +ILSVRC2012_val_00020957.JPEG +ILSVRC2012_val_00025783.JPEG +ILSVRC2012_val_00030109.JPEG +ILSVRC2012_val_00022774.JPEG +ILSVRC2012_val_00034910.JPEG +ILSVRC2012_val_00007648.JPEG +ILSVRC2012_val_00006623.JPEG +ILSVRC2012_val_00020328.JPEG +ILSVRC2012_val_00002259.JPEG +ILSVRC2012_val_00036937.JPEG +ILSVRC2012_val_00035346.JPEG +ILSVRC2012_val_00026328.JPEG +ILSVRC2012_val_00048543.JPEG +ILSVRC2012_val_00040367.JPEG +ILSVRC2012_val_00040393.JPEG +ILSVRC2012_val_00022065.JPEG +ILSVRC2012_val_00049418.JPEG +ILSVRC2012_val_00045692.JPEG +ILSVRC2012_val_00037178.JPEG +ILSVRC2012_val_00025995.JPEG +ILSVRC2012_val_00026322.JPEG +ILSVRC2012_val_00030106.JPEG +ILSVRC2012_val_00013426.JPEG +ILSVRC2012_val_00021421.JPEG +ILSVRC2012_val_00005431.JPEG +ILSVRC2012_val_00008922.JPEG +ILSVRC2012_val_00041198.JPEG +ILSVRC2012_val_00042589.JPEG +ILSVRC2012_val_00043506.JPEG +ILSVRC2012_val_00026811.JPEG +ILSVRC2012_val_00004301.JPEG +ILSVRC2012_val_00032248.JPEG +ILSVRC2012_val_00039104.JPEG +ILSVRC2012_val_00000640.JPEG +ILSVRC2012_val_00033802.JPEG +ILSVRC2012_val_00021697.JPEG +ILSVRC2012_val_00041294.JPEG +ILSVRC2012_val_00000676.JPEG +ILSVRC2012_val_00036799.JPEG +ILSVRC2012_val_00025483.JPEG +ILSVRC2012_val_00036977.JPEG +ILSVRC2012_val_00042477.JPEG +ILSVRC2012_val_00013783.JPEG +ILSVRC2012_val_00034747.JPEG +ILSVRC2012_val_00047655.JPEG +ILSVRC2012_val_00005727.JPEG +ILSVRC2012_val_00025260.JPEG +ILSVRC2012_val_00007438.JPEG +ILSVRC2012_val_00033935.JPEG +ILSVRC2012_val_00013821.JPEG +ILSVRC2012_val_00000584.JPEG +ILSVRC2012_val_00042493.JPEG +ILSVRC2012_val_00002613.JPEG +ILSVRC2012_val_00015992.JPEG +ILSVRC2012_val_00023276.JPEG +ILSVRC2012_val_00018592.JPEG +ILSVRC2012_val_00035863.JPEG +ILSVRC2012_val_00010278.JPEG +ILSVRC2012_val_00035189.JPEG +ILSVRC2012_val_00037078.JPEG +ILSVRC2012_val_00017111.JPEG +ILSVRC2012_val_00034832.JPEG +ILSVRC2012_val_00045695.JPEG +ILSVRC2012_val_00019131.JPEG +ILSVRC2012_val_00036608.JPEG +ILSVRC2012_val_00029725.JPEG +ILSVRC2012_val_00016272.JPEG +ILSVRC2012_val_00046589.JPEG +ILSVRC2012_val_00010939.JPEG +ILSVRC2012_val_00002535.JPEG +ILSVRC2012_val_00005525.JPEG +ILSVRC2012_val_00011581.JPEG +ILSVRC2012_val_00039225.JPEG +ILSVRC2012_val_00000972.JPEG +ILSVRC2012_val_00035677.JPEG +ILSVRC2012_val_00009131.JPEG +ILSVRC2012_val_00003329.JPEG +ILSVRC2012_val_00028833.JPEG +ILSVRC2012_val_00045011.JPEG +ILSVRC2012_val_00012411.JPEG +ILSVRC2012_val_00004896.JPEG +ILSVRC2012_val_00014937.JPEG +ILSVRC2012_val_00009889.JPEG +ILSVRC2012_val_00037013.JPEG +ILSVRC2012_val_00009597.JPEG +ILSVRC2012_val_00039841.JPEG +ILSVRC2012_val_00017326.JPEG +ILSVRC2012_val_00006370.JPEG +ILSVRC2012_val_00043337.JPEG +ILSVRC2012_val_00004321.JPEG +ILSVRC2012_val_00018478.JPEG +ILSVRC2012_val_00013498.JPEG +ILSVRC2012_val_00003867.JPEG +ILSVRC2012_val_00036188.JPEG +ILSVRC2012_val_00039898.JPEG +ILSVRC2012_val_00000741.JPEG +ILSVRC2012_val_00017926.JPEG +ILSVRC2012_val_00017624.JPEG +ILSVRC2012_val_00021359.JPEG +ILSVRC2012_val_00048263.JPEG +ILSVRC2012_val_00021798.JPEG +ILSVRC2012_val_00031455.JPEG +ILSVRC2012_val_00018175.JPEG +ILSVRC2012_val_00016018.JPEG +ILSVRC2012_val_00046434.JPEG +ILSVRC2012_val_00001793.JPEG +ILSVRC2012_val_00034889.JPEG +ILSVRC2012_val_00014832.JPEG +ILSVRC2012_val_00011726.JPEG +ILSVRC2012_val_00044568.JPEG +ILSVRC2012_val_00035397.JPEG +ILSVRC2012_val_00023681.JPEG +ILSVRC2012_val_00037607.JPEG +ILSVRC2012_val_00021851.JPEG +ILSVRC2012_val_00036327.JPEG +ILSVRC2012_val_00031988.JPEG +ILSVRC2012_val_00046097.JPEG +ILSVRC2012_val_00040164.JPEG +ILSVRC2012_val_00036097.JPEG +ILSVRC2012_val_00005138.JPEG +ILSVRC2012_val_00033940.JPEG +ILSVRC2012_val_00018776.JPEG +ILSVRC2012_val_00010752.JPEG +ILSVRC2012_val_00036066.JPEG +ILSVRC2012_val_00019281.JPEG +ILSVRC2012_val_00005198.JPEG +ILSVRC2012_val_00001956.JPEG +ILSVRC2012_val_00037478.JPEG +ILSVRC2012_val_00034183.JPEG +ILSVRC2012_val_00031429.JPEG +ILSVRC2012_val_00015590.JPEG +ILSVRC2012_val_00008487.JPEG +ILSVRC2012_val_00024805.JPEG +ILSVRC2012_val_00020420.JPEG +ILSVRC2012_val_00043277.JPEG +ILSVRC2012_val_00048623.JPEG +ILSVRC2012_val_00037872.JPEG +ILSVRC2012_val_00013310.JPEG +ILSVRC2012_val_00042311.JPEG +ILSVRC2012_val_00044045.JPEG +ILSVRC2012_val_00033149.JPEG +ILSVRC2012_val_00011465.JPEG +ILSVRC2012_val_00043197.JPEG +ILSVRC2012_val_00037228.JPEG +ILSVRC2012_val_00026343.JPEG +ILSVRC2012_val_00001412.JPEG +ILSVRC2012_val_00042377.JPEG +ILSVRC2012_val_00015452.JPEG +ILSVRC2012_val_00002542.JPEG +ILSVRC2012_val_00018732.JPEG +ILSVRC2012_val_00029756.JPEG +ILSVRC2012_val_00006945.JPEG +ILSVRC2012_val_00047433.JPEG +ILSVRC2012_val_00028849.JPEG +ILSVRC2012_val_00016253.JPEG +ILSVRC2012_val_00014536.JPEG +ILSVRC2012_val_00008927.JPEG +ILSVRC2012_val_00010981.JPEG +ILSVRC2012_val_00018867.JPEG +ILSVRC2012_val_00040832.JPEG +ILSVRC2012_val_00045284.JPEG +ILSVRC2012_val_00011227.JPEG +ILSVRC2012_val_00012428.JPEG +ILSVRC2012_val_00000408.JPEG +ILSVRC2012_val_00002069.JPEG +ILSVRC2012_val_00016719.JPEG +ILSVRC2012_val_00032406.JPEG +ILSVRC2012_val_00023290.JPEG +ILSVRC2012_val_00029525.JPEG +ILSVRC2012_val_00007996.JPEG +ILSVRC2012_val_00002695.JPEG +ILSVRC2012_val_00023476.JPEG +ILSVRC2012_val_00016477.JPEG +ILSVRC2012_val_00012919.JPEG +ILSVRC2012_val_00032894.JPEG +ILSVRC2012_val_00029163.JPEG +ILSVRC2012_val_00045817.JPEG +ILSVRC2012_val_00005733.JPEG +ILSVRC2012_val_00004810.JPEG +ILSVRC2012_val_00032082.JPEG +ILSVRC2012_val_00027534.JPEG +ILSVRC2012_val_00033582.JPEG +ILSVRC2012_val_00023845.JPEG +ILSVRC2012_val_00015973.JPEG +ILSVRC2012_val_00027861.JPEG +ILSVRC2012_val_00027197.JPEG +ILSVRC2012_val_00025990.JPEG +ILSVRC2012_val_00038534.JPEG +ILSVRC2012_val_00000538.JPEG +ILSVRC2012_val_00031063.JPEG +ILSVRC2012_val_00012319.JPEG +ILSVRC2012_val_00040523.JPEG +ILSVRC2012_val_00010722.JPEG +ILSVRC2012_val_00023597.JPEG +ILSVRC2012_val_00036265.JPEG +ILSVRC2012_val_00046430.JPEG +ILSVRC2012_val_00023261.JPEG +ILSVRC2012_val_00043244.JPEG +ILSVRC2012_val_00042362.JPEG +ILSVRC2012_val_00042820.JPEG +ILSVRC2012_val_00001087.JPEG +ILSVRC2012_val_00001765.JPEG +ILSVRC2012_val_00017678.JPEG +ILSVRC2012_val_00020300.JPEG +ILSVRC2012_val_00012326.JPEG +ILSVRC2012_val_00043435.JPEG +ILSVRC2012_val_00039411.JPEG +ILSVRC2012_val_00035465.JPEG +ILSVRC2012_val_00020874.JPEG +ILSVRC2012_val_00042278.JPEG +ILSVRC2012_val_00012823.JPEG +ILSVRC2012_val_00042898.JPEG +ILSVRC2012_val_00029555.JPEG +ILSVRC2012_val_00034876.JPEG +ILSVRC2012_val_00041201.JPEG +ILSVRC2012_val_00002305.JPEG +ILSVRC2012_val_00025555.JPEG +ILSVRC2012_val_00038959.JPEG +ILSVRC2012_val_00046159.JPEG +ILSVRC2012_val_00016861.JPEG +ILSVRC2012_val_00034587.JPEG +ILSVRC2012_val_00049641.JPEG +ILSVRC2012_val_00043950.JPEG +ILSVRC2012_val_00042015.JPEG +ILSVRC2012_val_00007814.JPEG +ILSVRC2012_val_00011031.JPEG +ILSVRC2012_val_00014612.JPEG +ILSVRC2012_val_00031151.JPEG +ILSVRC2012_val_00038109.JPEG +ILSVRC2012_val_00044088.JPEG +ILSVRC2012_val_00021407.JPEG +ILSVRC2012_val_00016533.JPEG +ILSVRC2012_val_00048332.JPEG +ILSVRC2012_val_00008000.JPEG +ILSVRC2012_val_00022665.JPEG +ILSVRC2012_val_00048307.JPEG +ILSVRC2012_val_00048891.JPEG +ILSVRC2012_val_00019852.JPEG +ILSVRC2012_val_00046392.JPEG +ILSVRC2012_val_00016310.JPEG +ILSVRC2012_val_00017278.JPEG +ILSVRC2012_val_00020295.JPEG +ILSVRC2012_val_00003997.JPEG +ILSVRC2012_val_00026287.JPEG +ILSVRC2012_val_00015198.JPEG +ILSVRC2012_val_00027570.JPEG +ILSVRC2012_val_00002781.JPEG +ILSVRC2012_val_00044608.JPEG +ILSVRC2012_val_00019060.JPEG +ILSVRC2012_val_00038736.JPEG +ILSVRC2012_val_00040963.JPEG +ILSVRC2012_val_00017830.JPEG +ILSVRC2012_val_00024490.JPEG +ILSVRC2012_val_00047997.JPEG +ILSVRC2012_val_00017848.JPEG +ILSVRC2012_val_00016739.JPEG +ILSVRC2012_val_00000007.JPEG +ILSVRC2012_val_00008158.JPEG +ILSVRC2012_val_00024590.JPEG +ILSVRC2012_val_00007949.JPEG +ILSVRC2012_val_00043265.JPEG +ILSVRC2012_val_00025466.JPEG +ILSVRC2012_val_00005010.JPEG +ILSVRC2012_val_00047905.JPEG +ILSVRC2012_val_00043630.JPEG +ILSVRC2012_val_00016562.JPEG +ILSVRC2012_val_00003398.JPEG +ILSVRC2012_val_00015303.JPEG +ILSVRC2012_val_00015506.JPEG +ILSVRC2012_val_00019185.JPEG +ILSVRC2012_val_00028190.JPEG +ILSVRC2012_val_00004174.JPEG +ILSVRC2012_val_00007706.JPEG +ILSVRC2012_val_00021925.JPEG +ILSVRC2012_val_00040168.JPEG +ILSVRC2012_val_00036239.JPEG +ILSVRC2012_val_00026085.JPEG +ILSVRC2012_val_00013524.JPEG +ILSVRC2012_val_00002824.JPEG +ILSVRC2012_val_00009481.JPEG +ILSVRC2012_val_00016620.JPEG +ILSVRC2012_val_00033835.JPEG +ILSVRC2012_val_00032340.JPEG +ILSVRC2012_val_00034041.JPEG +ILSVRC2012_val_00018314.JPEG +ILSVRC2012_val_00037179.JPEG +ILSVRC2012_val_00034887.JPEG +ILSVRC2012_val_00017774.JPEG +ILSVRC2012_val_00009150.JPEG +ILSVRC2012_val_00015379.JPEG +ILSVRC2012_val_00027718.JPEG +ILSVRC2012_val_00004909.JPEG +ILSVRC2012_val_00046667.JPEG +ILSVRC2012_val_00025375.JPEG +ILSVRC2012_val_00049046.JPEG +ILSVRC2012_val_00040377.JPEG +ILSVRC2012_val_00000813.JPEG +ILSVRC2012_val_00043741.JPEG +ILSVRC2012_val_00030649.JPEG +ILSVRC2012_val_00030041.JPEG +ILSVRC2012_val_00029215.JPEG +ILSVRC2012_val_00026627.JPEG +ILSVRC2012_val_00042799.JPEG +ILSVRC2012_val_00010080.JPEG +ILSVRC2012_val_00011126.JPEG +ILSVRC2012_val_00000701.JPEG +ILSVRC2012_val_00048640.JPEG +ILSVRC2012_val_00039790.JPEG +ILSVRC2012_val_00011526.JPEG +ILSVRC2012_val_00040771.JPEG +ILSVRC2012_val_00030741.JPEG +ILSVRC2012_val_00047941.JPEG +ILSVRC2012_val_00026715.JPEG +ILSVRC2012_val_00006480.JPEG +ILSVRC2012_val_00015423.JPEG +ILSVRC2012_val_00048920.JPEG +ILSVRC2012_val_00021962.JPEG +ILSVRC2012_val_00045123.JPEG +ILSVRC2012_val_00034014.JPEG +ILSVRC2012_val_00039732.JPEG +ILSVRC2012_val_00029396.JPEG +ILSVRC2012_val_00049761.JPEG +ILSVRC2012_val_00008554.JPEG +ILSVRC2012_val_00009544.JPEG +ILSVRC2012_val_00022706.JPEG +ILSVRC2012_val_00015067.JPEG +ILSVRC2012_val_00022111.JPEG +ILSVRC2012_val_00013753.JPEG +ILSVRC2012_val_00023437.JPEG +ILSVRC2012_val_00020230.JPEG +ILSVRC2012_val_00022695.JPEG +ILSVRC2012_val_00023556.JPEG +ILSVRC2012_val_00033260.JPEG +ILSVRC2012_val_00034869.JPEG +ILSVRC2012_val_00048210.JPEG +ILSVRC2012_val_00023271.JPEG +ILSVRC2012_val_00044259.JPEG +ILSVRC2012_val_00028763.JPEG +ILSVRC2012_val_00039042.JPEG +ILSVRC2012_val_00019533.JPEG +ILSVRC2012_val_00011604.JPEG +ILSVRC2012_val_00031606.JPEG +ILSVRC2012_val_00010515.JPEG +ILSVRC2012_val_00018280.JPEG +ILSVRC2012_val_00035390.JPEG +ILSVRC2012_val_00009250.JPEG +ILSVRC2012_val_00024357.JPEG +ILSVRC2012_val_00004945.JPEG +ILSVRC2012_val_00013656.JPEG +ILSVRC2012_val_00006426.JPEG +ILSVRC2012_val_00014107.JPEG +ILSVRC2012_val_00017468.JPEG +ILSVRC2012_val_00010105.JPEG +ILSVRC2012_val_00020783.JPEG +ILSVRC2012_val_00036102.JPEG +ILSVRC2012_val_00027944.JPEG +ILSVRC2012_val_00043232.JPEG +ILSVRC2012_val_00030122.JPEG +ILSVRC2012_val_00011667.JPEG +ILSVRC2012_val_00028176.JPEG +ILSVRC2012_val_00010395.JPEG +ILSVRC2012_val_00034171.JPEG +ILSVRC2012_val_00025805.JPEG +ILSVRC2012_val_00026844.JPEG +ILSVRC2012_val_00018613.JPEG +ILSVRC2012_val_00045652.JPEG +ILSVRC2012_val_00046440.JPEG +ILSVRC2012_val_00045073.JPEG +ILSVRC2012_val_00039977.JPEG +ILSVRC2012_val_00006598.JPEG +ILSVRC2012_val_00007708.JPEG +ILSVRC2012_val_00005548.JPEG +ILSVRC2012_val_00042847.JPEG +ILSVRC2012_val_00022633.JPEG +ILSVRC2012_val_00013301.JPEG +ILSVRC2012_val_00019176.JPEG +ILSVRC2012_val_00028555.JPEG +ILSVRC2012_val_00030312.JPEG +ILSVRC2012_val_00030491.JPEG +ILSVRC2012_val_00040131.JPEG +ILSVRC2012_val_00035885.JPEG +ILSVRC2012_val_00003773.JPEG +ILSVRC2012_val_00018493.JPEG +ILSVRC2012_val_00032322.JPEG +ILSVRC2012_val_00008842.JPEG +ILSVRC2012_val_00000151.JPEG +ILSVRC2012_val_00043934.JPEG +ILSVRC2012_val_00037913.JPEG +ILSVRC2012_val_00003989.JPEG +ILSVRC2012_val_00041099.JPEG +ILSVRC2012_val_00025018.JPEG +ILSVRC2012_val_00042740.JPEG +ILSVRC2012_val_00043037.JPEG +ILSVRC2012_val_00013606.JPEG +ILSVRC2012_val_00021696.JPEG +ILSVRC2012_val_00030148.JPEG +ILSVRC2012_val_00007642.JPEG +ILSVRC2012_val_00016800.JPEG +ILSVRC2012_val_00014917.JPEG +ILSVRC2012_val_00026741.JPEG +ILSVRC2012_val_00036556.JPEG +ILSVRC2012_val_00032220.JPEG +ILSVRC2012_val_00042101.JPEG +ILSVRC2012_val_00036835.JPEG +ILSVRC2012_val_00009642.JPEG +ILSVRC2012_val_00015884.JPEG +ILSVRC2012_val_00032685.JPEG +ILSVRC2012_val_00038136.JPEG +ILSVRC2012_val_00017441.JPEG +ILSVRC2012_val_00021790.JPEG +ILSVRC2012_val_00037527.JPEG +ILSVRC2012_val_00003913.JPEG +ILSVRC2012_val_00030847.JPEG +ILSVRC2012_val_00019957.JPEG +ILSVRC2012_val_00035803.JPEG +ILSVRC2012_val_00010848.JPEG +ILSVRC2012_val_00029081.JPEG +ILSVRC2012_val_00037781.JPEG +ILSVRC2012_val_00018150.JPEG +ILSVRC2012_val_00025549.JPEG +ILSVRC2012_val_00039997.JPEG +ILSVRC2012_val_00033091.JPEG +ILSVRC2012_val_00001197.JPEG +ILSVRC2012_val_00019917.JPEG +ILSVRC2012_val_00019656.JPEG +ILSVRC2012_val_00015943.JPEG +ILSVRC2012_val_00018599.JPEG +ILSVRC2012_val_00007115.JPEG +ILSVRC2012_val_00009978.JPEG +ILSVRC2012_val_00031996.JPEG +ILSVRC2012_val_00019861.JPEG +ILSVRC2012_val_00044035.JPEG +ILSVRC2012_val_00025362.JPEG +ILSVRC2012_val_00006173.JPEG +ILSVRC2012_val_00022823.JPEG +ILSVRC2012_val_00011548.JPEG +ILSVRC2012_val_00048957.JPEG +ILSVRC2012_val_00034696.JPEG +ILSVRC2012_val_00003917.JPEG +ILSVRC2012_val_00021278.JPEG +ILSVRC2012_val_00009724.JPEG +ILSVRC2012_val_00021233.JPEG +ILSVRC2012_val_00020508.JPEG +ILSVRC2012_val_00040063.JPEG +ILSVRC2012_val_00039723.JPEG +ILSVRC2012_val_00040822.JPEG +ILSVRC2012_val_00047665.JPEG +ILSVRC2012_val_00045685.JPEG +ILSVRC2012_val_00026496.JPEG +ILSVRC2012_val_00004171.JPEG +ILSVRC2012_val_00037372.JPEG +ILSVRC2012_val_00011123.JPEG +ILSVRC2012_val_00037934.JPEG +ILSVRC2012_val_00044138.JPEG +ILSVRC2012_val_00005467.JPEG +ILSVRC2012_val_00003455.JPEG +ILSVRC2012_val_00031310.JPEG +ILSVRC2012_val_00004360.JPEG +ILSVRC2012_val_00030006.JPEG +ILSVRC2012_val_00024457.JPEG +ILSVRC2012_val_00046698.JPEG +ILSVRC2012_val_00009702.JPEG +ILSVRC2012_val_00026143.JPEG +ILSVRC2012_val_00016738.JPEG +ILSVRC2012_val_00004521.JPEG +ILSVRC2012_val_00030193.JPEG +ILSVRC2012_val_00017935.JPEG +ILSVRC2012_val_00004971.JPEG +ILSVRC2012_val_00037526.JPEG +ILSVRC2012_val_00043813.JPEG +ILSVRC2012_val_00029214.JPEG +ILSVRC2012_val_00039217.JPEG +ILSVRC2012_val_00000768.JPEG +ILSVRC2012_val_00015691.JPEG +ILSVRC2012_val_00026030.JPEG +ILSVRC2012_val_00013971.JPEG +ILSVRC2012_val_00006877.JPEG +ILSVRC2012_val_00005660.JPEG +ILSVRC2012_val_00013123.JPEG +ILSVRC2012_val_00039773.JPEG +ILSVRC2012_val_00028955.JPEG +ILSVRC2012_val_00025920.JPEG +ILSVRC2012_val_00007405.JPEG +ILSVRC2012_val_00033619.JPEG +ILSVRC2012_val_00008855.JPEG +ILSVRC2012_val_00032726.JPEG +ILSVRC2012_val_00044311.JPEG +ILSVRC2012_val_00018504.JPEG +ILSVRC2012_val_00037893.JPEG +ILSVRC2012_val_00037155.JPEG +ILSVRC2012_val_00035389.JPEG +ILSVRC2012_val_00045952.JPEG +ILSVRC2012_val_00020696.JPEG +ILSVRC2012_val_00009075.JPEG +ILSVRC2012_val_00040739.JPEG +ILSVRC2012_val_00030388.JPEG +ILSVRC2012_val_00002860.JPEG +ILSVRC2012_val_00049495.JPEG +ILSVRC2012_val_00036344.JPEG +ILSVRC2012_val_00003256.JPEG +ILSVRC2012_val_00027451.JPEG +ILSVRC2012_val_00046013.JPEG +ILSVRC2012_val_00046965.JPEG +ILSVRC2012_val_00042759.JPEG +ILSVRC2012_val_00034002.JPEG +ILSVRC2012_val_00023770.JPEG +ILSVRC2012_val_00038224.JPEG +ILSVRC2012_val_00039105.JPEG +ILSVRC2012_val_00013660.JPEG +ILSVRC2012_val_00011443.JPEG +ILSVRC2012_val_00041718.JPEG +ILSVRC2012_val_00033603.JPEG +ILSVRC2012_val_00042073.JPEG +ILSVRC2012_val_00008938.JPEG +ILSVRC2012_val_00045441.JPEG +ILSVRC2012_val_00019583.JPEG +ILSVRC2012_val_00036478.JPEG +ILSVRC2012_val_00013031.JPEG +ILSVRC2012_val_00011509.JPEG +ILSVRC2012_val_00006751.JPEG +ILSVRC2012_val_00008393.JPEG +ILSVRC2012_val_00019392.JPEG +ILSVRC2012_val_00017521.JPEG +ILSVRC2012_val_00002956.JPEG +ILSVRC2012_val_00011917.JPEG +ILSVRC2012_val_00047946.JPEG +ILSVRC2012_val_00005777.JPEG +ILSVRC2012_val_00026429.JPEG +ILSVRC2012_val_00008500.JPEG +ILSVRC2012_val_00032415.JPEG +ILSVRC2012_val_00029953.JPEG +ILSVRC2012_val_00031913.JPEG +ILSVRC2012_val_00035699.JPEG +ILSVRC2012_val_00013979.JPEG +ILSVRC2012_val_00030252.JPEG +ILSVRC2012_val_00000002.JPEG +ILSVRC2012_val_00000249.JPEG +ILSVRC2012_val_00024868.JPEG +ILSVRC2012_val_00000746.JPEG +ILSVRC2012_val_00042299.JPEG +ILSVRC2012_val_00025202.JPEG +ILSVRC2012_val_00011821.JPEG +ILSVRC2012_val_00011823.JPEG +ILSVRC2012_val_00039350.JPEG +ILSVRC2012_val_00033621.JPEG +ILSVRC2012_val_00037042.JPEG +ILSVRC2012_val_00001225.JPEG +ILSVRC2012_val_00031402.JPEG +ILSVRC2012_val_00041030.JPEG +ILSVRC2012_val_00022640.JPEG +ILSVRC2012_val_00009604.JPEG +ILSVRC2012_val_00017887.JPEG +ILSVRC2012_val_00046914.JPEG +ILSVRC2012_val_00033330.JPEG +ILSVRC2012_val_00017722.JPEG +ILSVRC2012_val_00046098.JPEG +ILSVRC2012_val_00040989.JPEG +ILSVRC2012_val_00038606.JPEG +ILSVRC2012_val_00049360.JPEG +ILSVRC2012_val_00040100.JPEG +ILSVRC2012_val_00012515.JPEG +ILSVRC2012_val_00031511.JPEG +ILSVRC2012_val_00000063.JPEG +ILSVRC2012_val_00027789.JPEG +ILSVRC2012_val_00003099.JPEG +ILSVRC2012_val_00023599.JPEG +ILSVRC2012_val_00033530.JPEG +ILSVRC2012_val_00039334.JPEG +ILSVRC2012_val_00034502.JPEG +ILSVRC2012_val_00023765.JPEG +ILSVRC2012_val_00004850.JPEG +ILSVRC2012_val_00013734.JPEG +ILSVRC2012_val_00042960.JPEG +ILSVRC2012_val_00033570.JPEG +ILSVRC2012_val_00036463.JPEG +ILSVRC2012_val_00017703.JPEG +ILSVRC2012_val_00034923.JPEG +ILSVRC2012_val_00019086.JPEG +ILSVRC2012_val_00035992.JPEG +ILSVRC2012_val_00026581.JPEG +ILSVRC2012_val_00019749.JPEG +ILSVRC2012_val_00030202.JPEG +ILSVRC2012_val_00020436.JPEG +ILSVRC2012_val_00007309.JPEG +ILSVRC2012_val_00035494.JPEG +ILSVRC2012_val_00026793.JPEG +ILSVRC2012_val_00014550.JPEG +ILSVRC2012_val_00038161.JPEG +ILSVRC2012_val_00041986.JPEG +ILSVRC2012_val_00019191.JPEG +ILSVRC2012_val_00003257.JPEG +ILSVRC2012_val_00042534.JPEG +ILSVRC2012_val_00012204.JPEG +ILSVRC2012_val_00018832.JPEG +ILSVRC2012_val_00012226.JPEG +ILSVRC2012_val_00041502.JPEG +ILSVRC2012_val_00022322.JPEG +ILSVRC2012_val_00017550.JPEG +ILSVRC2012_val_00048829.JPEG +ILSVRC2012_val_00043492.JPEG +ILSVRC2012_val_00023031.JPEG +ILSVRC2012_val_00049887.JPEG +ILSVRC2012_val_00043151.JPEG +ILSVRC2012_val_00018210.JPEG +ILSVRC2012_val_00011060.JPEG +ILSVRC2012_val_00027583.JPEG +ILSVRC2012_val_00014022.JPEG +ILSVRC2012_val_00048060.JPEG +ILSVRC2012_val_00029797.JPEG +ILSVRC2012_val_00001215.JPEG +ILSVRC2012_val_00031263.JPEG +ILSVRC2012_val_00001408.JPEG +ILSVRC2012_val_00006449.JPEG +ILSVRC2012_val_00027335.JPEG +ILSVRC2012_val_00044155.JPEG +ILSVRC2012_val_00009628.JPEG +ILSVRC2012_val_00049251.JPEG +ILSVRC2012_val_00022372.JPEG +ILSVRC2012_val_00022974.JPEG +ILSVRC2012_val_00021480.JPEG +ILSVRC2012_val_00022084.JPEG +ILSVRC2012_val_00011021.JPEG +ILSVRC2012_val_00021862.JPEG +ILSVRC2012_val_00018916.JPEG +ILSVRC2012_val_00030438.JPEG +ILSVRC2012_val_00003538.JPEG +ILSVRC2012_val_00011813.JPEG +ILSVRC2012_val_00041916.JPEG +ILSVRC2012_val_00048309.JPEG +ILSVRC2012_val_00013402.JPEG +ILSVRC2012_val_00049415.JPEG +ILSVRC2012_val_00010180.JPEG +ILSVRC2012_val_00046481.JPEG +ILSVRC2012_val_00046623.JPEG +ILSVRC2012_val_00019188.JPEG +ILSVRC2012_val_00029558.JPEG +ILSVRC2012_val_00015857.JPEG +ILSVRC2012_val_00042830.JPEG +ILSVRC2012_val_00028580.JPEG +ILSVRC2012_val_00010304.JPEG +ILSVRC2012_val_00048142.JPEG +ILSVRC2012_val_00012781.JPEG +ILSVRC2012_val_00014713.JPEG +ILSVRC2012_val_00020200.JPEG +ILSVRC2012_val_00008176.JPEG +ILSVRC2012_val_00026935.JPEG +ILSVRC2012_val_00008485.JPEG +ILSVRC2012_val_00037193.JPEG +ILSVRC2012_val_00041900.JPEG +ILSVRC2012_val_00007301.JPEG +ILSVRC2012_val_00004249.JPEG +ILSVRC2012_val_00010186.JPEG +ILSVRC2012_val_00002263.JPEG +ILSVRC2012_val_00023103.JPEG +ILSVRC2012_val_00018000.JPEG +ILSVRC2012_val_00037533.JPEG +ILSVRC2012_val_00022510.JPEG +ILSVRC2012_val_00002720.JPEG +ILSVRC2012_val_00026682.JPEG +ILSVRC2012_val_00042967.JPEG +ILSVRC2012_val_00025359.JPEG +ILSVRC2012_val_00049733.JPEG +ILSVRC2012_val_00025565.JPEG +ILSVRC2012_val_00017306.JPEG +ILSVRC2012_val_00021377.JPEG +ILSVRC2012_val_00047861.JPEG +ILSVRC2012_val_00031099.JPEG +ILSVRC2012_val_00042069.JPEG +ILSVRC2012_val_00012595.JPEG +ILSVRC2012_val_00039834.JPEG +ILSVRC2012_val_00015377.JPEG +ILSVRC2012_val_00014046.JPEG +ILSVRC2012_val_00034051.JPEG +ILSVRC2012_val_00006873.JPEG +ILSVRC2012_val_00032656.JPEG +ILSVRC2012_val_00028007.JPEG +ILSVRC2012_val_00035359.JPEG +ILSVRC2012_val_00031873.JPEG +ILSVRC2012_val_00022043.JPEG +ILSVRC2012_val_00037948.JPEG +ILSVRC2012_val_00003247.JPEG +ILSVRC2012_val_00047504.JPEG +ILSVRC2012_val_00040452.JPEG +ILSVRC2012_val_00033075.JPEG +ILSVRC2012_val_00026605.JPEG +ILSVRC2012_val_00033053.JPEG +ILSVRC2012_val_00027979.JPEG +ILSVRC2012_val_00021478.JPEG +ILSVRC2012_val_00023061.JPEG +ILSVRC2012_val_00047144.JPEG +ILSVRC2012_val_00021470.JPEG +ILSVRC2012_val_00007220.JPEG +ILSVRC2012_val_00028592.JPEG +ILSVRC2012_val_00027154.JPEG +ILSVRC2012_val_00047516.JPEG +ILSVRC2012_val_00041814.JPEG +ILSVRC2012_val_00033942.JPEG +ILSVRC2012_val_00045097.JPEG +ILSVRC2012_val_00021309.JPEG +ILSVRC2012_val_00031713.JPEG +ILSVRC2012_val_00028385.JPEG +ILSVRC2012_val_00003924.JPEG +ILSVRC2012_val_00026537.JPEG +ILSVRC2012_val_00035575.JPEG +ILSVRC2012_val_00031255.JPEG +ILSVRC2012_val_00007129.JPEG +ILSVRC2012_val_00008454.JPEG +ILSVRC2012_val_00040310.JPEG +ILSVRC2012_val_00003672.JPEG +ILSVRC2012_val_00024218.JPEG +ILSVRC2012_val_00014890.JPEG +ILSVRC2012_val_00001309.JPEG +ILSVRC2012_val_00016827.JPEG +ILSVRC2012_val_00039597.JPEG +ILSVRC2012_val_00039115.JPEG +ILSVRC2012_val_00001111.JPEG +ILSVRC2012_val_00033182.JPEG +ILSVRC2012_val_00011423.JPEG +ILSVRC2012_val_00049656.JPEG +ILSVRC2012_val_00024267.JPEG +ILSVRC2012_val_00039240.JPEG +ILSVRC2012_val_00031881.JPEG +ILSVRC2012_val_00035369.JPEG +ILSVRC2012_val_00015008.JPEG +ILSVRC2012_val_00030774.JPEG +ILSVRC2012_val_00038595.JPEG +ILSVRC2012_val_00021479.JPEG +ILSVRC2012_val_00040700.JPEG +ILSVRC2012_val_00034549.JPEG +ILSVRC2012_val_00034417.JPEG +ILSVRC2012_val_00006986.JPEG +ILSVRC2012_val_00045019.JPEG +ILSVRC2012_val_00045171.JPEG +ILSVRC2012_val_00006626.JPEG +ILSVRC2012_val_00026932.JPEG +ILSVRC2012_val_00004447.JPEG +ILSVRC2012_val_00003591.JPEG +ILSVRC2012_val_00018806.JPEG +ILSVRC2012_val_00009713.JPEG +ILSVRC2012_val_00012685.JPEG +ILSVRC2012_val_00018192.JPEG +ILSVRC2012_val_00022154.JPEG +ILSVRC2012_val_00026218.JPEG +ILSVRC2012_val_00001202.JPEG +ILSVRC2012_val_00028922.JPEG +ILSVRC2012_val_00018299.JPEG +ILSVRC2012_val_00028400.JPEG +ILSVRC2012_val_00022498.JPEG +ILSVRC2012_val_00002800.JPEG +ILSVRC2012_val_00040443.JPEG +ILSVRC2012_val_00006725.JPEG +ILSVRC2012_val_00049358.JPEG +ILSVRC2012_val_00044275.JPEG +ILSVRC2012_val_00027021.JPEG +ILSVRC2012_val_00033014.JPEG +ILSVRC2012_val_00011543.JPEG +ILSVRC2012_val_00009659.JPEG +ILSVRC2012_val_00024484.JPEG +ILSVRC2012_val_00016832.JPEG +ILSVRC2012_val_00047057.JPEG +ILSVRC2012_val_00013913.JPEG +ILSVRC2012_val_00048594.JPEG +ILSVRC2012_val_00034579.JPEG +ILSVRC2012_val_00017197.JPEG +ILSVRC2012_val_00045680.JPEG +ILSVRC2012_val_00016504.JPEG +ILSVRC2012_val_00040253.JPEG +ILSVRC2012_val_00031865.JPEG +ILSVRC2012_val_00039173.JPEG +ILSVRC2012_val_00046283.JPEG +ILSVRC2012_val_00034635.JPEG +ILSVRC2012_val_00041962.JPEG +ILSVRC2012_val_00046309.JPEG +ILSVRC2012_val_00044748.JPEG +ILSVRC2012_val_00033677.JPEG +ILSVRC2012_val_00040090.JPEG +ILSVRC2012_val_00003964.JPEG +ILSVRC2012_val_00033617.JPEG +ILSVRC2012_val_00046636.JPEG +ILSVRC2012_val_00010208.JPEG +ILSVRC2012_val_00020616.JPEG +ILSVRC2012_val_00019755.JPEG +ILSVRC2012_val_00025190.JPEG +ILSVRC2012_val_00011660.JPEG +ILSVRC2012_val_00010980.JPEG +ILSVRC2012_val_00027738.JPEG +ILSVRC2012_val_00007407.JPEG +ILSVRC2012_val_00027767.JPEG +ILSVRC2012_val_00036709.JPEG +ILSVRC2012_val_00017960.JPEG +ILSVRC2012_val_00044210.JPEG +ILSVRC2012_val_00047592.JPEG +ILSVRC2012_val_00013477.JPEG +ILSVRC2012_val_00046335.JPEG +ILSVRC2012_val_00031119.JPEG +ILSVRC2012_val_00014251.JPEG +ILSVRC2012_val_00030507.JPEG +ILSVRC2012_val_00014294.JPEG +ILSVRC2012_val_00002661.JPEG +ILSVRC2012_val_00013706.JPEG +ILSVRC2012_val_00045757.JPEG +ILSVRC2012_val_00044579.JPEG +ILSVRC2012_val_00014794.JPEG +ILSVRC2012_val_00032113.JPEG +ILSVRC2012_val_00020273.JPEG +ILSVRC2012_val_00019194.JPEG +ILSVRC2012_val_00006226.JPEG +ILSVRC2012_val_00023819.JPEG +ILSVRC2012_val_00009454.JPEG +ILSVRC2012_val_00001694.JPEG +ILSVRC2012_val_00049875.JPEG +ILSVRC2012_val_00037422.JPEG +ILSVRC2012_val_00027200.JPEG +ILSVRC2012_val_00016601.JPEG +ILSVRC2012_val_00018466.JPEG +ILSVRC2012_val_00026461.JPEG +ILSVRC2012_val_00034253.JPEG +ILSVRC2012_val_00039198.JPEG +ILSVRC2012_val_00037684.JPEG +ILSVRC2012_val_00029359.JPEG +ILSVRC2012_val_00039126.JPEG +ILSVRC2012_val_00035121.JPEG +ILSVRC2012_val_00015535.JPEG +ILSVRC2012_val_00043716.JPEG +ILSVRC2012_val_00016745.JPEG +ILSVRC2012_val_00024215.JPEG +ILSVRC2012_val_00041629.JPEG +ILSVRC2012_val_00036155.JPEG +ILSVRC2012_val_00029893.JPEG +ILSVRC2012_val_00042097.JPEG +ILSVRC2012_val_00048932.JPEG +ILSVRC2012_val_00001627.JPEG +ILSVRC2012_val_00014189.JPEG +ILSVRC2012_val_00013213.JPEG +ILSVRC2012_val_00006850.JPEG +ILSVRC2012_val_00048203.JPEG +ILSVRC2012_val_00022452.JPEG +ILSVRC2012_val_00003160.JPEG +ILSVRC2012_val_00039910.JPEG +ILSVRC2012_val_00029912.JPEG +ILSVRC2012_val_00038263.JPEG +ILSVRC2012_val_00013167.JPEG +ILSVRC2012_val_00016072.JPEG +ILSVRC2012_val_00040084.JPEG +ILSVRC2012_val_00020101.JPEG +ILSVRC2012_val_00033436.JPEG +ILSVRC2012_val_00017349.JPEG +ILSVRC2012_val_00000127.JPEG +ILSVRC2012_val_00044962.JPEG +ILSVRC2012_val_00029389.JPEG +ILSVRC2012_val_00028303.JPEG +ILSVRC2012_val_00044342.JPEG +ILSVRC2012_val_00010850.JPEG +ILSVRC2012_val_00048519.JPEG +ILSVRC2012_val_00036144.JPEG +ILSVRC2012_val_00002877.JPEG +ILSVRC2012_val_00019010.JPEG +ILSVRC2012_val_00012936.JPEG +ILSVRC2012_val_00034484.JPEG +ILSVRC2012_val_00017857.JPEG +ILSVRC2012_val_00044344.JPEG +ILSVRC2012_val_00047646.JPEG +ILSVRC2012_val_00018187.JPEG +ILSVRC2012_val_00046922.JPEG +ILSVRC2012_val_00036215.JPEG +ILSVRC2012_val_00036531.JPEG +ILSVRC2012_val_00025011.JPEG +ILSVRC2012_val_00008329.JPEG +ILSVRC2012_val_00008603.JPEG +ILSVRC2012_val_00010589.JPEG +ILSVRC2012_val_00024808.JPEG +ILSVRC2012_val_00027010.JPEG +ILSVRC2012_val_00002155.JPEG +ILSVRC2012_val_00007965.JPEG +ILSVRC2012_val_00043215.JPEG +ILSVRC2012_val_00011850.JPEG +ILSVRC2012_val_00045827.JPEG +ILSVRC2012_val_00006399.JPEG +ILSVRC2012_val_00027151.JPEG +ILSVRC2012_val_00036553.JPEG +ILSVRC2012_val_00031971.JPEG +ILSVRC2012_val_00012419.JPEG +ILSVRC2012_val_00023320.JPEG +ILSVRC2012_val_00039879.JPEG +ILSVRC2012_val_00049440.JPEG +ILSVRC2012_val_00002208.JPEG +ILSVRC2012_val_00020194.JPEG +ILSVRC2012_val_00034750.JPEG +ILSVRC2012_val_00013873.JPEG +ILSVRC2012_val_00030380.JPEG +ILSVRC2012_val_00018324.JPEG +ILSVRC2012_val_00001547.JPEG +ILSVRC2012_val_00048402.JPEG +ILSVRC2012_val_00013462.JPEG +ILSVRC2012_val_00014753.JPEG +ILSVRC2012_val_00008464.JPEG +ILSVRC2012_val_00004005.JPEG +ILSVRC2012_val_00004210.JPEG +ILSVRC2012_val_00023149.JPEG +ILSVRC2012_val_00022880.JPEG +ILSVRC2012_val_00020548.JPEG +ILSVRC2012_val_00015765.JPEG +ILSVRC2012_val_00029325.JPEG +ILSVRC2012_val_00025625.JPEG +ILSVRC2012_val_00010698.JPEG +ILSVRC2012_val_00017142.JPEG +ILSVRC2012_val_00006021.JPEG +ILSVRC2012_val_00043655.JPEG +ILSVRC2012_val_00004718.JPEG +ILSVRC2012_val_00026994.JPEG +ILSVRC2012_val_00015108.JPEG +ILSVRC2012_val_00031857.JPEG +ILSVRC2012_val_00005342.JPEG +ILSVRC2012_val_00022196.JPEG +ILSVRC2012_val_00044817.JPEG +ILSVRC2012_val_00049354.JPEG +ILSVRC2012_val_00023664.JPEG +ILSVRC2012_val_00005434.JPEG +ILSVRC2012_val_00025434.JPEG +ILSVRC2012_val_00049450.JPEG +ILSVRC2012_val_00043589.JPEG +ILSVRC2012_val_00016161.JPEG +ILSVRC2012_val_00032953.JPEG +ILSVRC2012_val_00032295.JPEG +ILSVRC2012_val_00035905.JPEG +ILSVRC2012_val_00000432.JPEG +ILSVRC2012_val_00001400.JPEG +ILSVRC2012_val_00008612.JPEG +ILSVRC2012_val_00020404.JPEG +ILSVRC2012_val_00049372.JPEG +ILSVRC2012_val_00046608.JPEG +ILSVRC2012_val_00030469.JPEG +ILSVRC2012_val_00017257.JPEG +ILSVRC2012_val_00036589.JPEG +ILSVRC2012_val_00039988.JPEG +ILSVRC2012_val_00003127.JPEG +ILSVRC2012_val_00022513.JPEG +ILSVRC2012_val_00026121.JPEG +ILSVRC2012_val_00041435.JPEG +ILSVRC2012_val_00016540.JPEG +ILSVRC2012_val_00039264.JPEG +ILSVRC2012_val_00023257.JPEG +ILSVRC2012_val_00003296.JPEG +ILSVRC2012_val_00039274.JPEG +ILSVRC2012_val_00013522.JPEG +ILSVRC2012_val_00000142.JPEG +ILSVRC2012_val_00042563.JPEG +ILSVRC2012_val_00031699.JPEG +ILSVRC2012_val_00024106.JPEG +ILSVRC2012_val_00043605.JPEG +ILSVRC2012_val_00015044.JPEG +ILSVRC2012_val_00035566.JPEG +ILSVRC2012_val_00010551.JPEG +ILSVRC2012_val_00009509.JPEG +ILSVRC2012_val_00044025.JPEG +ILSVRC2012_val_00041457.JPEG +ILSVRC2012_val_00011337.JPEG +ILSVRC2012_val_00018817.JPEG +ILSVRC2012_val_00001514.JPEG +ILSVRC2012_val_00009207.JPEG +ILSVRC2012_val_00034862.JPEG +ILSVRC2012_val_00024506.JPEG +ILSVRC2012_val_00009623.JPEG +ILSVRC2012_val_00042796.JPEG +ILSVRC2012_val_00008336.JPEG +ILSVRC2012_val_00002768.JPEG +ILSVRC2012_val_00006342.JPEG +ILSVRC2012_val_00037332.JPEG +ILSVRC2012_val_00000953.JPEG +ILSVRC2012_val_00010010.JPEG +ILSVRC2012_val_00027278.JPEG +ILSVRC2012_val_00039665.JPEG +ILSVRC2012_val_00016192.JPEG +ILSVRC2012_val_00016357.JPEG +ILSVRC2012_val_00039631.JPEG +ILSVRC2012_val_00016212.JPEG +ILSVRC2012_val_00022097.JPEG +ILSVRC2012_val_00043217.JPEG +ILSVRC2012_val_00017181.JPEG +ILSVRC2012_val_00042165.JPEG +ILSVRC2012_val_00003311.JPEG +ILSVRC2012_val_00024256.JPEG +ILSVRC2012_val_00012751.JPEG +ILSVRC2012_val_00041663.JPEG +ILSVRC2012_val_00025404.JPEG +ILSVRC2012_val_00004150.JPEG +ILSVRC2012_val_00000264.JPEG +ILSVRC2012_val_00039373.JPEG +ILSVRC2012_val_00041121.JPEG +ILSVRC2012_val_00045936.JPEG +ILSVRC2012_val_00027325.JPEG +ILSVRC2012_val_00029260.JPEG +ILSVRC2012_val_00026313.JPEG +ILSVRC2012_val_00012699.JPEG +ILSVRC2012_val_00034694.JPEG +ILSVRC2012_val_00005143.JPEG +ILSVRC2012_val_00006307.JPEG +ILSVRC2012_val_00046872.JPEG +ILSVRC2012_val_00013641.JPEG +ILSVRC2012_val_00032044.JPEG +ILSVRC2012_val_00048346.JPEG +ILSVRC2012_val_00012222.JPEG +ILSVRC2012_val_00018795.JPEG +ILSVRC2012_val_00032371.JPEG +ILSVRC2012_val_00020208.JPEG +ILSVRC2012_val_00044014.JPEG +ILSVRC2012_val_00008482.JPEG +ILSVRC2012_val_00041116.JPEG +ILSVRC2012_val_00016812.JPEG +ILSVRC2012_val_00048799.JPEG +ILSVRC2012_val_00017315.JPEG +ILSVRC2012_val_00030619.JPEG +ILSVRC2012_val_00002801.JPEG +ILSVRC2012_val_00047117.JPEG +ILSVRC2012_val_00009314.JPEG +ILSVRC2012_val_00024301.JPEG +ILSVRC2012_val_00038550.JPEG +ILSVRC2012_val_00019064.JPEG +ILSVRC2012_val_00026907.JPEG +ILSVRC2012_val_00045064.JPEG +ILSVRC2012_val_00009298.JPEG +ILSVRC2012_val_00015649.JPEG +ILSVRC2012_val_00024736.JPEG +ILSVRC2012_val_00003846.JPEG +ILSVRC2012_val_00023875.JPEG +ILSVRC2012_val_00020888.JPEG +ILSVRC2012_val_00043434.JPEG +ILSVRC2012_val_00031492.JPEG +ILSVRC2012_val_00037451.JPEG +ILSVRC2012_val_00040197.JPEG +ILSVRC2012_val_00039620.JPEG +ILSVRC2012_val_00027406.JPEG +ILSVRC2012_val_00007624.JPEG +ILSVRC2012_val_00017394.JPEG +ILSVRC2012_val_00028734.JPEG +ILSVRC2012_val_00022831.JPEG +ILSVRC2012_val_00037785.JPEG +ILSVRC2012_val_00028788.JPEG +ILSVRC2012_val_00020536.JPEG +ILSVRC2012_val_00020850.JPEG +ILSVRC2012_val_00048599.JPEG +ILSVRC2012_val_00045907.JPEG +ILSVRC2012_val_00044768.JPEG +ILSVRC2012_val_00009949.JPEG +ILSVRC2012_val_00008787.JPEG +ILSVRC2012_val_00018258.JPEG +ILSVRC2012_val_00011663.JPEG +ILSVRC2012_val_00019565.JPEG +ILSVRC2012_val_00047562.JPEG +ILSVRC2012_val_00034209.JPEG +ILSVRC2012_val_00027263.JPEG +ILSVRC2012_val_00033845.JPEG +ILSVRC2012_val_00048850.JPEG +ILSVRC2012_val_00042094.JPEG +ILSVRC2012_val_00011818.JPEG +ILSVRC2012_val_00042615.JPEG +ILSVRC2012_val_00014744.JPEG +ILSVRC2012_val_00003113.JPEG +ILSVRC2012_val_00024137.JPEG +ILSVRC2012_val_00029187.JPEG +ILSVRC2012_val_00015432.JPEG +ILSVRC2012_val_00008661.JPEG +ILSVRC2012_val_00047873.JPEG +ILSVRC2012_val_00027153.JPEG +ILSVRC2012_val_00028475.JPEG +ILSVRC2012_val_00016343.JPEG +ILSVRC2012_val_00043242.JPEG +ILSVRC2012_val_00002889.JPEG +ILSVRC2012_val_00007235.JPEG +ILSVRC2012_val_00028480.JPEG +ILSVRC2012_val_00021361.JPEG +ILSVRC2012_val_00010173.JPEG +ILSVRC2012_val_00036213.JPEG +ILSVRC2012_val_00008468.JPEG +ILSVRC2012_val_00001095.JPEG +ILSVRC2012_val_00029547.JPEG +ILSVRC2012_val_00018840.JPEG +ILSVRC2012_val_00044828.JPEG +ILSVRC2012_val_00019206.JPEG +ILSVRC2012_val_00025868.JPEG +ILSVRC2012_val_00006190.JPEG +ILSVRC2012_val_00031178.JPEG +ILSVRC2012_val_00033301.JPEG +ILSVRC2012_val_00037678.JPEG +ILSVRC2012_val_00015072.JPEG +ILSVRC2012_val_00034668.JPEG +ILSVRC2012_val_00036975.JPEG +ILSVRC2012_val_00019641.JPEG +ILSVRC2012_val_00034269.JPEG +ILSVRC2012_val_00015374.JPEG +ILSVRC2012_val_00006112.JPEG +ILSVRC2012_val_00025769.JPEG +ILSVRC2012_val_00017523.JPEG +ILSVRC2012_val_00010676.JPEG +ILSVRC2012_val_00010986.JPEG +ILSVRC2012_val_00029917.JPEG +ILSVRC2012_val_00010552.JPEG +ILSVRC2012_val_00003502.JPEG +ILSVRC2012_val_00001983.JPEG +ILSVRC2012_val_00047022.JPEG +ILSVRC2012_val_00028227.JPEG +ILSVRC2012_val_00031145.JPEG +ILSVRC2012_val_00002522.JPEG +ILSVRC2012_val_00021449.JPEG +ILSVRC2012_val_00009153.JPEG +ILSVRC2012_val_00004889.JPEG +ILSVRC2012_val_00035894.JPEG +ILSVRC2012_val_00014470.JPEG \ No newline at end of file diff --git a/big_vision/evaluators/proj/uvim/coltran_fid_data/reference_file_names.txt b/big_vision/evaluators/proj/uvim/coltran_fid_data/reference_file_names.txt new file mode 100644 index 0000000000000000000000000000000000000000..8c4faa20f48bfb29e6f8c4e4d70823795c846672 --- /dev/null +++ b/big_vision/evaluators/proj/uvim/coltran_fid_data/reference_file_names.txt @@ -0,0 +1,5000 @@ +ILSVRC2012_val_00013328.JPEG +ILSVRC2012_val_00004694.JPEG +ILSVRC2012_val_00003310.JPEG +ILSVRC2012_val_00013349.JPEG +ILSVRC2012_val_00020452.JPEG +ILSVRC2012_val_00015433.JPEG +ILSVRC2012_val_00039333.JPEG +ILSVRC2012_val_00035263.JPEG +ILSVRC2012_val_00040982.JPEG +ILSVRC2012_val_00035719.JPEG +ILSVRC2012_val_00022596.JPEG +ILSVRC2012_val_00033811.JPEG +ILSVRC2012_val_00012865.JPEG +ILSVRC2012_val_00003372.JPEG +ILSVRC2012_val_00015268.JPEG +ILSVRC2012_val_00030296.JPEG +ILSVRC2012_val_00000872.JPEG +ILSVRC2012_val_00005382.JPEG +ILSVRC2012_val_00027656.JPEG +ILSVRC2012_val_00019275.JPEG +ILSVRC2012_val_00007499.JPEG +ILSVRC2012_val_00000185.JPEG +ILSVRC2012_val_00012704.JPEG +ILSVRC2012_val_00016046.JPEG +ILSVRC2012_val_00014027.JPEG +ILSVRC2012_val_00030748.JPEG +ILSVRC2012_val_00001459.JPEG +ILSVRC2012_val_00021147.JPEG +ILSVRC2012_val_00046948.JPEG +ILSVRC2012_val_00028531.JPEG +ILSVRC2012_val_00004539.JPEG +ILSVRC2012_val_00007052.JPEG +ILSVRC2012_val_00031553.JPEG +ILSVRC2012_val_00000052.JPEG +ILSVRC2012_val_00010415.JPEG +ILSVRC2012_val_00012947.JPEG +ILSVRC2012_val_00037959.JPEG +ILSVRC2012_val_00032520.JPEG +ILSVRC2012_val_00019509.JPEG +ILSVRC2012_val_00008496.JPEG +ILSVRC2012_val_00041804.JPEG +ILSVRC2012_val_00007337.JPEG +ILSVRC2012_val_00022134.JPEG +ILSVRC2012_val_00021473.JPEG +ILSVRC2012_val_00042479.JPEG +ILSVRC2012_val_00016086.JPEG +ILSVRC2012_val_00012076.JPEG +ILSVRC2012_val_00043842.JPEG +ILSVRC2012_val_00022098.JPEG +ILSVRC2012_val_00008187.JPEG +ILSVRC2012_val_00001139.JPEG +ILSVRC2012_val_00044871.JPEG +ILSVRC2012_val_00046803.JPEG +ILSVRC2012_val_00033729.JPEG +ILSVRC2012_val_00049602.JPEG +ILSVRC2012_val_00005731.JPEG +ILSVRC2012_val_00010441.JPEG +ILSVRC2012_val_00044442.JPEG +ILSVRC2012_val_00005255.JPEG +ILSVRC2012_val_00047726.JPEG +ILSVRC2012_val_00016222.JPEG +ILSVRC2012_val_00004195.JPEG +ILSVRC2012_val_00013645.JPEG +ILSVRC2012_val_00006255.JPEG +ILSVRC2012_val_00029684.JPEG +ILSVRC2012_val_00031407.JPEG +ILSVRC2012_val_00012361.JPEG +ILSVRC2012_val_00045108.JPEG +ILSVRC2012_val_00049224.JPEG +ILSVRC2012_val_00018743.JPEG +ILSVRC2012_val_00013493.JPEG +ILSVRC2012_val_00036398.JPEG +ILSVRC2012_val_00006869.JPEG +ILSVRC2012_val_00020791.JPEG +ILSVRC2012_val_00031228.JPEG +ILSVRC2012_val_00033066.JPEG +ILSVRC2012_val_00019237.JPEG +ILSVRC2012_val_00028829.JPEG +ILSVRC2012_val_00020190.JPEG +ILSVRC2012_val_00039176.JPEG +ILSVRC2012_val_00029959.JPEG +ILSVRC2012_val_00011778.JPEG +ILSVRC2012_val_00036200.JPEG +ILSVRC2012_val_00016974.JPEG +ILSVRC2012_val_00036309.JPEG +ILSVRC2012_val_00045377.JPEG +ILSVRC2012_val_00032040.JPEG +ILSVRC2012_val_00014636.JPEG +ILSVRC2012_val_00005688.JPEG +ILSVRC2012_val_00043394.JPEG +ILSVRC2012_val_00045810.JPEG +ILSVRC2012_val_00014678.JPEG +ILSVRC2012_val_00031641.JPEG +ILSVRC2012_val_00023449.JPEG +ILSVRC2012_val_00040348.JPEG +ILSVRC2012_val_00040341.JPEG +ILSVRC2012_val_00033165.JPEG +ILSVRC2012_val_00040844.JPEG +ILSVRC2012_val_00020275.JPEG +ILSVRC2012_val_00036696.JPEG +ILSVRC2012_val_00006550.JPEG +ILSVRC2012_val_00028867.JPEG +ILSVRC2012_val_00012770.JPEG +ILSVRC2012_val_00014127.JPEG +ILSVRC2012_val_00016361.JPEG +ILSVRC2012_val_00009377.JPEG +ILSVRC2012_val_00009420.JPEG +ILSVRC2012_val_00002079.JPEG +ILSVRC2012_val_00048545.JPEG +ILSVRC2012_val_00000254.JPEG +ILSVRC2012_val_00038946.JPEG +ILSVRC2012_val_00048195.JPEG +ILSVRC2012_val_00021193.JPEG +ILSVRC2012_val_00019822.JPEG +ILSVRC2012_val_00010096.JPEG +ILSVRC2012_val_00005640.JPEG +ILSVRC2012_val_00001719.JPEG +ILSVRC2012_val_00001133.JPEG +ILSVRC2012_val_00020074.JPEG +ILSVRC2012_val_00001860.JPEG +ILSVRC2012_val_00024352.JPEG +ILSVRC2012_val_00001885.JPEG +ILSVRC2012_val_00039482.JPEG +ILSVRC2012_val_00012018.JPEG +ILSVRC2012_val_00025839.JPEG +ILSVRC2012_val_00031129.JPEG +ILSVRC2012_val_00012574.JPEG +ILSVRC2012_val_00039938.JPEG +ILSVRC2012_val_00045970.JPEG +ILSVRC2012_val_00047004.JPEG +ILSVRC2012_val_00048687.JPEG +ILSVRC2012_val_00014727.JPEG +ILSVRC2012_val_00032578.JPEG +ILSVRC2012_val_00013640.JPEG +ILSVRC2012_val_00042617.JPEG +ILSVRC2012_val_00036829.JPEG +ILSVRC2012_val_00000810.JPEG +ILSVRC2012_val_00002450.JPEG +ILSVRC2012_val_00028800.JPEG +ILSVRC2012_val_00000146.JPEG +ILSVRC2012_val_00043510.JPEG +ILSVRC2012_val_00037849.JPEG +ILSVRC2012_val_00040470.JPEG +ILSVRC2012_val_00013798.JPEG +ILSVRC2012_val_00013390.JPEG +ILSVRC2012_val_00005308.JPEG +ILSVRC2012_val_00013492.JPEG +ILSVRC2012_val_00031463.JPEG +ILSVRC2012_val_00002149.JPEG +ILSVRC2012_val_00002249.JPEG +ILSVRC2012_val_00036568.JPEG +ILSVRC2012_val_00009829.JPEG +ILSVRC2012_val_00008771.JPEG +ILSVRC2012_val_00049780.JPEG +ILSVRC2012_val_00043443.JPEG +ILSVRC2012_val_00048502.JPEG +ILSVRC2012_val_00048709.JPEG +ILSVRC2012_val_00041308.JPEG +ILSVRC2012_val_00020565.JPEG +ILSVRC2012_val_00005961.JPEG +ILSVRC2012_val_00038221.JPEG +ILSVRC2012_val_00036623.JPEG +ILSVRC2012_val_00004032.JPEG +ILSVRC2012_val_00007345.JPEG +ILSVRC2012_val_00045071.JPEG +ILSVRC2012_val_00015176.JPEG +ILSVRC2012_val_00039111.JPEG +ILSVRC2012_val_00044927.JPEG +ILSVRC2012_val_00046333.JPEG +ILSVRC2012_val_00032433.JPEG +ILSVRC2012_val_00001469.JPEG +ILSVRC2012_val_00048878.JPEG +ILSVRC2012_val_00012356.JPEG +ILSVRC2012_val_00016038.JPEG +ILSVRC2012_val_00013199.JPEG +ILSVRC2012_val_00022162.JPEG +ILSVRC2012_val_00011582.JPEG +ILSVRC2012_val_00031488.JPEG +ILSVRC2012_val_00028854.JPEG +ILSVRC2012_val_00016352.JPEG +ILSVRC2012_val_00038173.JPEG +ILSVRC2012_val_00036319.JPEG +ILSVRC2012_val_00037175.JPEG +ILSVRC2012_val_00037432.JPEG +ILSVRC2012_val_00047920.JPEG +ILSVRC2012_val_00025699.JPEG +ILSVRC2012_val_00016401.JPEG +ILSVRC2012_val_00024466.JPEG +ILSVRC2012_val_00018903.JPEG +ILSVRC2012_val_00020774.JPEG +ILSVRC2012_val_00028132.JPEG +ILSVRC2012_val_00023014.JPEG +ILSVRC2012_val_00018632.JPEG +ILSVRC2012_val_00037601.JPEG +ILSVRC2012_val_00007931.JPEG +ILSVRC2012_val_00049718.JPEG +ILSVRC2012_val_00007455.JPEG +ILSVRC2012_val_00040279.JPEG +ILSVRC2012_val_00046160.JPEG +ILSVRC2012_val_00036153.JPEG +ILSVRC2012_val_00015636.JPEG +ILSVRC2012_val_00019977.JPEG +ILSVRC2012_val_00013131.JPEG +ILSVRC2012_val_00008197.JPEG +ILSVRC2012_val_00020310.JPEG +ILSVRC2012_val_00013001.JPEG +ILSVRC2012_val_00030418.JPEG +ILSVRC2012_val_00039147.JPEG +ILSVRC2012_val_00009474.JPEG +ILSVRC2012_val_00047626.JPEG +ILSVRC2012_val_00018852.JPEG +ILSVRC2012_val_00032244.JPEG +ILSVRC2012_val_00026260.JPEG +ILSVRC2012_val_00026168.JPEG +ILSVRC2012_val_00040284.JPEG +ILSVRC2012_val_00020104.JPEG +ILSVRC2012_val_00010432.JPEG +ILSVRC2012_val_00030982.JPEG +ILSVRC2012_val_00047787.JPEG +ILSVRC2012_val_00027703.JPEG +ILSVRC2012_val_00011851.JPEG +ILSVRC2012_val_00047036.JPEG +ILSVRC2012_val_00034914.JPEG +ILSVRC2012_val_00027009.JPEG +ILSVRC2012_val_00046641.JPEG +ILSVRC2012_val_00049817.JPEG +ILSVRC2012_val_00011935.JPEG +ILSVRC2012_val_00028469.JPEG +ILSVRC2012_val_00009408.JPEG +ILSVRC2012_val_00013004.JPEG +ILSVRC2012_val_00024904.JPEG +ILSVRC2012_val_00049922.JPEG +ILSVRC2012_val_00019697.JPEG +ILSVRC2012_val_00035775.JPEG +ILSVRC2012_val_00031556.JPEG +ILSVRC2012_val_00045696.JPEG +ILSVRC2012_val_00045262.JPEG +ILSVRC2012_val_00031628.JPEG +ILSVRC2012_val_00014104.JPEG +ILSVRC2012_val_00015490.JPEG +ILSVRC2012_val_00033609.JPEG +ILSVRC2012_val_00004806.JPEG +ILSVRC2012_val_00023445.JPEG +ILSVRC2012_val_00023326.JPEG +ILSVRC2012_val_00023901.JPEG +ILSVRC2012_val_00000415.JPEG +ILSVRC2012_val_00031661.JPEG +ILSVRC2012_val_00048406.JPEG +ILSVRC2012_val_00001601.JPEG +ILSVRC2012_val_00038759.JPEG +ILSVRC2012_val_00032195.JPEG +ILSVRC2012_val_00026862.JPEG +ILSVRC2012_val_00032633.JPEG +ILSVRC2012_val_00011208.JPEG +ILSVRC2012_val_00034221.JPEG +ILSVRC2012_val_00017102.JPEG +ILSVRC2012_val_00021565.JPEG +ILSVRC2012_val_00003858.JPEG +ILSVRC2012_val_00015391.JPEG +ILSVRC2012_val_00030691.JPEG +ILSVRC2012_val_00024900.JPEG +ILSVRC2012_val_00023382.JPEG +ILSVRC2012_val_00015981.JPEG +ILSVRC2012_val_00011930.JPEG +ILSVRC2012_val_00013544.JPEG +ILSVRC2012_val_00012859.JPEG +ILSVRC2012_val_00030222.JPEG +ILSVRC2012_val_00018005.JPEG +ILSVRC2012_val_00017577.JPEG +ILSVRC2012_val_00017990.JPEG +ILSVRC2012_val_00017411.JPEG +ILSVRC2012_val_00026116.JPEG +ILSVRC2012_val_00004732.JPEG +ILSVRC2012_val_00034439.JPEG +ILSVRC2012_val_00029943.JPEG +ILSVRC2012_val_00021347.JPEG +ILSVRC2012_val_00036789.JPEG +ILSVRC2012_val_00037660.JPEG +ILSVRC2012_val_00003619.JPEG +ILSVRC2012_val_00006007.JPEG +ILSVRC2012_val_00029155.JPEG +ILSVRC2012_val_00018470.JPEG +ILSVRC2012_val_00019468.JPEG +ILSVRC2012_val_00040620.JPEG +ILSVRC2012_val_00037112.JPEG +ILSVRC2012_val_00035282.JPEG +ILSVRC2012_val_00043815.JPEG +ILSVRC2012_val_00049331.JPEG +ILSVRC2012_val_00004630.JPEG +ILSVRC2012_val_00011301.JPEG +ILSVRC2012_val_00005362.JPEG +ILSVRC2012_val_00027871.JPEG +ILSVRC2012_val_00014763.JPEG +ILSVRC2012_val_00020871.JPEG +ILSVRC2012_val_00006257.JPEG +ILSVRC2012_val_00018205.JPEG +ILSVRC2012_val_00049529.JPEG +ILSVRC2012_val_00046735.JPEG +ILSVRC2012_val_00013541.JPEG +ILSVRC2012_val_00030979.JPEG +ILSVRC2012_val_00011006.JPEG +ILSVRC2012_val_00004503.JPEG +ILSVRC2012_val_00003820.JPEG +ILSVRC2012_val_00018799.JPEG +ILSVRC2012_val_00025967.JPEG +ILSVRC2012_val_00021991.JPEG +ILSVRC2012_val_00016479.JPEG +ILSVRC2012_val_00048801.JPEG +ILSVRC2012_val_00026953.JPEG +ILSVRC2012_val_00043834.JPEG +ILSVRC2012_val_00037570.JPEG +ILSVRC2012_val_00006073.JPEG +ILSVRC2012_val_00043273.JPEG +ILSVRC2012_val_00019660.JPEG +ILSVRC2012_val_00009255.JPEG +ILSVRC2012_val_00038850.JPEG +ILSVRC2012_val_00025518.JPEG +ILSVRC2012_val_00000141.JPEG +ILSVRC2012_val_00008068.JPEG +ILSVRC2012_val_00043958.JPEG +ILSVRC2012_val_00008819.JPEG +ILSVRC2012_val_00049792.JPEG +ILSVRC2012_val_00043262.JPEG +ILSVRC2012_val_00002568.JPEG +ILSVRC2012_val_00009693.JPEG +ILSVRC2012_val_00035164.JPEG +ILSVRC2012_val_00018668.JPEG +ILSVRC2012_val_00008893.JPEG +ILSVRC2012_val_00044853.JPEG +ILSVRC2012_val_00017452.JPEG +ILSVRC2012_val_00026717.JPEG +ILSVRC2012_val_00022723.JPEG +ILSVRC2012_val_00037281.JPEG +ILSVRC2012_val_00001781.JPEG +ILSVRC2012_val_00041738.JPEG +ILSVRC2012_val_00031383.JPEG +ILSVRC2012_val_00021551.JPEG +ILSVRC2012_val_00037482.JPEG +ILSVRC2012_val_00013219.JPEG +ILSVRC2012_val_00026667.JPEG +ILSVRC2012_val_00004125.JPEG +ILSVRC2012_val_00013258.JPEG +ILSVRC2012_val_00017875.JPEG +ILSVRC2012_val_00018075.JPEG +ILSVRC2012_val_00039140.JPEG +ILSVRC2012_val_00038464.JPEG +ILSVRC2012_val_00016215.JPEG +ILSVRC2012_val_00009944.JPEG +ILSVRC2012_val_00000652.JPEG +ILSVRC2012_val_00042080.JPEG +ILSVRC2012_val_00028831.JPEG +ILSVRC2012_val_00016598.JPEG +ILSVRC2012_val_00030277.JPEG +ILSVRC2012_val_00013475.JPEG +ILSVRC2012_val_00048101.JPEG +ILSVRC2012_val_00048907.JPEG +ILSVRC2012_val_00023981.JPEG +ILSVRC2012_val_00042433.JPEG +ILSVRC2012_val_00019444.JPEG +ILSVRC2012_val_00027383.JPEG +ILSVRC2012_val_00031473.JPEG +ILSVRC2012_val_00040874.JPEG +ILSVRC2012_val_00026102.JPEG +ILSVRC2012_val_00007559.JPEG +ILSVRC2012_val_00030454.JPEG +ILSVRC2012_val_00005157.JPEG +ILSVRC2012_val_00006143.JPEG +ILSVRC2012_val_00043168.JPEG +ILSVRC2012_val_00039060.JPEG +ILSVRC2012_val_00010301.JPEG +ILSVRC2012_val_00043584.JPEG +ILSVRC2012_val_00011728.JPEG +ILSVRC2012_val_00034195.JPEG +ILSVRC2012_val_00026980.JPEG +ILSVRC2012_val_00021896.JPEG +ILSVRC2012_val_00038169.JPEG +ILSVRC2012_val_00022481.JPEG +ILSVRC2012_val_00046712.JPEG +ILSVRC2012_val_00017338.JPEG +ILSVRC2012_val_00002958.JPEG +ILSVRC2012_val_00005191.JPEG +ILSVRC2012_val_00012353.JPEG +ILSVRC2012_val_00010460.JPEG +ILSVRC2012_val_00010603.JPEG +ILSVRC2012_val_00010746.JPEG +ILSVRC2012_val_00025094.JPEG +ILSVRC2012_val_00036942.JPEG +ILSVRC2012_val_00022762.JPEG +ILSVRC2012_val_00000197.JPEG +ILSVRC2012_val_00043579.JPEG +ILSVRC2012_val_00042051.JPEG +ILSVRC2012_val_00029050.JPEG +ILSVRC2012_val_00047557.JPEG +ILSVRC2012_val_00025645.JPEG +ILSVRC2012_val_00040607.JPEG +ILSVRC2012_val_00046645.JPEG +ILSVRC2012_val_00012109.JPEG +ILSVRC2012_val_00000096.JPEG +ILSVRC2012_val_00049252.JPEG +ILSVRC2012_val_00030869.JPEG +ILSVRC2012_val_00004865.JPEG +ILSVRC2012_val_00031418.JPEG +ILSVRC2012_val_00005550.JPEG +ILSVRC2012_val_00038057.JPEG +ILSVRC2012_val_00029723.JPEG +ILSVRC2012_val_00042153.JPEG +ILSVRC2012_val_00020339.JPEG +ILSVRC2012_val_00024141.JPEG +ILSVRC2012_val_00036085.JPEG +ILSVRC2012_val_00043160.JPEG +ILSVRC2012_val_00012720.JPEG +ILSVRC2012_val_00008933.JPEG +ILSVRC2012_val_00006335.JPEG +ILSVRC2012_val_00021204.JPEG +ILSVRC2012_val_00046353.JPEG +ILSVRC2012_val_00003776.JPEG +ILSVRC2012_val_00014170.JPEG +ILSVRC2012_val_00044746.JPEG +ILSVRC2012_val_00003216.JPEG +ILSVRC2012_val_00049777.JPEG +ILSVRC2012_val_00001540.JPEG +ILSVRC2012_val_00022168.JPEG +ILSVRC2012_val_00048843.JPEG +ILSVRC2012_val_00033131.JPEG +ILSVRC2012_val_00011687.JPEG +ILSVRC2012_val_00019213.JPEG +ILSVRC2012_val_00024975.JPEG +ILSVRC2012_val_00044339.JPEG +ILSVRC2012_val_00005806.JPEG +ILSVRC2012_val_00023026.JPEG +ILSVRC2012_val_00030853.JPEG +ILSVRC2012_val_00023861.JPEG +ILSVRC2012_val_00046455.JPEG +ILSVRC2012_val_00044393.JPEG +ILSVRC2012_val_00017040.JPEG +ILSVRC2012_val_00048292.JPEG +ILSVRC2012_val_00043459.JPEG +ILSVRC2012_val_00026858.JPEG +ILSVRC2012_val_00007285.JPEG +ILSVRC2012_val_00001862.JPEG +ILSVRC2012_val_00042415.JPEG +ILSVRC2012_val_00003452.JPEG +ILSVRC2012_val_00029466.JPEG +ILSVRC2012_val_00003005.JPEG +ILSVRC2012_val_00009494.JPEG +ILSVRC2012_val_00044722.JPEG +ILSVRC2012_val_00042329.JPEG +ILSVRC2012_val_00030807.JPEG +ILSVRC2012_val_00033536.JPEG +ILSVRC2012_val_00044872.JPEG +ILSVRC2012_val_00014964.JPEG +ILSVRC2012_val_00014256.JPEG +ILSVRC2012_val_00004170.JPEG +ILSVRC2012_val_00025334.JPEG +ILSVRC2012_val_00003406.JPEG +ILSVRC2012_val_00034174.JPEG +ILSVRC2012_val_00008848.JPEG +ILSVRC2012_val_00016537.JPEG +ILSVRC2012_val_00000033.JPEG +ILSVRC2012_val_00023443.JPEG +ILSVRC2012_val_00034067.JPEG +ILSVRC2012_val_00037479.JPEG +ILSVRC2012_val_00032622.JPEG +ILSVRC2012_val_00021389.JPEG +ILSVRC2012_val_00015747.JPEG +ILSVRC2012_val_00022116.JPEG +ILSVRC2012_val_00009714.JPEG +ILSVRC2012_val_00032229.JPEG +ILSVRC2012_val_00015642.JPEG +ILSVRC2012_val_00037000.JPEG +ILSVRC2012_val_00007695.JPEG +ILSVRC2012_val_00015392.JPEG +ILSVRC2012_val_00020355.JPEG +ILSVRC2012_val_00027357.JPEG +ILSVRC2012_val_00009705.JPEG +ILSVRC2012_val_00023495.JPEG +ILSVRC2012_val_00033003.JPEG +ILSVRC2012_val_00010103.JPEG +ILSVRC2012_val_00029847.JPEG +ILSVRC2012_val_00037718.JPEG +ILSVRC2012_val_00029644.JPEG +ILSVRC2012_val_00046509.JPEG +ILSVRC2012_val_00017656.JPEG +ILSVRC2012_val_00045212.JPEG +ILSVRC2012_val_00032062.JPEG +ILSVRC2012_val_00010095.JPEG +ILSVRC2012_val_00002946.JPEG +ILSVRC2012_val_00008392.JPEG +ILSVRC2012_val_00035667.JPEG +ILSVRC2012_val_00022626.JPEG +ILSVRC2012_val_00026119.JPEG +ILSVRC2012_val_00027545.JPEG +ILSVRC2012_val_00035934.JPEG +ILSVRC2012_val_00018136.JPEG +ILSVRC2012_val_00042711.JPEG +ILSVRC2012_val_00007536.JPEG +ILSVRC2012_val_00038740.JPEG +ILSVRC2012_val_00031926.JPEG +ILSVRC2012_val_00029345.JPEG +ILSVRC2012_val_00027520.JPEG +ILSVRC2012_val_00040858.JPEG +ILSVRC2012_val_00023646.JPEG +ILSVRC2012_val_00032650.JPEG +ILSVRC2012_val_00009026.JPEG +ILSVRC2012_val_00003345.JPEG +ILSVRC2012_val_00008413.JPEG +ILSVRC2012_val_00024318.JPEG +ILSVRC2012_val_00001210.JPEG +ILSVRC2012_val_00040288.JPEG +ILSVRC2012_val_00014168.JPEG +ILSVRC2012_val_00047428.JPEG +ILSVRC2012_val_00004307.JPEG +ILSVRC2012_val_00046472.JPEG +ILSVRC2012_val_00049379.JPEG +ILSVRC2012_val_00032817.JPEG +ILSVRC2012_val_00012418.JPEG +ILSVRC2012_val_00024498.JPEG +ILSVRC2012_val_00014487.JPEG +ILSVRC2012_val_00043557.JPEG +ILSVRC2012_val_00028422.JPEG +ILSVRC2012_val_00011171.JPEG +ILSVRC2012_val_00031391.JPEG +ILSVRC2012_val_00022619.JPEG +ILSVRC2012_val_00021014.JPEG +ILSVRC2012_val_00037745.JPEG +ILSVRC2012_val_00025754.JPEG +ILSVRC2012_val_00043945.JPEG +ILSVRC2012_val_00031022.JPEG +ILSVRC2012_val_00044597.JPEG +ILSVRC2012_val_00003095.JPEG +ILSVRC2012_val_00003261.JPEG +ILSVRC2012_val_00049784.JPEG +ILSVRC2012_val_00004474.JPEG +ILSVRC2012_val_00042535.JPEG +ILSVRC2012_val_00004488.JPEG +ILSVRC2012_val_00008735.JPEG +ILSVRC2012_val_00048223.JPEG +ILSVRC2012_val_00000061.JPEG +ILSVRC2012_val_00016239.JPEG +ILSVRC2012_val_00012542.JPEG +ILSVRC2012_val_00021794.JPEG +ILSVRC2012_val_00023532.JPEG +ILSVRC2012_val_00028411.JPEG +ILSVRC2012_val_00046427.JPEG +ILSVRC2012_val_00010531.JPEG +ILSVRC2012_val_00029585.JPEG +ILSVRC2012_val_00003187.JPEG +ILSVRC2012_val_00027146.JPEG +ILSVRC2012_val_00043587.JPEG +ILSVRC2012_val_00023413.JPEG +ILSVRC2012_val_00006372.JPEG +ILSVRC2012_val_00028036.JPEG +ILSVRC2012_val_00002455.JPEG +ILSVRC2012_val_00018932.JPEG +ILSVRC2012_val_00015674.JPEG +ILSVRC2012_val_00040049.JPEG +ILSVRC2012_val_00048880.JPEG +ILSVRC2012_val_00026134.JPEG +ILSVRC2012_val_00042736.JPEG +ILSVRC2012_val_00021531.JPEG +ILSVRC2012_val_00038649.JPEG +ILSVRC2012_val_00027518.JPEG +ILSVRC2012_val_00040016.JPEG +ILSVRC2012_val_00000964.JPEG +ILSVRC2012_val_00001245.JPEG +ILSVRC2012_val_00047893.JPEG +ILSVRC2012_val_00008989.JPEG +ILSVRC2012_val_00012004.JPEG +ILSVRC2012_val_00001246.JPEG +ILSVRC2012_val_00034390.JPEG +ILSVRC2012_val_00023032.JPEG +ILSVRC2012_val_00021748.JPEG +ILSVRC2012_val_00004393.JPEG +ILSVRC2012_val_00002282.JPEG +ILSVRC2012_val_00047645.JPEG +ILSVRC2012_val_00037661.JPEG +ILSVRC2012_val_00024855.JPEG +ILSVRC2012_val_00013482.JPEG +ILSVRC2012_val_00010475.JPEG +ILSVRC2012_val_00006027.JPEG +ILSVRC2012_val_00046760.JPEG +ILSVRC2012_val_00013471.JPEG +ILSVRC2012_val_00019314.JPEG +ILSVRC2012_val_00032652.JPEG +ILSVRC2012_val_00003065.JPEG +ILSVRC2012_val_00027455.JPEG +ILSVRC2012_val_00013957.JPEG +ILSVRC2012_val_00041348.JPEG +ILSVRC2012_val_00043270.JPEG +ILSVRC2012_val_00024909.JPEG +ILSVRC2012_val_00011724.JPEG +ILSVRC2012_val_00018445.JPEG +ILSVRC2012_val_00035671.JPEG +ILSVRC2012_val_00020578.JPEG +ILSVRC2012_val_00021418.JPEG +ILSVRC2012_val_00017242.JPEG +ILSVRC2012_val_00039311.JPEG +ILSVRC2012_val_00035391.JPEG +ILSVRC2012_val_00038461.JPEG +ILSVRC2012_val_00021274.JPEG +ILSVRC2012_val_00040670.JPEG +ILSVRC2012_val_00032153.JPEG +ILSVRC2012_val_00007142.JPEG +ILSVRC2012_val_00017765.JPEG +ILSVRC2012_val_00040587.JPEG +ILSVRC2012_val_00025493.JPEG +ILSVRC2012_val_00049322.JPEG +ILSVRC2012_val_00012376.JPEG +ILSVRC2012_val_00008728.JPEG +ILSVRC2012_val_00046573.JPEG +ILSVRC2012_val_00013374.JPEG +ILSVRC2012_val_00036914.JPEG +ILSVRC2012_val_00019963.JPEG +ILSVRC2012_val_00004382.JPEG +ILSVRC2012_val_00046751.JPEG +ILSVRC2012_val_00011422.JPEG +ILSVRC2012_val_00041235.JPEG +ILSVRC2012_val_00034671.JPEG +ILSVRC2012_val_00008471.JPEG +ILSVRC2012_val_00004615.JPEG +ILSVRC2012_val_00043965.JPEG +ILSVRC2012_val_00031171.JPEG +ILSVRC2012_val_00011906.JPEG +ILSVRC2012_val_00003002.JPEG +ILSVRC2012_val_00032735.JPEG +ILSVRC2012_val_00038157.JPEG +ILSVRC2012_val_00035637.JPEG +ILSVRC2012_val_00013439.JPEG +ILSVRC2012_val_00023431.JPEG +ILSVRC2012_val_00013782.JPEG +ILSVRC2012_val_00028347.JPEG +ILSVRC2012_val_00000106.JPEG +ILSVRC2012_val_00012958.JPEG +ILSVRC2012_val_00048730.JPEG +ILSVRC2012_val_00021293.JPEG +ILSVRC2012_val_00048357.JPEG +ILSVRC2012_val_00019361.JPEG +ILSVRC2012_val_00033830.JPEG +ILSVRC2012_val_00021712.JPEG +ILSVRC2012_val_00031799.JPEG +ILSVRC2012_val_00025750.JPEG +ILSVRC2012_val_00011113.JPEG +ILSVRC2012_val_00032075.JPEG +ILSVRC2012_val_00025937.JPEG +ILSVRC2012_val_00028029.JPEG +ILSVRC2012_val_00019636.JPEG +ILSVRC2012_val_00024501.JPEG +ILSVRC2012_val_00048757.JPEG +ILSVRC2012_val_00034592.JPEG +ILSVRC2012_val_00039260.JPEG +ILSVRC2012_val_00007394.JPEG +ILSVRC2012_val_00033093.JPEG +ILSVRC2012_val_00045565.JPEG +ILSVRC2012_val_00030855.JPEG +ILSVRC2012_val_00021626.JPEG +ILSVRC2012_val_00014051.JPEG +ILSVRC2012_val_00009183.JPEG +ILSVRC2012_val_00025021.JPEG +ILSVRC2012_val_00023096.JPEG +ILSVRC2012_val_00017681.JPEG +ILSVRC2012_val_00014345.JPEG +ILSVRC2012_val_00015115.JPEG +ILSVRC2012_val_00008333.JPEG +ILSVRC2012_val_00020503.JPEG +ILSVRC2012_val_00029377.JPEG +ILSVRC2012_val_00007577.JPEG +ILSVRC2012_val_00034982.JPEG +ILSVRC2012_val_00018824.JPEG +ILSVRC2012_val_00019263.JPEG +ILSVRC2012_val_00024669.JPEG +ILSVRC2012_val_00010241.JPEG +ILSVRC2012_val_00022155.JPEG +ILSVRC2012_val_00039279.JPEG +ILSVRC2012_val_00049327.JPEG +ILSVRC2012_val_00017593.JPEG +ILSVRC2012_val_00020975.JPEG +ILSVRC2012_val_00020843.JPEG +ILSVRC2012_val_00049234.JPEG +ILSVRC2012_val_00045672.JPEG +ILSVRC2012_val_00028180.JPEG +ILSVRC2012_val_00014089.JPEG +ILSVRC2012_val_00028845.JPEG +ILSVRC2012_val_00014178.JPEG +ILSVRC2012_val_00004291.JPEG +ILSVRC2012_val_00049144.JPEG +ILSVRC2012_val_00029182.JPEG +ILSVRC2012_val_00005314.JPEG +ILSVRC2012_val_00024361.JPEG +ILSVRC2012_val_00004773.JPEG +ILSVRC2012_val_00022108.JPEG +ILSVRC2012_val_00001718.JPEG +ILSVRC2012_val_00045977.JPEG +ILSVRC2012_val_00024617.JPEG +ILSVRC2012_val_00045051.JPEG +ILSVRC2012_val_00000957.JPEG +ILSVRC2012_val_00035873.JPEG +ILSVRC2012_val_00017502.JPEG +ILSVRC2012_val_00047733.JPEG +ILSVRC2012_val_00045112.JPEG +ILSVRC2012_val_00012019.JPEG +ILSVRC2012_val_00022905.JPEG +ILSVRC2012_val_00035993.JPEG +ILSVRC2012_val_00001808.JPEG +ILSVRC2012_val_00029765.JPEG +ILSVRC2012_val_00034236.JPEG +ILSVRC2012_val_00013272.JPEG +ILSVRC2012_val_00040912.JPEG +ILSVRC2012_val_00016753.JPEG +ILSVRC2012_val_00023155.JPEG +ILSVRC2012_val_00014580.JPEG +ILSVRC2012_val_00019964.JPEG +ILSVRC2012_val_00036636.JPEG +ILSVRC2012_val_00043122.JPEG +ILSVRC2012_val_00048027.JPEG +ILSVRC2012_val_00017355.JPEG +ILSVRC2012_val_00014113.JPEG +ILSVRC2012_val_00017609.JPEG +ILSVRC2012_val_00049040.JPEG +ILSVRC2012_val_00026775.JPEG +ILSVRC2012_val_00043421.JPEG +ILSVRC2012_val_00032876.JPEG +ILSVRC2012_val_00022636.JPEG +ILSVRC2012_val_00020020.JPEG +ILSVRC2012_val_00000335.JPEG +ILSVRC2012_val_00018120.JPEG +ILSVRC2012_val_00046791.JPEG +ILSVRC2012_val_00036034.JPEG +ILSVRC2012_val_00011776.JPEG +ILSVRC2012_val_00014485.JPEG +ILSVRC2012_val_00024108.JPEG +ILSVRC2012_val_00001662.JPEG +ILSVRC2012_val_00035450.JPEG +ILSVRC2012_val_00015730.JPEG +ILSVRC2012_val_00031574.JPEG +ILSVRC2012_val_00033714.JPEG +ILSVRC2012_val_00024711.JPEG +ILSVRC2012_val_00046949.JPEG +ILSVRC2012_val_00034448.JPEG +ILSVRC2012_val_00047984.JPEG +ILSVRC2012_val_00029957.JPEG +ILSVRC2012_val_00039937.JPEG +ILSVRC2012_val_00043380.JPEG +ILSVRC2012_val_00030600.JPEG +ILSVRC2012_val_00004588.JPEG +ILSVRC2012_val_00023248.JPEG +ILSVRC2012_val_00012479.JPEG +ILSVRC2012_val_00049933.JPEG +ILSVRC2012_val_00001443.JPEG +ILSVRC2012_val_00000110.JPEG +ILSVRC2012_val_00049937.JPEG +ILSVRC2012_val_00042268.JPEG +ILSVRC2012_val_00049565.JPEG +ILSVRC2012_val_00042029.JPEG +ILSVRC2012_val_00035629.JPEG +ILSVRC2012_val_00022964.JPEG +ILSVRC2012_val_00012830.JPEG +ILSVRC2012_val_00025724.JPEG +ILSVRC2012_val_00032899.JPEG +ILSVRC2012_val_00010183.JPEG +ILSVRC2012_val_00015458.JPEG +ILSVRC2012_val_00000034.JPEG +ILSVRC2012_val_00018697.JPEG +ILSVRC2012_val_00007565.JPEG +ILSVRC2012_val_00020730.JPEG +ILSVRC2012_val_00016443.JPEG +ILSVRC2012_val_00033391.JPEG +ILSVRC2012_val_00015386.JPEG +ILSVRC2012_val_00029529.JPEG +ILSVRC2012_val_00040022.JPEG +ILSVRC2012_val_00009184.JPEG +ILSVRC2012_val_00019396.JPEG +ILSVRC2012_val_00008580.JPEG +ILSVRC2012_val_00049789.JPEG +ILSVRC2012_val_00049460.JPEG +ILSVRC2012_val_00011552.JPEG +ILSVRC2012_val_00011049.JPEG +ILSVRC2012_val_00037289.JPEG +ILSVRC2012_val_00016035.JPEG +ILSVRC2012_val_00015091.JPEG +ILSVRC2012_val_00049413.JPEG +ILSVRC2012_val_00001110.JPEG +ILSVRC2012_val_00029186.JPEG +ILSVRC2012_val_00035556.JPEG +ILSVRC2012_val_00043502.JPEG +ILSVRC2012_val_00011212.JPEG +ILSVRC2012_val_00048885.JPEG +ILSVRC2012_val_00044047.JPEG +ILSVRC2012_val_00014318.JPEG +ILSVRC2012_val_00014572.JPEG +ILSVRC2012_val_00047917.JPEG +ILSVRC2012_val_00045678.JPEG +ILSVRC2012_val_00043381.JPEG +ILSVRC2012_val_00005031.JPEG +ILSVRC2012_val_00043843.JPEG +ILSVRC2012_val_00047304.JPEG +ILSVRC2012_val_00003823.JPEG +ILSVRC2012_val_00049192.JPEG +ILSVRC2012_val_00035205.JPEG +ILSVRC2012_val_00035857.JPEG +ILSVRC2012_val_00030657.JPEG +ILSVRC2012_val_00048081.JPEG +ILSVRC2012_val_00004305.JPEG +ILSVRC2012_val_00013827.JPEG +ILSVRC2012_val_00024881.JPEG +ILSVRC2012_val_00027233.JPEG +ILSVRC2012_val_00015903.JPEG +ILSVRC2012_val_00033345.JPEG +ILSVRC2012_val_00014879.JPEG +ILSVRC2012_val_00011437.JPEG +ILSVRC2012_val_00007687.JPEG +ILSVRC2012_val_00031954.JPEG +ILSVRC2012_val_00037649.JPEG +ILSVRC2012_val_00014304.JPEG +ILSVRC2012_val_00035790.JPEG +ILSVRC2012_val_00042613.JPEG +ILSVRC2012_val_00047355.JPEG +ILSVRC2012_val_00039347.JPEG +ILSVRC2012_val_00016003.JPEG +ILSVRC2012_val_00039802.JPEG +ILSVRC2012_val_00015499.JPEG +ILSVRC2012_val_00003242.JPEG +ILSVRC2012_val_00029116.JPEG +ILSVRC2012_val_00020677.JPEG +ILSVRC2012_val_00013024.JPEG +ILSVRC2012_val_00039386.JPEG +ILSVRC2012_val_00025253.JPEG +ILSVRC2012_val_00025753.JPEG +ILSVRC2012_val_00010280.JPEG +ILSVRC2012_val_00014221.JPEG +ILSVRC2012_val_00024867.JPEG +ILSVRC2012_val_00009997.JPEG +ILSVRC2012_val_00011203.JPEG +ILSVRC2012_val_00044864.JPEG +ILSVRC2012_val_00020250.JPEG +ILSVRC2012_val_00018772.JPEG +ILSVRC2012_val_00021073.JPEG +ILSVRC2012_val_00047096.JPEG +ILSVRC2012_val_00040422.JPEG +ILSVRC2012_val_00049727.JPEG +ILSVRC2012_val_00037138.JPEG +ILSVRC2012_val_00025918.JPEG +ILSVRC2012_val_00037445.JPEG +ILSVRC2012_val_00041897.JPEG +ILSVRC2012_val_00008712.JPEG +ILSVRC2012_val_00009020.JPEG +ILSVRC2012_val_00012321.JPEG +ILSVRC2012_val_00023863.JPEG +ILSVRC2012_val_00029978.JPEG +ILSVRC2012_val_00042137.JPEG +ILSVRC2012_val_00009681.JPEG +ILSVRC2012_val_00020386.JPEG +ILSVRC2012_val_00000880.JPEG +ILSVRC2012_val_00047324.JPEG +ILSVRC2012_val_00005240.JPEG +ILSVRC2012_val_00039332.JPEG +ILSVRC2012_val_00022237.JPEG +ILSVRC2012_val_00026674.JPEG +ILSVRC2012_val_00039785.JPEG +ILSVRC2012_val_00044521.JPEG +ILSVRC2012_val_00039655.JPEG +ILSVRC2012_val_00014741.JPEG +ILSVRC2012_val_00003441.JPEG +ILSVRC2012_val_00023021.JPEG +ILSVRC2012_val_00032374.JPEG +ILSVRC2012_val_00048606.JPEG +ILSVRC2012_val_00009274.JPEG +ILSVRC2012_val_00030216.JPEG +ILSVRC2012_val_00006135.JPEG +ILSVRC2012_val_00004852.JPEG +ILSVRC2012_val_00024458.JPEG +ILSVRC2012_val_00015780.JPEG +ILSVRC2012_val_00001270.JPEG +ILSVRC2012_val_00028533.JPEG +ILSVRC2012_val_00027821.JPEG +ILSVRC2012_val_00009427.JPEG +ILSVRC2012_val_00049047.JPEG +ILSVRC2012_val_00004800.JPEG +ILSVRC2012_val_00012167.JPEG +ILSVRC2012_val_00030358.JPEG +ILSVRC2012_val_00022473.JPEG +ILSVRC2012_val_00003700.JPEG +ILSVRC2012_val_00005868.JPEG +ILSVRC2012_val_00030033.JPEG +ILSVRC2012_val_00015706.JPEG +ILSVRC2012_val_00024073.JPEG +ILSVRC2012_val_00044884.JPEG +ILSVRC2012_val_00048099.JPEG +ILSVRC2012_val_00026694.JPEG +ILSVRC2012_val_00044648.JPEG +ILSVRC2012_val_00020837.JPEG +ILSVRC2012_val_00013410.JPEG +ILSVRC2012_val_00037172.JPEG +ILSVRC2012_val_00048735.JPEG +ILSVRC2012_val_00037466.JPEG +ILSVRC2012_val_00041907.JPEG +ILSVRC2012_val_00026299.JPEG +ILSVRC2012_val_00029499.JPEG +ILSVRC2012_val_00023702.JPEG +ILSVRC2012_val_00023477.JPEG +ILSVRC2012_val_00014210.JPEG +ILSVRC2012_val_00035416.JPEG +ILSVRC2012_val_00022253.JPEG +ILSVRC2012_val_00009231.JPEG +ILSVRC2012_val_00027673.JPEG +ILSVRC2012_val_00046445.JPEG +ILSVRC2012_val_00000463.JPEG +ILSVRC2012_val_00044907.JPEG +ILSVRC2012_val_00041806.JPEG +ILSVRC2012_val_00035911.JPEG +ILSVRC2012_val_00048053.JPEG +ILSVRC2012_val_00040285.JPEG +ILSVRC2012_val_00021884.JPEG +ILSVRC2012_val_00011632.JPEG +ILSVRC2012_val_00029929.JPEG +ILSVRC2012_val_00003594.JPEG +ILSVRC2012_val_00008442.JPEG +ILSVRC2012_val_00000603.JPEG +ILSVRC2012_val_00038656.JPEG +ILSVRC2012_val_00047360.JPEG +ILSVRC2012_val_00045255.JPEG +ILSVRC2012_val_00014562.JPEG +ILSVRC2012_val_00019235.JPEG +ILSVRC2012_val_00026514.JPEG +ILSVRC2012_val_00006431.JPEG +ILSVRC2012_val_00039560.JPEG +ILSVRC2012_val_00031476.JPEG +ILSVRC2012_val_00014263.JPEG +ILSVRC2012_val_00048386.JPEG +ILSVRC2012_val_00015585.JPEG +ILSVRC2012_val_00049025.JPEG +ILSVRC2012_val_00017243.JPEG +ILSVRC2012_val_00028661.JPEG +ILSVRC2012_val_00041593.JPEG +ILSVRC2012_val_00032537.JPEG +ILSVRC2012_val_00005215.JPEG +ILSVRC2012_val_00046532.JPEG +ILSVRC2012_val_00005108.JPEG +ILSVRC2012_val_00041666.JPEG +ILSVRC2012_val_00043962.JPEG +ILSVRC2012_val_00045564.JPEG +ILSVRC2012_val_00015731.JPEG +ILSVRC2012_val_00004915.JPEG +ILSVRC2012_val_00038977.JPEG +ILSVRC2012_val_00013169.JPEG +ILSVRC2012_val_00049621.JPEG +ILSVRC2012_val_00008696.JPEG +ILSVRC2012_val_00020036.JPEG +ILSVRC2012_val_00043610.JPEG +ILSVRC2012_val_00016032.JPEG +ILSVRC2012_val_00030250.JPEG +ILSVRC2012_val_00005182.JPEG +ILSVRC2012_val_00046048.JPEG +ILSVRC2012_val_00039951.JPEG +ILSVRC2012_val_00029911.JPEG +ILSVRC2012_val_00046045.JPEG +ILSVRC2012_val_00012248.JPEG +ILSVRC2012_val_00000891.JPEG +ILSVRC2012_val_00043299.JPEG +ILSVRC2012_val_00034833.JPEG +ILSVRC2012_val_00025658.JPEG +ILSVRC2012_val_00013799.JPEG +ILSVRC2012_val_00033302.JPEG +ILSVRC2012_val_00023000.JPEG +ILSVRC2012_val_00047792.JPEG +ILSVRC2012_val_00002337.JPEG +ILSVRC2012_val_00031595.JPEG +ILSVRC2012_val_00006951.JPEG +ILSVRC2012_val_00021820.JPEG +ILSVRC2012_val_00028514.JPEG +ILSVRC2012_val_00004194.JPEG +ILSVRC2012_val_00001479.JPEG +ILSVRC2012_val_00024952.JPEG +ILSVRC2012_val_00034246.JPEG +ILSVRC2012_val_00016254.JPEG +ILSVRC2012_val_00024387.JPEG +ILSVRC2012_val_00011783.JPEG +ILSVRC2012_val_00010738.JPEG +ILSVRC2012_val_00023267.JPEG +ILSVRC2012_val_00003158.JPEG +ILSVRC2012_val_00048290.JPEG +ILSVRC2012_val_00016447.JPEG +ILSVRC2012_val_00001045.JPEG +ILSVRC2012_val_00032891.JPEG +ILSVRC2012_val_00003038.JPEG +ILSVRC2012_val_00026778.JPEG +ILSVRC2012_val_00033241.JPEG +ILSVRC2012_val_00000046.JPEG +ILSVRC2012_val_00020219.JPEG +ILSVRC2012_val_00047696.JPEG +ILSVRC2012_val_00001374.JPEG +ILSVRC2012_val_00006124.JPEG +ILSVRC2012_val_00001167.JPEG +ILSVRC2012_val_00002128.JPEG +ILSVRC2012_val_00007132.JPEG +ILSVRC2012_val_00032760.JPEG +ILSVRC2012_val_00019465.JPEG +ILSVRC2012_val_00010205.JPEG +ILSVRC2012_val_00019246.JPEG +ILSVRC2012_val_00018740.JPEG +ILSVRC2012_val_00043027.JPEG +ILSVRC2012_val_00001453.JPEG +ILSVRC2012_val_00009958.JPEG +ILSVRC2012_val_00004473.JPEG +ILSVRC2012_val_00015569.JPEG +ILSVRC2012_val_00031001.JPEG +ILSVRC2012_val_00002122.JPEG +ILSVRC2012_val_00044792.JPEG +ILSVRC2012_val_00028301.JPEG +ILSVRC2012_val_00008301.JPEG +ILSVRC2012_val_00007855.JPEG +ILSVRC2012_val_00004490.JPEG +ILSVRC2012_val_00004411.JPEG +ILSVRC2012_val_00040039.JPEG +ILSVRC2012_val_00049706.JPEG +ILSVRC2012_val_00003533.JPEG +ILSVRC2012_val_00025795.JPEG +ILSVRC2012_val_00031809.JPEG +ILSVRC2012_val_00006164.JPEG +ILSVRC2012_val_00007424.JPEG +ILSVRC2012_val_00013284.JPEG +ILSVRC2012_val_00008497.JPEG +ILSVRC2012_val_00030123.JPEG +ILSVRC2012_val_00019691.JPEG +ILSVRC2012_val_00028116.JPEG +ILSVRC2012_val_00032683.JPEG +ILSVRC2012_val_00004316.JPEG +ILSVRC2012_val_00003574.JPEG +ILSVRC2012_val_00027155.JPEG +ILSVRC2012_val_00043707.JPEG +ILSVRC2012_val_00000068.JPEG +ILSVRC2012_val_00039327.JPEG +ILSVRC2012_val_00011424.JPEG +ILSVRC2012_val_00048057.JPEG +ILSVRC2012_val_00002978.JPEG +ILSVRC2012_val_00001669.JPEG +ILSVRC2012_val_00048109.JPEG +ILSVRC2012_val_00009540.JPEG +ILSVRC2012_val_00048666.JPEG +ILSVRC2012_val_00015456.JPEG +ILSVRC2012_val_00035305.JPEG +ILSVRC2012_val_00011890.JPEG +ILSVRC2012_val_00002748.JPEG +ILSVRC2012_val_00020898.JPEG +ILSVRC2012_val_00020896.JPEG +ILSVRC2012_val_00012327.JPEG +ILSVRC2012_val_00013363.JPEG +ILSVRC2012_val_00019826.JPEG +ILSVRC2012_val_00000831.JPEG +ILSVRC2012_val_00013281.JPEG +ILSVRC2012_val_00003636.JPEG +ILSVRC2012_val_00030405.JPEG +ILSVRC2012_val_00036384.JPEG +ILSVRC2012_val_00002342.JPEG +ILSVRC2012_val_00017539.JPEG +ILSVRC2012_val_00044980.JPEG +ILSVRC2012_val_00024601.JPEG +ILSVRC2012_val_00049249.JPEG +ILSVRC2012_val_00020280.JPEG +ILSVRC2012_val_00019626.JPEG +ILSVRC2012_val_00017597.JPEG +ILSVRC2012_val_00004240.JPEG +ILSVRC2012_val_00019376.JPEG +ILSVRC2012_val_00008694.JPEG +ILSVRC2012_val_00027175.JPEG +ILSVRC2012_val_00048991.JPEG +ILSVRC2012_val_00012558.JPEG +ILSVRC2012_val_00027245.JPEG +ILSVRC2012_val_00021184.JPEG +ILSVRC2012_val_00016824.JPEG +ILSVRC2012_val_00046863.JPEG +ILSVRC2012_val_00020920.JPEG +ILSVRC2012_val_00020880.JPEG +ILSVRC2012_val_00015476.JPEG +ILSVRC2012_val_00021953.JPEG +ILSVRC2012_val_00003480.JPEG +ILSVRC2012_val_00037790.JPEG +ILSVRC2012_val_00027480.JPEG +ILSVRC2012_val_00027368.JPEG +ILSVRC2012_val_00049195.JPEG +ILSVRC2012_val_00025414.JPEG +ILSVRC2012_val_00022392.JPEG +ILSVRC2012_val_00040525.JPEG +ILSVRC2012_val_00041917.JPEG +ILSVRC2012_val_00041134.JPEG +ILSVRC2012_val_00046297.JPEG +ILSVRC2012_val_00028008.JPEG +ILSVRC2012_val_00031084.JPEG +ILSVRC2012_val_00018097.JPEG +ILSVRC2012_val_00021356.JPEG +ILSVRC2012_val_00049174.JPEG +ILSVRC2012_val_00022656.JPEG +ILSVRC2012_val_00022742.JPEG +ILSVRC2012_val_00013947.JPEG +ILSVRC2012_val_00043190.JPEG +ILSVRC2012_val_00000476.JPEG +ILSVRC2012_val_00019329.JPEG +ILSVRC2012_val_00016924.JPEG +ILSVRC2012_val_00033474.JPEG +ILSVRC2012_val_00037051.JPEG +ILSVRC2012_val_00049000.JPEG +ILSVRC2012_val_00000858.JPEG +ILSVRC2012_val_00030824.JPEG +ILSVRC2012_val_00049422.JPEG +ILSVRC2012_val_00026945.JPEG +ILSVRC2012_val_00031627.JPEG +ILSVRC2012_val_00040169.JPEG +ILSVRC2012_val_00039567.JPEG +ILSVRC2012_val_00013496.JPEG +ILSVRC2012_val_00027670.JPEG +ILSVRC2012_val_00033211.JPEG +ILSVRC2012_val_00019725.JPEG +ILSVRC2012_val_00005459.JPEG +ILSVRC2012_val_00026872.JPEG +ILSVRC2012_val_00006385.JPEG +ILSVRC2012_val_00030199.JPEG +ILSVRC2012_val_00019031.JPEG +ILSVRC2012_val_00043078.JPEG +ILSVRC2012_val_00037941.JPEG +ILSVRC2012_val_00027229.JPEG +ILSVRC2012_val_00001372.JPEG +ILSVRC2012_val_00013758.JPEG +ILSVRC2012_val_00020873.JPEG +ILSVRC2012_val_00048847.JPEG +ILSVRC2012_val_00037750.JPEG +ILSVRC2012_val_00023629.JPEG +ILSVRC2012_val_00031537.JPEG +ILSVRC2012_val_00037485.JPEG +ILSVRC2012_val_00027601.JPEG +ILSVRC2012_val_00044159.JPEG +ILSVRC2012_val_00029496.JPEG +ILSVRC2012_val_00038135.JPEG +ILSVRC2012_val_00014488.JPEG +ILSVRC2012_val_00023364.JPEG +ILSVRC2012_val_00008582.JPEG +ILSVRC2012_val_00031338.JPEG +ILSVRC2012_val_00024431.JPEG +ILSVRC2012_val_00020827.JPEG +ILSVRC2012_val_00037575.JPEG +ILSVRC2012_val_00034848.JPEG +ILSVRC2012_val_00022847.JPEG +ILSVRC2012_val_00048288.JPEG +ILSVRC2012_val_00019199.JPEG +ILSVRC2012_val_00009918.JPEG +ILSVRC2012_val_00015293.JPEG +ILSVRC2012_val_00021764.JPEG +ILSVRC2012_val_00001932.JPEG +ILSVRC2012_val_00042085.JPEG +ILSVRC2012_val_00005144.JPEG +ILSVRC2012_val_00007186.JPEG +ILSVRC2012_val_00025462.JPEG +ILSVRC2012_val_00049597.JPEG +ILSVRC2012_val_00020803.JPEG +ILSVRC2012_val_00003540.JPEG +ILSVRC2012_val_00029109.JPEG +ILSVRC2012_val_00022049.JPEG +ILSVRC2012_val_00038406.JPEG +ILSVRC2012_val_00029105.JPEG +ILSVRC2012_val_00027294.JPEG +ILSVRC2012_val_00004508.JPEG +ILSVRC2012_val_00015262.JPEG +ILSVRC2012_val_00024010.JPEG +ILSVRC2012_val_00015860.JPEG +ILSVRC2012_val_00023410.JPEG +ILSVRC2012_val_00035749.JPEG +ILSVRC2012_val_00006547.JPEG +ILSVRC2012_val_00027555.JPEG +ILSVRC2012_val_00023558.JPEG +ILSVRC2012_val_00014261.JPEG +ILSVRC2012_val_00007033.JPEG +ILSVRC2012_val_00037196.JPEG +ILSVRC2012_val_00007358.JPEG +ILSVRC2012_val_00004898.JPEG +ILSVRC2012_val_00022131.JPEG +ILSVRC2012_val_00017461.JPEG +ILSVRC2012_val_00027207.JPEG +ILSVRC2012_val_00012592.JPEG +ILSVRC2012_val_00027412.JPEG +ILSVRC2012_val_00047350.JPEG +ILSVRC2012_val_00011540.JPEG +ILSVRC2012_val_00044294.JPEG +ILSVRC2012_val_00038681.JPEG +ILSVRC2012_val_00020984.JPEG +ILSVRC2012_val_00011143.JPEG +ILSVRC2012_val_00037983.JPEG +ILSVRC2012_val_00027672.JPEG +ILSVRC2012_val_00010295.JPEG +ILSVRC2012_val_00019585.JPEG +ILSVRC2012_val_00030201.JPEG +ILSVRC2012_val_00032166.JPEG +ILSVRC2012_val_00037838.JPEG +ILSVRC2012_val_00033023.JPEG +ILSVRC2012_val_00023609.JPEG +ILSVRC2012_val_00041888.JPEG +ILSVRC2012_val_00024407.JPEG +ILSVRC2012_val_00033868.JPEG +ILSVRC2012_val_00008077.JPEG +ILSVRC2012_val_00006123.JPEG +ILSVRC2012_val_00021422.JPEG +ILSVRC2012_val_00009333.JPEG +ILSVRC2012_val_00006441.JPEG +ILSVRC2012_val_00006915.JPEG +ILSVRC2012_val_00036236.JPEG +ILSVRC2012_val_00010766.JPEG +ILSVRC2012_val_00009605.JPEG +ILSVRC2012_val_00025017.JPEG +ILSVRC2012_val_00039541.JPEG +ILSVRC2012_val_00016971.JPEG +ILSVRC2012_val_00011704.JPEG +ILSVRC2012_val_00004905.JPEG +ILSVRC2012_val_00009423.JPEG +ILSVRC2012_val_00002514.JPEG +ILSVRC2012_val_00016353.JPEG +ILSVRC2012_val_00049638.JPEG +ILSVRC2012_val_00031006.JPEG +ILSVRC2012_val_00032495.JPEG +ILSVRC2012_val_00044850.JPEG +ILSVRC2012_val_00024969.JPEG +ILSVRC2012_val_00022435.JPEG +ILSVRC2012_val_00012531.JPEG +ILSVRC2012_val_00030042.JPEG +ILSVRC2012_val_00004051.JPEG +ILSVRC2012_val_00009862.JPEG +ILSVRC2012_val_00019659.JPEG +ILSVRC2012_val_00039381.JPEG +ILSVRC2012_val_00019863.JPEG +ILSVRC2012_val_00040835.JPEG +ILSVRC2012_val_00014557.JPEG +ILSVRC2012_val_00015502.JPEG +ILSVRC2012_val_00005669.JPEG +ILSVRC2012_val_00030985.JPEG +ILSVRC2012_val_00032979.JPEG +ILSVRC2012_val_00027059.JPEG +ILSVRC2012_val_00025393.JPEG +ILSVRC2012_val_00033253.JPEG +ILSVRC2012_val_00032781.JPEG +ILSVRC2012_val_00021311.JPEG +ILSVRC2012_val_00027044.JPEG +ILSVRC2012_val_00011606.JPEG +ILSVRC2012_val_00037412.JPEG +ILSVRC2012_val_00033129.JPEG +ILSVRC2012_val_00007665.JPEG +ILSVRC2012_val_00021959.JPEG +ILSVRC2012_val_00046146.JPEG +ILSVRC2012_val_00000544.JPEG +ILSVRC2012_val_00004846.JPEG +ILSVRC2012_val_00040116.JPEG +ILSVRC2012_val_00025351.JPEG +ILSVRC2012_val_00006217.JPEG +ILSVRC2012_val_00021113.JPEG +ILSVRC2012_val_00017837.JPEG +ILSVRC2012_val_00018889.JPEG +ILSVRC2012_val_00041968.JPEG +ILSVRC2012_val_00027792.JPEG +ILSVRC2012_val_00016871.JPEG +ILSVRC2012_val_00009309.JPEG +ILSVRC2012_val_00000760.JPEG +ILSVRC2012_val_00024761.JPEG +ILSVRC2012_val_00001825.JPEG +ILSVRC2012_val_00031410.JPEG +ILSVRC2012_val_00017157.JPEG +ILSVRC2012_val_00022739.JPEG +ILSVRC2012_val_00035986.JPEG +ILSVRC2012_val_00033572.JPEG +ILSVRC2012_val_00035129.JPEG +ILSVRC2012_val_00035392.JPEG +ILSVRC2012_val_00033647.JPEG +ILSVRC2012_val_00007375.JPEG +ILSVRC2012_val_00010651.JPEG +ILSVRC2012_val_00023065.JPEG +ILSVRC2012_val_00041864.JPEG +ILSVRC2012_val_00026613.JPEG +ILSVRC2012_val_00025104.JPEG +ILSVRC2012_val_00001190.JPEG +ILSVRC2012_val_00019568.JPEG +ILSVRC2012_val_00023751.JPEG +ILSVRC2012_val_00012502.JPEG +ILSVRC2012_val_00010917.JPEG +ILSVRC2012_val_00016395.JPEG +ILSVRC2012_val_00040208.JPEG +ILSVRC2012_val_00008481.JPEG +ILSVRC2012_val_00033625.JPEG +ILSVRC2012_val_00014993.JPEG +ILSVRC2012_val_00023145.JPEG +ILSVRC2012_val_00010136.JPEG +ILSVRC2012_val_00048387.JPEG +ILSVRC2012_val_00023683.JPEG +ILSVRC2012_val_00013677.JPEG +ILSVRC2012_val_00008788.JPEG +ILSVRC2012_val_00013581.JPEG +ILSVRC2012_val_00004118.JPEG +ILSVRC2012_val_00048236.JPEG +ILSVRC2012_val_00016914.JPEG +ILSVRC2012_val_00035300.JPEG +ILSVRC2012_val_00041040.JPEG +ILSVRC2012_val_00006987.JPEG +ILSVRC2012_val_00047603.JPEG +ILSVRC2012_val_00017860.JPEG +ILSVRC2012_val_00033990.JPEG +ILSVRC2012_val_00044708.JPEG +ILSVRC2012_val_00024770.JPEG +ILSVRC2012_val_00001391.JPEG +ILSVRC2012_val_00011119.JPEG +ILSVRC2012_val_00000657.JPEG +ILSVRC2012_val_00003660.JPEG +ILSVRC2012_val_00036353.JPEG +ILSVRC2012_val_00019544.JPEG +ILSVRC2012_val_00049562.JPEG +ILSVRC2012_val_00022560.JPEG +ILSVRC2012_val_00025905.JPEG +ILSVRC2012_val_00009180.JPEG +ILSVRC2012_val_00044414.JPEG +ILSVRC2012_val_00049084.JPEG +ILSVRC2012_val_00027078.JPEG +ILSVRC2012_val_00049375.JPEG +ILSVRC2012_val_00019184.JPEG +ILSVRC2012_val_00028972.JPEG +ILSVRC2012_val_00008515.JPEG +ILSVRC2012_val_00016329.JPEG +ILSVRC2012_val_00005092.JPEG +ILSVRC2012_val_00043021.JPEG +ILSVRC2012_val_00039953.JPEG +ILSVRC2012_val_00038742.JPEG +ILSVRC2012_val_00038640.JPEG +ILSVRC2012_val_00038196.JPEG +ILSVRC2012_val_00047247.JPEG +ILSVRC2012_val_00010830.JPEG +ILSVRC2012_val_00046855.JPEG +ILSVRC2012_val_00028233.JPEG +ILSVRC2012_val_00001358.JPEG +ILSVRC2012_val_00012754.JPEG +ILSVRC2012_val_00005453.JPEG +ILSVRC2012_val_00002951.JPEG +ILSVRC2012_val_00008466.JPEG +ILSVRC2012_val_00012437.JPEG +ILSVRC2012_val_00037242.JPEG +ILSVRC2012_val_00035684.JPEG +ILSVRC2012_val_00049739.JPEG +ILSVRC2012_val_00013811.JPEG +ILSVRC2012_val_00015414.JPEG +ILSVRC2012_val_00019614.JPEG +ILSVRC2012_val_00044192.JPEG +ILSVRC2012_val_00006600.JPEG +ILSVRC2012_val_00007809.JPEG +ILSVRC2012_val_00006923.JPEG +ILSVRC2012_val_00021265.JPEG +ILSVRC2012_val_00020835.JPEG +ILSVRC2012_val_00008453.JPEG +ILSVRC2012_val_00004341.JPEG +ILSVRC2012_val_00003252.JPEG +ILSVRC2012_val_00029289.JPEG +ILSVRC2012_val_00038792.JPEG +ILSVRC2012_val_00009847.JPEG +ILSVRC2012_val_00042081.JPEG +ILSVRC2012_val_00048779.JPEG +ILSVRC2012_val_00030397.JPEG +ILSVRC2012_val_00044164.JPEG +ILSVRC2012_val_00026657.JPEG +ILSVRC2012_val_00024522.JPEG +ILSVRC2012_val_00007374.JPEG +ILSVRC2012_val_00012784.JPEG +ILSVRC2012_val_00004007.JPEG +ILSVRC2012_val_00025950.JPEG +ILSVRC2012_val_00038231.JPEG +ILSVRC2012_val_00033997.JPEG +ILSVRC2012_val_00004693.JPEG +ILSVRC2012_val_00040140.JPEG +ILSVRC2012_val_00012355.JPEG +ILSVRC2012_val_00039524.JPEG +ILSVRC2012_val_00009627.JPEG +ILSVRC2012_val_00049546.JPEG +ILSVRC2012_val_00011791.JPEG +ILSVRC2012_val_00032643.JPEG +ILSVRC2012_val_00033111.JPEG +ILSVRC2012_val_00035975.JPEG +ILSVRC2012_val_00040178.JPEG +ILSVRC2012_val_00025820.JPEG +ILSVRC2012_val_00028959.JPEG +ILSVRC2012_val_00008470.JPEG +ILSVRC2012_val_00012351.JPEG +ILSVRC2012_val_00018332.JPEG +ILSVRC2012_val_00003639.JPEG +ILSVRC2012_val_00047310.JPEG +ILSVRC2012_val_00046409.JPEG +ILSVRC2012_val_00003090.JPEG +ILSVRC2012_val_00006742.JPEG +ILSVRC2012_val_00000379.JPEG +ILSVRC2012_val_00041333.JPEG +ILSVRC2012_val_00042387.JPEG +ILSVRC2012_val_00011042.JPEG +ILSVRC2012_val_00020778.JPEG +ILSVRC2012_val_00043248.JPEG +ILSVRC2012_val_00027585.JPEG +ILSVRC2012_val_00030925.JPEG +ILSVRC2012_val_00024703.JPEG +ILSVRC2012_val_00027295.JPEG +ILSVRC2012_val_00035019.JPEG +ILSVRC2012_val_00014590.JPEG +ILSVRC2012_val_00040475.JPEG +ILSVRC2012_val_00029417.JPEG +ILSVRC2012_val_00011610.JPEG +ILSVRC2012_val_00020955.JPEG +ILSVRC2012_val_00003567.JPEG +ILSVRC2012_val_00041793.JPEG +ILSVRC2012_val_00016089.JPEG +ILSVRC2012_val_00044772.JPEG +ILSVRC2012_val_00041473.JPEG +ILSVRC2012_val_00005309.JPEG +ILSVRC2012_val_00040785.JPEG +ILSVRC2012_val_00024764.JPEG +ILSVRC2012_val_00023879.JPEG +ILSVRC2012_val_00038680.JPEG +ILSVRC2012_val_00037820.JPEG +ILSVRC2012_val_00007513.JPEG +ILSVRC2012_val_00015525.JPEG +ILSVRC2012_val_00011181.JPEG +ILSVRC2012_val_00017180.JPEG +ILSVRC2012_val_00030045.JPEG +ILSVRC2012_val_00028672.JPEG +ILSVRC2012_val_00022604.JPEG +ILSVRC2012_val_00030840.JPEG +ILSVRC2012_val_00025700.JPEG +ILSVRC2012_val_00009433.JPEG +ILSVRC2012_val_00042953.JPEG +ILSVRC2012_val_00012491.JPEG +ILSVRC2012_val_00039638.JPEG +ILSVRC2012_val_00039893.JPEG +ILSVRC2012_val_00017903.JPEG +ILSVRC2012_val_00018539.JPEG +ILSVRC2012_val_00026638.JPEG +ILSVRC2012_val_00046456.JPEG +ILSVRC2012_val_00001033.JPEG +ILSVRC2012_val_00035545.JPEG +ILSVRC2012_val_00030731.JPEG +ILSVRC2012_val_00022647.JPEG +ILSVRC2012_val_00035785.JPEG +ILSVRC2012_val_00004794.JPEG +ILSVRC2012_val_00013608.JPEG +ILSVRC2012_val_00049724.JPEG +ILSVRC2012_val_00031262.JPEG +ILSVRC2012_val_00012488.JPEG +ILSVRC2012_val_00023969.JPEG +ILSVRC2012_val_00021538.JPEG +ILSVRC2012_val_00034563.JPEG +ILSVRC2012_val_00045101.JPEG +ILSVRC2012_val_00032731.JPEG +ILSVRC2012_val_00012648.JPEG +ILSVRC2012_val_00034120.JPEG +ILSVRC2012_val_00044447.JPEG +ILSVRC2012_val_00025913.JPEG +ILSVRC2012_val_00018927.JPEG +ILSVRC2012_val_00011128.JPEG +ILSVRC2012_val_00007963.JPEG +ILSVRC2012_val_00026408.JPEG +ILSVRC2012_val_00040742.JPEG +ILSVRC2012_val_00049705.JPEG +ILSVRC2012_val_00047396.JPEG +ILSVRC2012_val_00024005.JPEG +ILSVRC2012_val_00044795.JPEG +ILSVRC2012_val_00029603.JPEG +ILSVRC2012_val_00034734.JPEG +ILSVRC2012_val_00043458.JPEG +ILSVRC2012_val_00045295.JPEG +ILSVRC2012_val_00016834.JPEG +ILSVRC2012_val_00048866.JPEG +ILSVRC2012_val_00041484.JPEG +ILSVRC2012_val_00014190.JPEG +ILSVRC2012_val_00006984.JPEG +ILSVRC2012_val_00049770.JPEG +ILSVRC2012_val_00008920.JPEG +ILSVRC2012_val_00025115.JPEG +ILSVRC2012_val_00046169.JPEG +ILSVRC2012_val_00034856.JPEG +ILSVRC2012_val_00014984.JPEG +ILSVRC2012_val_00000160.JPEG +ILSVRC2012_val_00007010.JPEG +ILSVRC2012_val_00048995.JPEG +ILSVRC2012_val_00047020.JPEG +ILSVRC2012_val_00023124.JPEG +ILSVRC2012_val_00012600.JPEG +ILSVRC2012_val_00048526.JPEG +ILSVRC2012_val_00019264.JPEG +ILSVRC2012_val_00042419.JPEG +ILSVRC2012_val_00038849.JPEG +ILSVRC2012_val_00000912.JPEG +ILSVRC2012_val_00016837.JPEG +ILSVRC2012_val_00021367.JPEG +ILSVRC2012_val_00015327.JPEG +ILSVRC2012_val_00042886.JPEG +ILSVRC2012_val_00028530.JPEG +ILSVRC2012_val_00007758.JPEG +ILSVRC2012_val_00020557.JPEG +ILSVRC2012_val_00024393.JPEG +ILSVRC2012_val_00017058.JPEG +ILSVRC2012_val_00041259.JPEG +ILSVRC2012_val_00040007.JPEG +ILSVRC2012_val_00002236.JPEG +ILSVRC2012_val_00023033.JPEG +ILSVRC2012_val_00005068.JPEG +ILSVRC2012_val_00047074.JPEG +ILSVRC2012_val_00036330.JPEG +ILSVRC2012_val_00016801.JPEG +ILSVRC2012_val_00014770.JPEG +ILSVRC2012_val_00039189.JPEG +ILSVRC2012_val_00033788.JPEG +ILSVRC2012_val_00011621.JPEG +ILSVRC2012_val_00043847.JPEG +ILSVRC2012_val_00030609.JPEG +ILSVRC2012_val_00037916.JPEG +ILSVRC2012_val_00000200.JPEG +ILSVRC2012_val_00027119.JPEG +ILSVRC2012_val_00017924.JPEG +ILSVRC2012_val_00026920.JPEG +ILSVRC2012_val_00006128.JPEG +ILSVRC2012_val_00018993.JPEG +ILSVRC2012_val_00027998.JPEG +ILSVRC2012_val_00009235.JPEG +ILSVRC2012_val_00022470.JPEG +ILSVRC2012_val_00031456.JPEG +ILSVRC2012_val_00007651.JPEG +ILSVRC2012_val_00049215.JPEG +ILSVRC2012_val_00024268.JPEG +ILSVRC2012_val_00020388.JPEG +ILSVRC2012_val_00027322.JPEG +ILSVRC2012_val_00018653.JPEG +ILSVRC2012_val_00017224.JPEG +ILSVRC2012_val_00031812.JPEG +ILSVRC2012_val_00009253.JPEG +ILSVRC2012_val_00038614.JPEG +ILSVRC2012_val_00025177.JPEG +ILSVRC2012_val_00013805.JPEG +ILSVRC2012_val_00019282.JPEG +ILSVRC2012_val_00029954.JPEG +ILSVRC2012_val_00003651.JPEG +ILSVRC2012_val_00039336.JPEG +ILSVRC2012_val_00045992.JPEG +ILSVRC2012_val_00043590.JPEG +ILSVRC2012_val_00042129.JPEG +ILSVRC2012_val_00007510.JPEG +ILSVRC2012_val_00030648.JPEG +ILSVRC2012_val_00020456.JPEG +ILSVRC2012_val_00045164.JPEG +ILSVRC2012_val_00010959.JPEG +ILSVRC2012_val_00021593.JPEG +ILSVRC2012_val_00015892.JPEG +ILSVRC2012_val_00047703.JPEG +ILSVRC2012_val_00033439.JPEG +ILSVRC2012_val_00036671.JPEG +ILSVRC2012_val_00035619.JPEG +ILSVRC2012_val_00020626.JPEG +ILSVRC2012_val_00023871.JPEG +ILSVRC2012_val_00017213.JPEG +ILSVRC2012_val_00041516.JPEG +ILSVRC2012_val_00018612.JPEG +ILSVRC2012_val_00024587.JPEG +ILSVRC2012_val_00020043.JPEG +ILSVRC2012_val_00024567.JPEG +ILSVRC2012_val_00016465.JPEG +ILSVRC2012_val_00020606.JPEG +ILSVRC2012_val_00049841.JPEG +ILSVRC2012_val_00040394.JPEG +ILSVRC2012_val_00046121.JPEG +ILSVRC2012_val_00023685.JPEG +ILSVRC2012_val_00002727.JPEG +ILSVRC2012_val_00032208.JPEG +ILSVRC2012_val_00008069.JPEG +ILSVRC2012_val_00025738.JPEG +ILSVRC2012_val_00019809.JPEG +ILSVRC2012_val_00032393.JPEG +ILSVRC2012_val_00039466.JPEG +ILSVRC2012_val_00011883.JPEG +ILSVRC2012_val_00007157.JPEG +ILSVRC2012_val_00003200.JPEG +ILSVRC2012_val_00043471.JPEG +ILSVRC2012_val_00001770.JPEG +ILSVRC2012_val_00047525.JPEG +ILSVRC2012_val_00007162.JPEG +ILSVRC2012_val_00040076.JPEG +ILSVRC2012_val_00010565.JPEG +ILSVRC2012_val_00001132.JPEG +ILSVRC2012_val_00026242.JPEG +ILSVRC2012_val_00025191.JPEG +ILSVRC2012_val_00000718.JPEG +ILSVRC2012_val_00012762.JPEG +ILSVRC2012_val_00023529.JPEG +ILSVRC2012_val_00016960.JPEG +ILSVRC2012_val_00036817.JPEG +ILSVRC2012_val_00043528.JPEG +ILSVRC2012_val_00034911.JPEG +ILSVRC2012_val_00012783.JPEG +ILSVRC2012_val_00044915.JPEG +ILSVRC2012_val_00006744.JPEG +ILSVRC2012_val_00048440.JPEG +ILSVRC2012_val_00008135.JPEG +ILSVRC2012_val_00005905.JPEG +ILSVRC2012_val_00025916.JPEG +ILSVRC2012_val_00035927.JPEG +ILSVRC2012_val_00049010.JPEG +ILSVRC2012_val_00011652.JPEG +ILSVRC2012_val_00010121.JPEG +ILSVRC2012_val_00024606.JPEG +ILSVRC2012_val_00010345.JPEG +ILSVRC2012_val_00049138.JPEG +ILSVRC2012_val_00002873.JPEG +ILSVRC2012_val_00023016.JPEG +ILSVRC2012_val_00020173.JPEG +ILSVRC2012_val_00043130.JPEG +ILSVRC2012_val_00023503.JPEG +ILSVRC2012_val_00000014.JPEG +ILSVRC2012_val_00006696.JPEG +ILSVRC2012_val_00015286.JPEG +ILSVRC2012_val_00023275.JPEG +ILSVRC2012_val_00017709.JPEG +ILSVRC2012_val_00007655.JPEG +ILSVRC2012_val_00030317.JPEG +ILSVRC2012_val_00008948.JPEG +ILSVRC2012_val_00008886.JPEG +ILSVRC2012_val_00011664.JPEG +ILSVRC2012_val_00037828.JPEG +ILSVRC2012_val_00039835.JPEG +ILSVRC2012_val_00034094.JPEG +ILSVRC2012_val_00011436.JPEG +ILSVRC2012_val_00047486.JPEG +ILSVRC2012_val_00022825.JPEG +ILSVRC2012_val_00038413.JPEG +ILSVRC2012_val_00017027.JPEG +ILSVRC2012_val_00047128.JPEG +ILSVRC2012_val_00034625.JPEG +ILSVRC2012_val_00036417.JPEG +ILSVRC2012_val_00047505.JPEG +ILSVRC2012_val_00042264.JPEG +ILSVRC2012_val_00036710.JPEG +ILSVRC2012_val_00039203.JPEG +ILSVRC2012_val_00030386.JPEG +ILSVRC2012_val_00000591.JPEG +ILSVRC2012_val_00002275.JPEG +ILSVRC2012_val_00042574.JPEG +ILSVRC2012_val_00039503.JPEG +ILSVRC2012_val_00037725.JPEG +ILSVRC2012_val_00037749.JPEG +ILSVRC2012_val_00021411.JPEG +ILSVRC2012_val_00016421.JPEG +ILSVRC2012_val_00028254.JPEG +ILSVRC2012_val_00013898.JPEG +ILSVRC2012_val_00006844.JPEG +ILSVRC2012_val_00046517.JPEG +ILSVRC2012_val_00010396.JPEG +ILSVRC2012_val_00015518.JPEG +ILSVRC2012_val_00029209.JPEG +ILSVRC2012_val_00001852.JPEG +ILSVRC2012_val_00009975.JPEG +ILSVRC2012_val_00001263.JPEG +ILSVRC2012_val_00032101.JPEG +ILSVRC2012_val_00037708.JPEG +ILSVRC2012_val_00037695.JPEG +ILSVRC2012_val_00021911.JPEG +ILSVRC2012_val_00047659.JPEG +ILSVRC2012_val_00006518.JPEG +ILSVRC2012_val_00036617.JPEG +ILSVRC2012_val_00023452.JPEG +ILSVRC2012_val_00002985.JPEG +ILSVRC2012_val_00010644.JPEG +ILSVRC2012_val_00046139.JPEG +ILSVRC2012_val_00040246.JPEG +ILSVRC2012_val_00009395.JPEG +ILSVRC2012_val_00036168.JPEG +ILSVRC2012_val_00028730.JPEG +ILSVRC2012_val_00006525.JPEG +ILSVRC2012_val_00032985.JPEG +ILSVRC2012_val_00024908.JPEG +ILSVRC2012_val_00011745.JPEG +ILSVRC2012_val_00008640.JPEG +ILSVRC2012_val_00005011.JPEG +ILSVRC2012_val_00031723.JPEG +ILSVRC2012_val_00006041.JPEG +ILSVRC2012_val_00007315.JPEG +ILSVRC2012_val_00032030.JPEG +ILSVRC2012_val_00049704.JPEG +ILSVRC2012_val_00002392.JPEG +ILSVRC2012_val_00035173.JPEG +ILSVRC2012_val_00036986.JPEG +ILSVRC2012_val_00010591.JPEG +ILSVRC2012_val_00007005.JPEG +ILSVRC2012_val_00002796.JPEG +ILSVRC2012_val_00004133.JPEG +ILSVRC2012_val_00027888.JPEG +ILSVRC2012_val_00019845.JPEG +ILSVRC2012_val_00037014.JPEG +ILSVRC2012_val_00022161.JPEG +ILSVRC2012_val_00023092.JPEG +ILSVRC2012_val_00014990.JPEG +ILSVRC2012_val_00044650.JPEG +ILSVRC2012_val_00030321.JPEG +ILSVRC2012_val_00021371.JPEG +ILSVRC2012_val_00049003.JPEG +ILSVRC2012_val_00023783.JPEG +ILSVRC2012_val_00038968.JPEG +ILSVRC2012_val_00046700.JPEG +ILSVRC2012_val_00013819.JPEG +ILSVRC2012_val_00000876.JPEG +ILSVRC2012_val_00048756.JPEG +ILSVRC2012_val_00036119.JPEG +ILSVRC2012_val_00025627.JPEG +ILSVRC2012_val_00044511.JPEG +ILSVRC2012_val_00045106.JPEG +ILSVRC2012_val_00014766.JPEG +ILSVRC2012_val_00042745.JPEG +ILSVRC2012_val_00015278.JPEG +ILSVRC2012_val_00009760.JPEG +ILSVRC2012_val_00021871.JPEG +ILSVRC2012_val_00003273.JPEG +ILSVRC2012_val_00030424.JPEG +ILSVRC2012_val_00044453.JPEG +ILSVRC2012_val_00007746.JPEG +ILSVRC2012_val_00021539.JPEG +ILSVRC2012_val_00001693.JPEG +ILSVRC2012_val_00034332.JPEG +ILSVRC2012_val_00037973.JPEG +ILSVRC2012_val_00011034.JPEG +ILSVRC2012_val_00030434.JPEG +ILSVRC2012_val_00032585.JPEG +ILSVRC2012_val_00049644.JPEG +ILSVRC2012_val_00038043.JPEG +ILSVRC2012_val_00018525.JPEG +ILSVRC2012_val_00013699.JPEG +ILSVRC2012_val_00044388.JPEG +ILSVRC2012_val_00047727.JPEG +ILSVRC2012_val_00045396.JPEG +ILSVRC2012_val_00031725.JPEG +ILSVRC2012_val_00019019.JPEG +ILSVRC2012_val_00045825.JPEG +ILSVRC2012_val_00029786.JPEG +ILSVRC2012_val_00018128.JPEG +ILSVRC2012_val_00014218.JPEG +ILSVRC2012_val_00017202.JPEG +ILSVRC2012_val_00025765.JPEG +ILSVRC2012_val_00044167.JPEG +ILSVRC2012_val_00005322.JPEG +ILSVRC2012_val_00026673.JPEG +ILSVRC2012_val_00034666.JPEG +ILSVRC2012_val_00007928.JPEG +ILSVRC2012_val_00005582.JPEG +ILSVRC2012_val_00018837.JPEG +ILSVRC2012_val_00020260.JPEG +ILSVRC2012_val_00017259.JPEG +ILSVRC2012_val_00013300.JPEG +ILSVRC2012_val_00048894.JPEG +ILSVRC2012_val_00015560.JPEG +ILSVRC2012_val_00040716.JPEG +ILSVRC2012_val_00044493.JPEG +ILSVRC2012_val_00007487.JPEG +ILSVRC2012_val_00035705.JPEG +ILSVRC2012_val_00044935.JPEG +ILSVRC2012_val_00031634.JPEG +ILSVRC2012_val_00047489.JPEG +ILSVRC2012_val_00040200.JPEG +ILSVRC2012_val_00000483.JPEG +ILSVRC2012_val_00017477.JPEG +ILSVRC2012_val_00029642.JPEG +ILSVRC2012_val_00026046.JPEG +ILSVRC2012_val_00005654.JPEG +ILSVRC2012_val_00035091.JPEG +ILSVRC2012_val_00022966.JPEG +ILSVRC2012_val_00018549.JPEG +ILSVRC2012_val_00010059.JPEG +ILSVRC2012_val_00041385.JPEG +ILSVRC2012_val_00021942.JPEG +ILSVRC2012_val_00013582.JPEG +ILSVRC2012_val_00024795.JPEG +ILSVRC2012_val_00044685.JPEG +ILSVRC2012_val_00040056.JPEG +ILSVRC2012_val_00010513.JPEG +ILSVRC2012_val_00044296.JPEG +ILSVRC2012_val_00043634.JPEG +ILSVRC2012_val_00033934.JPEG +ILSVRC2012_val_00031648.JPEG +ILSVRC2012_val_00046067.JPEG +ILSVRC2012_val_00026197.JPEG +ILSVRC2012_val_00014372.JPEG +ILSVRC2012_val_00000984.JPEG +ILSVRC2012_val_00035826.JPEG +ILSVRC2012_val_00044385.JPEG +ILSVRC2012_val_00039760.JPEG +ILSVRC2012_val_00009603.JPEG +ILSVRC2012_val_00024001.JPEG +ILSVRC2012_val_00000042.JPEG +ILSVRC2012_val_00032729.JPEG +ILSVRC2012_val_00000931.JPEG +ILSVRC2012_val_00042902.JPEG +ILSVRC2012_val_00018763.JPEG +ILSVRC2012_val_00016384.JPEG +ILSVRC2012_val_00041128.JPEG +ILSVRC2012_val_00048671.JPEG +ILSVRC2012_val_00024369.JPEG +ILSVRC2012_val_00024350.JPEG +ILSVRC2012_val_00003938.JPEG +ILSVRC2012_val_00045903.JPEG +ILSVRC2012_val_00008657.JPEG +ILSVRC2012_val_00044281.JPEG +ILSVRC2012_val_00017897.JPEG +ILSVRC2012_val_00009848.JPEG +ILSVRC2012_val_00016422.JPEG +ILSVRC2012_val_00045727.JPEG +ILSVRC2012_val_00046587.JPEG +ILSVRC2012_val_00018257.JPEG +ILSVRC2012_val_00037766.JPEG +ILSVRC2012_val_00041067.JPEG +ILSVRC2012_val_00000078.JPEG +ILSVRC2012_val_00040557.JPEG +ILSVRC2012_val_00011863.JPEG +ILSVRC2012_val_00011541.JPEG +ILSVRC2012_val_00036109.JPEG +ILSVRC2012_val_00013250.JPEG +ILSVRC2012_val_00021852.JPEG +ILSVRC2012_val_00000401.JPEG +ILSVRC2012_val_00003319.JPEG +ILSVRC2012_val_00040612.JPEG +ILSVRC2012_val_00026783.JPEG +ILSVRC2012_val_00020653.JPEG +ILSVRC2012_val_00027324.JPEG +ILSVRC2012_val_00002692.JPEG +ILSVRC2012_val_00024716.JPEG +ILSVRC2012_val_00047650.JPEG +ILSVRC2012_val_00013731.JPEG +ILSVRC2012_val_00033638.JPEG +ILSVRC2012_val_00041408.JPEG +ILSVRC2012_val_00020671.JPEG +ILSVRC2012_val_00047693.JPEG +ILSVRC2012_val_00001018.JPEG +ILSVRC2012_val_00005560.JPEG +ILSVRC2012_val_00034151.JPEG +ILSVRC2012_val_00011407.JPEG +ILSVRC2012_val_00036935.JPEG +ILSVRC2012_val_00017379.JPEG +ILSVRC2012_val_00023053.JPEG +ILSVRC2012_val_00033496.JPEG +ILSVRC2012_val_00024370.JPEG +ILSVRC2012_val_00046935.JPEG +ILSVRC2012_val_00012047.JPEG +ILSVRC2012_val_00019290.JPEG +ILSVRC2012_val_00016836.JPEG +ILSVRC2012_val_00007373.JPEG +ILSVRC2012_val_00048216.JPEG +ILSVRC2012_val_00035872.JPEG +ILSVRC2012_val_00005079.JPEG +ILSVRC2012_val_00034765.JPEG +ILSVRC2012_val_00019198.JPEG +ILSVRC2012_val_00017320.JPEG +ILSVRC2012_val_00015370.JPEG +ILSVRC2012_val_00030958.JPEG +ILSVRC2012_val_00016932.JPEG +ILSVRC2012_val_00047111.JPEG +ILSVRC2012_val_00001692.JPEG +ILSVRC2012_val_00023973.JPEG +ILSVRC2012_val_00042837.JPEG +ILSVRC2012_val_00012045.JPEG +ILSVRC2012_val_00028399.JPEG +ILSVRC2012_val_00002582.JPEG +ILSVRC2012_val_00014302.JPEG +ILSVRC2012_val_00021562.JPEG +ILSVRC2012_val_00043360.JPEG +ILSVRC2012_val_00021891.JPEG +ILSVRC2012_val_00023858.JPEG +ILSVRC2012_val_00035066.JPEG +ILSVRC2012_val_00044367.JPEG +ILSVRC2012_val_00040868.JPEG +ILSVRC2012_val_00046854.JPEG +ILSVRC2012_val_00023803.JPEG +ILSVRC2012_val_00049755.JPEG +ILSVRC2012_val_00048462.JPEG +ILSVRC2012_val_00035645.JPEG +ILSVRC2012_val_00004298.JPEG +ILSVRC2012_val_00019052.JPEG +ILSVRC2012_val_00025974.JPEG +ILSVRC2012_val_00018785.JPEG +ILSVRC2012_val_00010249.JPEG +ILSVRC2012_val_00048073.JPEG +ILSVRC2012_val_00019937.JPEG +ILSVRC2012_val_00025837.JPEG +ILSVRC2012_val_00021805.JPEG +ILSVRC2012_val_00024095.JPEG +ILSVRC2012_val_00006125.JPEG +ILSVRC2012_val_00045635.JPEG +ILSVRC2012_val_00018224.JPEG +ILSVRC2012_val_00012662.JPEG +ILSVRC2012_val_00033730.JPEG +ILSVRC2012_val_00027798.JPEG +ILSVRC2012_val_00049845.JPEG +ILSVRC2012_val_00024728.JPEG +ILSVRC2012_val_00049072.JPEG +ILSVRC2012_val_00041076.JPEG +ILSVRC2012_val_00039952.JPEG +ILSVRC2012_val_00017612.JPEG +ILSVRC2012_val_00008459.JPEG +ILSVRC2012_val_00045637.JPEG +ILSVRC2012_val_00045166.JPEG +ILSVRC2012_val_00010016.JPEG +ILSVRC2012_val_00021005.JPEG +ILSVRC2012_val_00005639.JPEG +ILSVRC2012_val_00035976.JPEG +ILSVRC2012_val_00016236.JPEG +ILSVRC2012_val_00021360.JPEG +ILSVRC2012_val_00049646.JPEG +ILSVRC2012_val_00046485.JPEG +ILSVRC2012_val_00004644.JPEG +ILSVRC2012_val_00022674.JPEG +ILSVRC2012_val_00011846.JPEG +ILSVRC2012_val_00014614.JPEG +ILSVRC2012_val_00016866.JPEG +ILSVRC2012_val_00002350.JPEG +ILSVRC2012_val_00018243.JPEG +ILSVRC2012_val_00002693.JPEG +ILSVRC2012_val_00010446.JPEG +ILSVRC2012_val_00016388.JPEG +ILSVRC2012_val_00027369.JPEG +ILSVRC2012_val_00017340.JPEG +ILSVRC2012_val_00045786.JPEG +ILSVRC2012_val_00019077.JPEG +ILSVRC2012_val_00044123.JPEG +ILSVRC2012_val_00018880.JPEG +ILSVRC2012_val_00011442.JPEG +ILSVRC2012_val_00037100.JPEG +ILSVRC2012_val_00016942.JPEG +ILSVRC2012_val_00042839.JPEG +ILSVRC2012_val_00031729.JPEG +ILSVRC2012_val_00010934.JPEG +ILSVRC2012_val_00001509.JPEG +ILSVRC2012_val_00006153.JPEG +ILSVRC2012_val_00033347.JPEG +ILSVRC2012_val_00010816.JPEG +ILSVRC2012_val_00034299.JPEG +ILSVRC2012_val_00008552.JPEG +ILSVRC2012_val_00040402.JPEG +ILSVRC2012_val_00049456.JPEG +ILSVRC2012_val_00013751.JPEG +ILSVRC2012_val_00047416.JPEG +ILSVRC2012_val_00029788.JPEG +ILSVRC2012_val_00028413.JPEG +ILSVRC2012_val_00036004.JPEG +ILSVRC2012_val_00007776.JPEG +ILSVRC2012_val_00031108.JPEG +ILSVRC2012_val_00035106.JPEG +ILSVRC2012_val_00002343.JPEG +ILSVRC2012_val_00012451.JPEG +ILSVRC2012_val_00039523.JPEG +ILSVRC2012_val_00019944.JPEG +ILSVRC2012_val_00044326.JPEG +ILSVRC2012_val_00018930.JPEG +ILSVRC2012_val_00043737.JPEG +ILSVRC2012_val_00002448.JPEG +ILSVRC2012_val_00014788.JPEG +ILSVRC2012_val_00010281.JPEG +ILSVRC2012_val_00000751.JPEG +ILSVRC2012_val_00010517.JPEG +ILSVRC2012_val_00019637.JPEG +ILSVRC2012_val_00015971.JPEG +ILSVRC2012_val_00042766.JPEG +ILSVRC2012_val_00024249.JPEG +ILSVRC2012_val_00038935.JPEG +ILSVRC2012_val_00032219.JPEG +ILSVRC2012_val_00036395.JPEG +ILSVRC2012_val_00003658.JPEG +ILSVRC2012_val_00029942.JPEG +ILSVRC2012_val_00008190.JPEG +ILSVRC2012_val_00024775.JPEG +ILSVRC2012_val_00030703.JPEG +ILSVRC2012_val_00005099.JPEG +ILSVRC2012_val_00020857.JPEG +ILSVRC2012_val_00023536.JPEG +ILSVRC2012_val_00006993.JPEG +ILSVRC2012_val_00020711.JPEG +ILSVRC2012_val_00026389.JPEG +ILSVRC2012_val_00027339.JPEG +ILSVRC2012_val_00020559.JPEG +ILSVRC2012_val_00031420.JPEG +ILSVRC2012_val_00005091.JPEG +ILSVRC2012_val_00007940.JPEG +ILSVRC2012_val_00006140.JPEG +ILSVRC2012_val_00030766.JPEG +ILSVRC2012_val_00034501.JPEG +ILSVRC2012_val_00048705.JPEG +ILSVRC2012_val_00034242.JPEG +ILSVRC2012_val_00019363.JPEG +ILSVRC2012_val_00024730.JPEG +ILSVRC2012_val_00024449.JPEG +ILSVRC2012_val_00003640.JPEG +ILSVRC2012_val_00016680.JPEG +ILSVRC2012_val_00013874.JPEG +ILSVRC2012_val_00039519.JPEG +ILSVRC2012_val_00000334.JPEG +ILSVRC2012_val_00001712.JPEG +ILSVRC2012_val_00016798.JPEG +ILSVRC2012_val_00013181.JPEG +ILSVRC2012_val_00012945.JPEG +ILSVRC2012_val_00040486.JPEG +ILSVRC2012_val_00023747.JPEG +ILSVRC2012_val_00045894.JPEG +ILSVRC2012_val_00029329.JPEG +ILSVRC2012_val_00037666.JPEG +ILSVRC2012_val_00025894.JPEG +ILSVRC2012_val_00036845.JPEG +ILSVRC2012_val_00049278.JPEG +ILSVRC2012_val_00044929.JPEG +ILSVRC2012_val_00035223.JPEG +ILSVRC2012_val_00042224.JPEG +ILSVRC2012_val_00049303.JPEG +ILSVRC2012_val_00006542.JPEG +ILSVRC2012_val_00023611.JPEG +ILSVRC2012_val_00020086.JPEG +ILSVRC2012_val_00023838.JPEG +ILSVRC2012_val_00034690.JPEG +ILSVRC2012_val_00019501.JPEG +ILSVRC2012_val_00000223.JPEG +ILSVRC2012_val_00029224.JPEG +ILSVRC2012_val_00033257.JPEG +ILSVRC2012_val_00019668.JPEG +ILSVRC2012_val_00038060.JPEG +ILSVRC2012_val_00047758.JPEG +ILSVRC2012_val_00043018.JPEG +ILSVRC2012_val_00039892.JPEG +ILSVRC2012_val_00015450.JPEG +ILSVRC2012_val_00019142.JPEG +ILSVRC2012_val_00032777.JPEG +ILSVRC2012_val_00037343.JPEG +ILSVRC2012_val_00023684.JPEG +ILSVRC2012_val_00047478.JPEG +ILSVRC2012_val_00014506.JPEG +ILSVRC2012_val_00045845.JPEG +ILSVRC2012_val_00044119.JPEG +ILSVRC2012_val_00034303.JPEG +ILSVRC2012_val_00006151.JPEG +ILSVRC2012_val_00028968.JPEG +ILSVRC2012_val_00043814.JPEG +ILSVRC2012_val_00046041.JPEG +ILSVRC2012_val_00041083.JPEG +ILSVRC2012_val_00031286.JPEG +ILSVRC2012_val_00035820.JPEG +ILSVRC2012_val_00031008.JPEG +ILSVRC2012_val_00012847.JPEG +ILSVRC2012_val_00038212.JPEG +ILSVRC2012_val_00026055.JPEG +ILSVRC2012_val_00035938.JPEG +ILSVRC2012_val_00008178.JPEG +ILSVRC2012_val_00005575.JPEG +ILSVRC2012_val_00013464.JPEG +ILSVRC2012_val_00005283.JPEG +ILSVRC2012_val_00006829.JPEG +ILSVRC2012_val_00011650.JPEG +ILSVRC2012_val_00032021.JPEG +ILSVRC2012_val_00016760.JPEG +ILSVRC2012_val_00008031.JPEG +ILSVRC2012_val_00011608.JPEG +ILSVRC2012_val_00004519.JPEG +ILSVRC2012_val_00046367.JPEG +ILSVRC2012_val_00026067.JPEG +ILSVRC2012_val_00023130.JPEG +ILSVRC2012_val_00023223.JPEG +ILSVRC2012_val_00023784.JPEG +ILSVRC2012_val_00036833.JPEG +ILSVRC2012_val_00029053.JPEG +ILSVRC2012_val_00006793.JPEG +ILSVRC2012_val_00042254.JPEG +ILSVRC2012_val_00009356.JPEG +ILSVRC2012_val_00002473.JPEG +ILSVRC2012_val_00027879.JPEG +ILSVRC2012_val_00048718.JPEG +ILSVRC2012_val_00046091.JPEG +ILSVRC2012_val_00015715.JPEG +ILSVRC2012_val_00018794.JPEG +ILSVRC2012_val_00028229.JPEG +ILSVRC2012_val_00019795.JPEG +ILSVRC2012_val_00038865.JPEG +ILSVRC2012_val_00044381.JPEG +ILSVRC2012_val_00016168.JPEG +ILSVRC2012_val_00013124.JPEG +ILSVRC2012_val_00026174.JPEG +ILSVRC2012_val_00037277.JPEG +ILSVRC2012_val_00046141.JPEG +ILSVRC2012_val_00009567.JPEG +ILSVRC2012_val_00034937.JPEG +ILSVRC2012_val_00037458.JPEG +ILSVRC2012_val_00014525.JPEG +ILSVRC2012_val_00040774.JPEG +ILSVRC2012_val_00036673.JPEG +ILSVRC2012_val_00018987.JPEG +ILSVRC2012_val_00026225.JPEG +ILSVRC2012_val_00011020.JPEG +ILSVRC2012_val_00010100.JPEG +ILSVRC2012_val_00023262.JPEG +ILSVRC2012_val_00026315.JPEG +ILSVRC2012_val_00001039.JPEG +ILSVRC2012_val_00037045.JPEG +ILSVRC2012_val_00015111.JPEG +ILSVRC2012_val_00008865.JPEG +ILSVRC2012_val_00036727.JPEG +ILSVRC2012_val_00036268.JPEG +ILSVRC2012_val_00026560.JPEG +ILSVRC2012_val_00017152.JPEG +ILSVRC2012_val_00031172.JPEG +ILSVRC2012_val_00025045.JPEG +ILSVRC2012_val_00031796.JPEG +ILSVRC2012_val_00045860.JPEG +ILSVRC2012_val_00028417.JPEG +ILSVRC2012_val_00002328.JPEG +ILSVRC2012_val_00018363.JPEG +ILSVRC2012_val_00047010.JPEG +ILSVRC2012_val_00014914.JPEG +ILSVRC2012_val_00030772.JPEG +ILSVRC2012_val_00028078.JPEG +ILSVRC2012_val_00028582.JPEG +ILSVRC2012_val_00019814.JPEG +ILSVRC2012_val_00037321.JPEG +ILSVRC2012_val_00036551.JPEG +ILSVRC2012_val_00047182.JPEG +ILSVRC2012_val_00014861.JPEG +ILSVRC2012_val_00017749.JPEG +ILSVRC2012_val_00011444.JPEG +ILSVRC2012_val_00010877.JPEG +ILSVRC2012_val_00021048.JPEG +ILSVRC2012_val_00004349.JPEG +ILSVRC2012_val_00004562.JPEG +ILSVRC2012_val_00041682.JPEG +ILSVRC2012_val_00009858.JPEG +ILSVRC2012_val_00040891.JPEG +ILSVRC2012_val_00009410.JPEG +ILSVRC2012_val_00019149.JPEG +ILSVRC2012_val_00018914.JPEG +ILSVRC2012_val_00016638.JPEG +ILSVRC2012_val_00023027.JPEG +ILSVRC2012_val_00016078.JPEG +ILSVRC2012_val_00041613.JPEG +ILSVRC2012_val_00011525.JPEG +ILSVRC2012_val_00005947.JPEG +ILSVRC2012_val_00035823.JPEG +ILSVRC2012_val_00048564.JPEG +ILSVRC2012_val_00032145.JPEG +ILSVRC2012_val_00042045.JPEG +ILSVRC2012_val_00041464.JPEG +ILSVRC2012_val_00021754.JPEG +ILSVRC2012_val_00020084.JPEG +ILSVRC2012_val_00013609.JPEG +ILSVRC2012_val_00028043.JPEG +ILSVRC2012_val_00002434.JPEG +ILSVRC2012_val_00048124.JPEG +ILSVRC2012_val_00035879.JPEG +ILSVRC2012_val_00039557.JPEG +ILSVRC2012_val_00032990.JPEG +ILSVRC2012_val_00038669.JPEG +ILSVRC2012_val_00006209.JPEG +ILSVRC2012_val_00032469.JPEG +ILSVRC2012_val_00028767.JPEG +ILSVRC2012_val_00013266.JPEG +ILSVRC2012_val_00009934.JPEG +ILSVRC2012_val_00010660.JPEG +ILSVRC2012_val_00009626.JPEG +ILSVRC2012_val_00022236.JPEG +ILSVRC2012_val_00011070.JPEG +ILSVRC2012_val_00038151.JPEG +ILSVRC2012_val_00007671.JPEG +ILSVRC2012_val_00017380.JPEG +ILSVRC2012_val_00012628.JPEG +ILSVRC2012_val_00034536.JPEG +ILSVRC2012_val_00004713.JPEG +ILSVRC2012_val_00021248.JPEG +ILSVRC2012_val_00006031.JPEG +ILSVRC2012_val_00021438.JPEG +ILSVRC2012_val_00034986.JPEG +ILSVRC2012_val_00019087.JPEG +ILSVRC2012_val_00003229.JPEG +ILSVRC2012_val_00020956.JPEG +ILSVRC2012_val_00040308.JPEG +ILSVRC2012_val_00011527.JPEG +ILSVRC2012_val_00026993.JPEG +ILSVRC2012_val_00027622.JPEG +ILSVRC2012_val_00032538.JPEG +ILSVRC2012_val_00047716.JPEG +ILSVRC2012_val_00010392.JPEG +ILSVRC2012_val_00000675.JPEG +ILSVRC2012_val_00025774.JPEG +ILSVRC2012_val_00017628.JPEG +ILSVRC2012_val_00000847.JPEG +ILSVRC2012_val_00048279.JPEG +ILSVRC2012_val_00015269.JPEG +ILSVRC2012_val_00000600.JPEG +ILSVRC2012_val_00010539.JPEG +ILSVRC2012_val_00016607.JPEG +ILSVRC2012_val_00034486.JPEG +ILSVRC2012_val_00033866.JPEG +ILSVRC2012_val_00005448.JPEG +ILSVRC2012_val_00001035.JPEG +ILSVRC2012_val_00048695.JPEG +ILSVRC2012_val_00027965.JPEG +ILSVRC2012_val_00047558.JPEG +ILSVRC2012_val_00007631.JPEG +ILSVRC2012_val_00031215.JPEG +ILSVRC2012_val_00003562.JPEG +ILSVRC2012_val_00033526.JPEG +ILSVRC2012_val_00036195.JPEG +ILSVRC2012_val_00003901.JPEG +ILSVRC2012_val_00043042.JPEG +ILSVRC2012_val_00046995.JPEG +ILSVRC2012_val_00005584.JPEG +ILSVRC2012_val_00030673.JPEG +ILSVRC2012_val_00041244.JPEG +ILSVRC2012_val_00049573.JPEG +ILSVRC2012_val_00010319.JPEG +ILSVRC2012_val_00027271.JPEG +ILSVRC2012_val_00018801.JPEG +ILSVRC2012_val_00044171.JPEG +ILSVRC2012_val_00037763.JPEG +ILSVRC2012_val_00035903.JPEG +ILSVRC2012_val_00004092.JPEG +ILSVRC2012_val_00001179.JPEG +ILSVRC2012_val_00026821.JPEG +ILSVRC2012_val_00012524.JPEG +ILSVRC2012_val_00037323.JPEG +ILSVRC2012_val_00004400.JPEG +ILSVRC2012_val_00045368.JPEG +ILSVRC2012_val_00018779.JPEG +ILSVRC2012_val_00003785.JPEG +ILSVRC2012_val_00027863.JPEG +ILSVRC2012_val_00046646.JPEG +ILSVRC2012_val_00035531.JPEG +ILSVRC2012_val_00009256.JPEG +ILSVRC2012_val_00003749.JPEG +ILSVRC2012_val_00026433.JPEG +ILSVRC2012_val_00029443.JPEG +ILSVRC2012_val_00013335.JPEG +ILSVRC2012_val_00047582.JPEG +ILSVRC2012_val_00049184.JPEG +ILSVRC2012_val_00009762.JPEG +ILSVRC2012_val_00024750.JPEG +ILSVRC2012_val_00015941.JPEG +ILSVRC2012_val_00036039.JPEG +ILSVRC2012_val_00019174.JPEG +ILSVRC2012_val_00045340.JPEG +ILSVRC2012_val_00006763.JPEG +ILSVRC2012_val_00023467.JPEG +ILSVRC2012_val_00031949.JPEG +ILSVRC2012_val_00045609.JPEG +ILSVRC2012_val_00018347.JPEG +ILSVRC2012_val_00021814.JPEG +ILSVRC2012_val_00021683.JPEG +ILSVRC2012_val_00042074.JPEG +ILSVRC2012_val_00003410.JPEG +ILSVRC2012_val_00025424.JPEG +ILSVRC2012_val_00046273.JPEG +ILSVRC2012_val_00035630.JPEG +ILSVRC2012_val_00035027.JPEG +ILSVRC2012_val_00011279.JPEG +ILSVRC2012_val_00000500.JPEG +ILSVRC2012_val_00001458.JPEG +ILSVRC2012_val_00030641.JPEG +ILSVRC2012_val_00038449.JPEG +ILSVRC2012_val_00016563.JPEG +ILSVRC2012_val_00007299.JPEG +ILSVRC2012_val_00001163.JPEG +ILSVRC2012_val_00034337.JPEG +ILSVRC2012_val_00004855.JPEG +ILSVRC2012_val_00044486.JPEG +ILSVRC2012_val_00026008.JPEG +ILSVRC2012_val_00025662.JPEG +ILSVRC2012_val_00033107.JPEG +ILSVRC2012_val_00036493.JPEG +ILSVRC2012_val_00044932.JPEG +ILSVRC2012_val_00040527.JPEG +ILSVRC2012_val_00047982.JPEG +ILSVRC2012_val_00031109.JPEG +ILSVRC2012_val_00018482.JPEG +ILSVRC2012_val_00036646.JPEG +ILSVRC2012_val_00027693.JPEG +ILSVRC2012_val_00043979.JPEG +ILSVRC2012_val_00045173.JPEG +ILSVRC2012_val_00035387.JPEG +ILSVRC2012_val_00011409.JPEG +ILSVRC2012_val_00022215.JPEG +ILSVRC2012_val_00001539.JPEG +ILSVRC2012_val_00030937.JPEG +ILSVRC2012_val_00026213.JPEG +ILSVRC2012_val_00038214.JPEG +ILSVRC2012_val_00013273.JPEG +ILSVRC2012_val_00003377.JPEG +ILSVRC2012_val_00004543.JPEG +ILSVRC2012_val_00023561.JPEG +ILSVRC2012_val_00044028.JPEG +ILSVRC2012_val_00041735.JPEG +ILSVRC2012_val_00032670.JPEG +ILSVRC2012_val_00036429.JPEG +ILSVRC2012_val_00040814.JPEG +ILSVRC2012_val_00039316.JPEG +ILSVRC2012_val_00015768.JPEG +ILSVRC2012_val_00025363.JPEG +ILSVRC2012_val_00041085.JPEG +ILSVRC2012_val_00047255.JPEG +ILSVRC2012_val_00000351.JPEG +ILSVRC2012_val_00029602.JPEG +ILSVRC2012_val_00032837.JPEG +ILSVRC2012_val_00042925.JPEG +ILSVRC2012_val_00028758.JPEG +ILSVRC2012_val_00036342.JPEG +ILSVRC2012_val_00016269.JPEG +ILSVRC2012_val_00033975.JPEG +ILSVRC2012_val_00019988.JPEG +ILSVRC2012_val_00037086.JPEG +ILSVRC2012_val_00022973.JPEG +ILSVRC2012_val_00039615.JPEG +ILSVRC2012_val_00016398.JPEG +ILSVRC2012_val_00041008.JPEG +ILSVRC2012_val_00032932.JPEG +ILSVRC2012_val_00023332.JPEG +ILSVRC2012_val_00041730.JPEG +ILSVRC2012_val_00024871.JPEG +ILSVRC2012_val_00010468.JPEG +ILSVRC2012_val_00030892.JPEG +ILSVRC2012_val_00036276.JPEG +ILSVRC2012_val_00035396.JPEG +ILSVRC2012_val_00032130.JPEG +ILSVRC2012_val_00026577.JPEG +ILSVRC2012_val_00024694.JPEG +ILSVRC2012_val_00036816.JPEG +ILSVRC2012_val_00002837.JPEG +ILSVRC2012_val_00001904.JPEG +ILSVRC2012_val_00027628.JPEG +ILSVRC2012_val_00000924.JPEG +ILSVRC2012_val_00001419.JPEG +ILSVRC2012_val_00009305.JPEG +ILSVRC2012_val_00026254.JPEG +ILSVRC2012_val_00022314.JPEG +ILSVRC2012_val_00017475.JPEG +ILSVRC2012_val_00037598.JPEG +ILSVRC2012_val_00047049.JPEG +ILSVRC2012_val_00042656.JPEG +ILSVRC2012_val_00042869.JPEG +ILSVRC2012_val_00000935.JPEG +ILSVRC2012_val_00010682.JPEG +ILSVRC2012_val_00039528.JPEG +ILSVRC2012_val_00028668.JPEG +ILSVRC2012_val_00046110.JPEG +ILSVRC2012_val_00031907.JPEG +ILSVRC2012_val_00016108.JPEG +ILSVRC2012_val_00011887.JPEG +ILSVRC2012_val_00023563.JPEG +ILSVRC2012_val_00032793.JPEG +ILSVRC2012_val_00038253.JPEG +ILSVRC2012_val_00035815.JPEG +ILSVRC2012_val_00021843.JPEG +ILSVRC2012_val_00006199.JPEG +ILSVRC2012_val_00048341.JPEG +ILSVRC2012_val_00033478.JPEG +ILSVRC2012_val_00010965.JPEG +ILSVRC2012_val_00005142.JPEG +ILSVRC2012_val_00039507.JPEG +ILSVRC2012_val_00013694.JPEG +ILSVRC2012_val_00009299.JPEG +ILSVRC2012_val_00014336.JPEG +ILSVRC2012_val_00028277.JPEG +ILSVRC2012_val_00039872.JPEG +ILSVRC2012_val_00022328.JPEG +ILSVRC2012_val_00045682.JPEG +ILSVRC2012_val_00019740.JPEG +ILSVRC2012_val_00021287.JPEG +ILSVRC2012_val_00012303.JPEG +ILSVRC2012_val_00025794.JPEG +ILSVRC2012_val_00048074.JPEG +ILSVRC2012_val_00001604.JPEG +ILSVRC2012_val_00032053.JPEG +ILSVRC2012_val_00028745.JPEG +ILSVRC2012_val_00012804.JPEG +ILSVRC2012_val_00029036.JPEG +ILSVRC2012_val_00043931.JPEG +ILSVRC2012_val_00024114.JPEG +ILSVRC2012_val_00031573.JPEG +ILSVRC2012_val_00020872.JPEG +ILSVRC2012_val_00047379.JPEG +ILSVRC2012_val_00023841.JPEG +ILSVRC2012_val_00040693.JPEG +ILSVRC2012_val_00002704.JPEG +ILSVRC2012_val_00003079.JPEG +ILSVRC2012_val_00019481.JPEG +ILSVRC2012_val_00030953.JPEG +ILSVRC2012_val_00009530.JPEG +ILSVRC2012_val_00039955.JPEG +ILSVRC2012_val_00043148.JPEG +ILSVRC2012_val_00044733.JPEG +ILSVRC2012_val_00042102.JPEG +ILSVRC2012_val_00049589.JPEG +ILSVRC2012_val_00001706.JPEG +ILSVRC2012_val_00049355.JPEG +ILSVRC2012_val_00035915.JPEG +ILSVRC2012_val_00037883.JPEG +ILSVRC2012_val_00011328.JPEG +ILSVRC2012_val_00001571.JPEG +ILSVRC2012_val_00029964.JPEG +ILSVRC2012_val_00024329.JPEG +ILSVRC2012_val_00040651.JPEG +ILSVRC2012_val_00009469.JPEG +ILSVRC2012_val_00013113.JPEG +ILSVRC2012_val_00020385.JPEG +ILSVRC2012_val_00020780.JPEG +ILSVRC2012_val_00003322.JPEG +ILSVRC2012_val_00008605.JPEG +ILSVRC2012_val_00049741.JPEG +ILSVRC2012_val_00048017.JPEG +ILSVRC2012_val_00014616.JPEG +ILSVRC2012_val_00009455.JPEG +ILSVRC2012_val_00022418.JPEG +ILSVRC2012_val_00033156.JPEG +ILSVRC2012_val_00025460.JPEG +ILSVRC2012_val_00046219.JPEG +ILSVRC2012_val_00012035.JPEG +ILSVRC2012_val_00036371.JPEG +ILSVRC2012_val_00039433.JPEG +ILSVRC2012_val_00000622.JPEG +ILSVRC2012_val_00008420.JPEG +ILSVRC2012_val_00009002.JPEG +ILSVRC2012_val_00002618.JPEG +ILSVRC2012_val_00040966.JPEG +ILSVRC2012_val_00030207.JPEG +ILSVRC2012_val_00019653.JPEG +ILSVRC2012_val_00005188.JPEG +ILSVRC2012_val_00017463.JPEG +ILSVRC2012_val_00032088.JPEG +ILSVRC2012_val_00001262.JPEG +ILSVRC2012_val_00004712.JPEG +ILSVRC2012_val_00049897.JPEG +ILSVRC2012_val_00038794.JPEG +ILSVRC2012_val_00002607.JPEG +ILSVRC2012_val_00020651.JPEG +ILSVRC2012_val_00022499.JPEG +ILSVRC2012_val_00022042.JPEG +ILSVRC2012_val_00031694.JPEG +ILSVRC2012_val_00011077.JPEG +ILSVRC2012_val_00049712.JPEG +ILSVRC2012_val_00016813.JPEG +ILSVRC2012_val_00029212.JPEG +ILSVRC2012_val_00041773.JPEG +ILSVRC2012_val_00016125.JPEG +ILSVRC2012_val_00046274.JPEG +ILSVRC2012_val_00012379.JPEG +ILSVRC2012_val_00031388.JPEG +ILSVRC2012_val_00022261.JPEG +ILSVRC2012_val_00028260.JPEG +ILSVRC2012_val_00047585.JPEG +ILSVRC2012_val_00002991.JPEG +ILSVRC2012_val_00021364.JPEG +ILSVRC2012_val_00012807.JPEG +ILSVRC2012_val_00030373.JPEG +ILSVRC2012_val_00029210.JPEG +ILSVRC2012_val_00031100.JPEG +ILSVRC2012_val_00039568.JPEG +ILSVRC2012_val_00045568.JPEG +ILSVRC2012_val_00013382.JPEG +ILSVRC2012_val_00042321.JPEG +ILSVRC2012_val_00044596.JPEG +ILSVRC2012_val_00010608.JPEG +ILSVRC2012_val_00026368.JPEG +ILSVRC2012_val_00006470.JPEG +ILSVRC2012_val_00042660.JPEG +ILSVRC2012_val_00044767.JPEG +ILSVRC2012_val_00002664.JPEG +ILSVRC2012_val_00021490.JPEG +ILSVRC2012_val_00046857.JPEG +ILSVRC2012_val_00012062.JPEG +ILSVRC2012_val_00020654.JPEG +ILSVRC2012_val_00021989.JPEG +ILSVRC2012_val_00043918.JPEG +ILSVRC2012_val_00042353.JPEG +ILSVRC2012_val_00024088.JPEG +ILSVRC2012_val_00035875.JPEG +ILSVRC2012_val_00022087.JPEG +ILSVRC2012_val_00003072.JPEG +ILSVRC2012_val_00045061.JPEG +ILSVRC2012_val_00000834.JPEG +ILSVRC2012_val_00014571.JPEG +ILSVRC2012_val_00027456.JPEG +ILSVRC2012_val_00011553.JPEG +ILSVRC2012_val_00045384.JPEG +ILSVRC2012_val_00028217.JPEG +ILSVRC2012_val_00023812.JPEG +ILSVRC2012_val_00018841.JPEG +ILSVRC2012_val_00019079.JPEG +ILSVRC2012_val_00048455.JPEG +ILSVRC2012_val_00038239.JPEG +ILSVRC2012_val_00037001.JPEG +ILSVRC2012_val_00035746.JPEG +ILSVRC2012_val_00030834.JPEG +ILSVRC2012_val_00014804.JPEG +ILSVRC2012_val_00036891.JPEG +ILSVRC2012_val_00037200.JPEG +ILSVRC2012_val_00008132.JPEG +ILSVRC2012_val_00042991.JPEG +ILSVRC2012_val_00014875.JPEG +ILSVRC2012_val_00007633.JPEG +ILSVRC2012_val_00020547.JPEG +ILSVRC2012_val_00007445.JPEG +ILSVRC2012_val_00031415.JPEG +ILSVRC2012_val_00044098.JPEG +ILSVRC2012_val_00018082.JPEG +ILSVRC2012_val_00049463.JPEG +ILSVRC2012_val_00020149.JPEG +ILSVRC2012_val_00036513.JPEG +ILSVRC2012_val_00023493.JPEG +ILSVRC2012_val_00030823.JPEG +ILSVRC2012_val_00038001.JPEG +ILSVRC2012_val_00048528.JPEG +ILSVRC2012_val_00002725.JPEG +ILSVRC2012_val_00027642.JPEG +ILSVRC2012_val_00014085.JPEG +ILSVRC2012_val_00019604.JPEG +ILSVRC2012_val_00007141.JPEG +ILSVRC2012_val_00019907.JPEG +ILSVRC2012_val_00034719.JPEG +ILSVRC2012_val_00048579.JPEG +ILSVRC2012_val_00017533.JPEG +ILSVRC2012_val_00015367.JPEG +ILSVRC2012_val_00038739.JPEG +ILSVRC2012_val_00036110.JPEG +ILSVRC2012_val_00034088.JPEG +ILSVRC2012_val_00000383.JPEG +ILSVRC2012_val_00015615.JPEG +ILSVRC2012_val_00004131.JPEG +ILSVRC2012_val_00012197.JPEG +ILSVRC2012_val_00023067.JPEG +ILSVRC2012_val_00006905.JPEG +ILSVRC2012_val_00038342.JPEG +ILSVRC2012_val_00046025.JPEG +ILSVRC2012_val_00015782.JPEG +ILSVRC2012_val_00028348.JPEG +ILSVRC2012_val_00019674.JPEG +ILSVRC2012_val_00049874.JPEG +ILSVRC2012_val_00048064.JPEG +ILSVRC2012_val_00041566.JPEG +ILSVRC2012_val_00007810.JPEG +ILSVRC2012_val_00018829.JPEG +ILSVRC2012_val_00025170.JPEG +ILSVRC2012_val_00020285.JPEG +ILSVRC2012_val_00030678.JPEG +ILSVRC2012_val_00028030.JPEG +ILSVRC2012_val_00021448.JPEG +ILSVRC2012_val_00031754.JPEG +ILSVRC2012_val_00043480.JPEG +ILSVRC2012_val_00035728.JPEG +ILSVRC2012_val_00014610.JPEG +ILSVRC2012_val_00019669.JPEG +ILSVRC2012_val_00009550.JPEG +ILSVRC2012_val_00049092.JPEG +ILSVRC2012_val_00029199.JPEG +ILSVRC2012_val_00015099.JPEG +ILSVRC2012_val_00007211.JPEG +ILSVRC2012_val_00039867.JPEG +ILSVRC2012_val_00038112.JPEG +ILSVRC2012_val_00032266.JPEG +ILSVRC2012_val_00032186.JPEG +ILSVRC2012_val_00035049.JPEG +ILSVRC2012_val_00021236.JPEG +ILSVRC2012_val_00049483.JPEG +ILSVRC2012_val_00011410.JPEG +ILSVRC2012_val_00042436.JPEG +ILSVRC2012_val_00049329.JPEG +ILSVRC2012_val_00027024.JPEG +ILSVRC2012_val_00036713.JPEG +ILSVRC2012_val_00035245.JPEG +ILSVRC2012_val_00041316.JPEG +ILSVRC2012_val_00045718.JPEG +ILSVRC2012_val_00003469.JPEG +ILSVRC2012_val_00040459.JPEG +ILSVRC2012_val_00000071.JPEG +ILSVRC2012_val_00036701.JPEG +ILSVRC2012_val_00000867.JPEG +ILSVRC2012_val_00021760.JPEG +ILSVRC2012_val_00001055.JPEG +ILSVRC2012_val_00029126.JPEG +ILSVRC2012_val_00015773.JPEG +ILSVRC2012_val_00002963.JPEG +ILSVRC2012_val_00019915.JPEG +ILSVRC2012_val_00047617.JPEG +ILSVRC2012_val_00031451.JPEG +ILSVRC2012_val_00012707.JPEG +ILSVRC2012_val_00016024.JPEG +ILSVRC2012_val_00005678.JPEG +ILSVRC2012_val_00025388.JPEG +ILSVRC2012_val_00039697.JPEG +ILSVRC2012_val_00016332.JPEG +ILSVRC2012_val_00028142.JPEG +ILSVRC2012_val_00040030.JPEG +ILSVRC2012_val_00019580.JPEG +ILSVRC2012_val_00027454.JPEG +ILSVRC2012_val_00034066.JPEG +ILSVRC2012_val_00027981.JPEG +ILSVRC2012_val_00022567.JPEG +ILSVRC2012_val_00036680.JPEG +ILSVRC2012_val_00035922.JPEG +ILSVRC2012_val_00027053.JPEG +ILSVRC2012_val_00017893.JPEG +ILSVRC2012_val_00033029.JPEG +ILSVRC2012_val_00046555.JPEG +ILSVRC2012_val_00003154.JPEG +ILSVRC2012_val_00017131.JPEG +ILSVRC2012_val_00030425.JPEG +ILSVRC2012_val_00025025.JPEG +ILSVRC2012_val_00032058.JPEG +ILSVRC2012_val_00037019.JPEG +ILSVRC2012_val_00004399.JPEG +ILSVRC2012_val_00025996.JPEG +ILSVRC2012_val_00005057.JPEG +ILSVRC2012_val_00037511.JPEG +ILSVRC2012_val_00045065.JPEG +ILSVRC2012_val_00011848.JPEG +ILSVRC2012_val_00026292.JPEG +ILSVRC2012_val_00011777.JPEG +ILSVRC2012_val_00033121.JPEG +ILSVRC2012_val_00045263.JPEG +ILSVRC2012_val_00044859.JPEG +ILSVRC2012_val_00006451.JPEG +ILSVRC2012_val_00016791.JPEG +ILSVRC2012_val_00044133.JPEG +ILSVRC2012_val_00043900.JPEG +ILSVRC2012_val_00008607.JPEG +ILSVRC2012_val_00011304.JPEG +ILSVRC2012_val_00046998.JPEG +ILSVRC2012_val_00001313.JPEG +ILSVRC2012_val_00042215.JPEG +ILSVRC2012_val_00049176.JPEG +ILSVRC2012_val_00029742.JPEG +ILSVRC2012_val_00026306.JPEG +ILSVRC2012_val_00045261.JPEG +ILSVRC2012_val_00002681.JPEG +ILSVRC2012_val_00027620.JPEG +ILSVRC2012_val_00014169.JPEG +ILSVRC2012_val_00020432.JPEG +ILSVRC2012_val_00011857.JPEG +ILSVRC2012_val_00009342.JPEG +ILSVRC2012_val_00033411.JPEG +ILSVRC2012_val_00035107.JPEG +ILSVRC2012_val_00039378.JPEG +ILSVRC2012_val_00048565.JPEG +ILSVRC2012_val_00048366.JPEG +ILSVRC2012_val_00044658.JPEG +ILSVRC2012_val_00003432.JPEG +ILSVRC2012_val_00033668.JPEG +ILSVRC2012_val_00047533.JPEG +ILSVRC2012_val_00041405.JPEG +ILSVRC2012_val_00032704.JPEG +ILSVRC2012_val_00028306.JPEG +ILSVRC2012_val_00002067.JPEG +ILSVRC2012_val_00001732.JPEG +ILSVRC2012_val_00014504.JPEG +ILSVRC2012_val_00006460.JPEG +ILSVRC2012_val_00012165.JPEG +ILSVRC2012_val_00042325.JPEG +ILSVRC2012_val_00015914.JPEG +ILSVRC2012_val_00021330.JPEG +ILSVRC2012_val_00035838.JPEG +ILSVRC2012_val_00015016.JPEG +ILSVRC2012_val_00002200.JPEG +ILSVRC2012_val_00033169.JPEG +ILSVRC2012_val_00020349.JPEG +ILSVRC2012_val_00034176.JPEG +ILSVRC2012_val_00001208.JPEG +ILSVRC2012_val_00032447.JPEG +ILSVRC2012_val_00001435.JPEG +ILSVRC2012_val_00031013.JPEG +ILSVRC2012_val_00004663.JPEG +ILSVRC2012_val_00018298.JPEG +ILSVRC2012_val_00003457.JPEG +ILSVRC2012_val_00016873.JPEG +ILSVRC2012_val_00035284.JPEG +ILSVRC2012_val_00021406.JPEG +ILSVRC2012_val_00027268.JPEG +ILSVRC2012_val_00007326.JPEG +ILSVRC2012_val_00006127.JPEG +ILSVRC2012_val_00002627.JPEG +ILSVRC2012_val_00042204.JPEG +ILSVRC2012_val_00044683.JPEG +ILSVRC2012_val_00009126.JPEG +ILSVRC2012_val_00009097.JPEG +ILSVRC2012_val_00017564.JPEG +ILSVRC2012_val_00017874.JPEG +ILSVRC2012_val_00022546.JPEG +ILSVRC2012_val_00048002.JPEG +ILSVRC2012_val_00000430.JPEG +ILSVRC2012_val_00038064.JPEG +ILSVRC2012_val_00031153.JPEG +ILSVRC2012_val_00023714.JPEG +ILSVRC2012_val_00015713.JPEG +ILSVRC2012_val_00009700.JPEG +ILSVRC2012_val_00009531.JPEG +ILSVRC2012_val_00034469.JPEG +ILSVRC2012_val_00026486.JPEG +ILSVRC2012_val_00028276.JPEG +ILSVRC2012_val_00006740.JPEG +ILSVRC2012_val_00023673.JPEG +ILSVRC2012_val_00026878.JPEG +ILSVRC2012_val_00003103.JPEG +ILSVRC2012_val_00015208.JPEG +ILSVRC2012_val_00038754.JPEG +ILSVRC2012_val_00030087.JPEG +ILSVRC2012_val_00047068.JPEG +ILSVRC2012_val_00013548.JPEG +ILSVRC2012_val_00025080.JPEG +ILSVRC2012_val_00029470.JPEG +ILSVRC2012_val_00023402.JPEG +ILSVRC2012_val_00030180.JPEG +ILSVRC2012_val_00034568.JPEG +ILSVRC2012_val_00014888.JPEG +ILSVRC2012_val_00022289.JPEG +ILSVRC2012_val_00005251.JPEG +ILSVRC2012_val_00003812.JPEG +ILSVRC2012_val_00000868.JPEG +ILSVRC2012_val_00048933.JPEG +ILSVRC2012_val_00031674.JPEG +ILSVRC2012_val_00010365.JPEG +ILSVRC2012_val_00031993.JPEG +ILSVRC2012_val_00020781.JPEG +ILSVRC2012_val_00039298.JPEG +ILSVRC2012_val_00036412.JPEG +ILSVRC2012_val_00020030.JPEG +ILSVRC2012_val_00026160.JPEG +ILSVRC2012_val_00011768.JPEG +ILSVRC2012_val_00016173.JPEG +ILSVRC2012_val_00019705.JPEG +ILSVRC2012_val_00045805.JPEG +ILSVRC2012_val_00029448.JPEG +ILSVRC2012_val_00001705.JPEG +ILSVRC2012_val_00014382.JPEG +ILSVRC2012_val_00041290.JPEG +ILSVRC2012_val_00019811.JPEG +ILSVRC2012_val_00021011.JPEG +ILSVRC2012_val_00028047.JPEG +ILSVRC2012_val_00022124.JPEG +ILSVRC2012_val_00020740.JPEG +ILSVRC2012_val_00001422.JPEG +ILSVRC2012_val_00001736.JPEG +ILSVRC2012_val_00001998.JPEG +ILSVRC2012_val_00002217.JPEG +ILSVRC2012_val_00033399.JPEG +ILSVRC2012_val_00040951.JPEG +ILSVRC2012_val_00007229.JPEG +ILSVRC2012_val_00020069.JPEG +ILSVRC2012_val_00025229.JPEG +ILSVRC2012_val_00021105.JPEG +ILSVRC2012_val_00044105.JPEG +ILSVRC2012_val_00021084.JPEG +ILSVRC2012_val_00004252.JPEG +ILSVRC2012_val_00040577.JPEG +ILSVRC2012_val_00008202.JPEG +ILSVRC2012_val_00004271.JPEG +ILSVRC2012_val_00045627.JPEG +ILSVRC2012_val_00047543.JPEG +ILSVRC2012_val_00008853.JPEG +ILSVRC2012_val_00017189.JPEG +ILSVRC2012_val_00020499.JPEG +ILSVRC2012_val_00040968.JPEG +ILSVRC2012_val_00017833.JPEG +ILSVRC2012_val_00030580.JPEG +ILSVRC2012_val_00034959.JPEG +ILSVRC2012_val_00030030.JPEG +ILSVRC2012_val_00038021.JPEG +ILSVRC2012_val_00040699.JPEG +ILSVRC2012_val_00025476.JPEG +ILSVRC2012_val_00018615.JPEG +ILSVRC2012_val_00025344.JPEG +ILSVRC2012_val_00035646.JPEG +ILSVRC2012_val_00002642.JPEG +ILSVRC2012_val_00017467.JPEG +ILSVRC2012_val_00017082.JPEG +ILSVRC2012_val_00045560.JPEG +ILSVRC2012_val_00023577.JPEG +ILSVRC2012_val_00019005.JPEG +ILSVRC2012_val_00010139.JPEG +ILSVRC2012_val_00034485.JPEG +ILSVRC2012_val_00001849.JPEG +ILSVRC2012_val_00042760.JPEG +ILSVRC2012_val_00028538.JPEG +ILSVRC2012_val_00048610.JPEG +ILSVRC2012_val_00023374.JPEG +ILSVRC2012_val_00016619.JPEG +ILSVRC2012_val_00025456.JPEG +ILSVRC2012_val_00003842.JPEG +ILSVRC2012_val_00026801.JPEG +ILSVRC2012_val_00007623.JPEG +ILSVRC2012_val_00043720.JPEG +ILSVRC2012_val_00026810.JPEG +ILSVRC2012_val_00043997.JPEG +ILSVRC2012_val_00040506.JPEG +ILSVRC2012_val_00026404.JPEG +ILSVRC2012_val_00017070.JPEG +ILSVRC2012_val_00037577.JPEG +ILSVRC2012_val_00005862.JPEG +ILSVRC2012_val_00030483.JPEG +ILSVRC2012_val_00012921.JPEG +ILSVRC2012_val_00038308.JPEG +ILSVRC2012_val_00021110.JPEG +ILSVRC2012_val_00022280.JPEG +ILSVRC2012_val_00038074.JPEG +ILSVRC2012_val_00012038.JPEG +ILSVRC2012_val_00026054.JPEG +ILSVRC2012_val_00010289.JPEG +ILSVRC2012_val_00039273.JPEG +ILSVRC2012_val_00007280.JPEG +ILSVRC2012_val_00007907.JPEG +ILSVRC2012_val_00047924.JPEG +ILSVRC2012_val_00029219.JPEG +ILSVRC2012_val_00027075.JPEG +ILSVRC2012_val_00029370.JPEG +ILSVRC2012_val_00046916.JPEG +ILSVRC2012_val_00019353.JPEG +ILSVRC2012_val_00027316.JPEG +ILSVRC2012_val_00001478.JPEG +ILSVRC2012_val_00010743.JPEG +ILSVRC2012_val_00035904.JPEG +ILSVRC2012_val_00009526.JPEG +ILSVRC2012_val_00049213.JPEG +ILSVRC2012_val_00014884.JPEG +ILSVRC2012_val_00010182.JPEG +ILSVRC2012_val_00029679.JPEG +ILSVRC2012_val_00033177.JPEG +ILSVRC2012_val_00022350.JPEG +ILSVRC2012_val_00040667.JPEG +ILSVRC2012_val_00045476.JPEG +ILSVRC2012_val_00016793.JPEG +ILSVRC2012_val_00043880.JPEG +ILSVRC2012_val_00037562.JPEG +ILSVRC2012_val_00045267.JPEG +ILSVRC2012_val_00043714.JPEG +ILSVRC2012_val_00006534.JPEG +ILSVRC2012_val_00016171.JPEG +ILSVRC2012_val_00035693.JPEG +ILSVRC2012_val_00040010.JPEG +ILSVRC2012_val_00002466.JPEG +ILSVRC2012_val_00036176.JPEG +ILSVRC2012_val_00040031.JPEG +ILSVRC2012_val_00002916.JPEG +ILSVRC2012_val_00045910.JPEG +ILSVRC2012_val_00025480.JPEG +ILSVRC2012_val_00047551.JPEG +ILSVRC2012_val_00027461.JPEG +ILSVRC2012_val_00038737.JPEG +ILSVRC2012_val_00049361.JPEG +ILSVRC2012_val_00026019.JPEG +ILSVRC2012_val_00026650.JPEG +ILSVRC2012_val_00045950.JPEG +ILSVRC2012_val_00023893.JPEG +ILSVRC2012_val_00025401.JPEG +ILSVRC2012_val_00047339.JPEG +ILSVRC2012_val_00021419.JPEG +ILSVRC2012_val_00019877.JPEG +ILSVRC2012_val_00019919.JPEG +ILSVRC2012_val_00022965.JPEG +ILSVRC2012_val_00029112.JPEG +ILSVRC2012_val_00019201.JPEG +ILSVRC2012_val_00036554.JPEG +ILSVRC2012_val_00026883.JPEG +ILSVRC2012_val_00001607.JPEG +ILSVRC2012_val_00025758.JPEG +ILSVRC2012_val_00014742.JPEG +ILSVRC2012_val_00012660.JPEG +ILSVRC2012_val_00045264.JPEG +ILSVRC2012_val_00013859.JPEG +ILSVRC2012_val_00039053.JPEG +ILSVRC2012_val_00017397.JPEG +ILSVRC2012_val_00001757.JPEG +ILSVRC2012_val_00010218.JPEG +ILSVRC2012_val_00033092.JPEG +ILSVRC2012_val_00004874.JPEG +ILSVRC2012_val_00039156.JPEG +ILSVRC2012_val_00040207.JPEG +ILSVRC2012_val_00039267.JPEG +ILSVRC2012_val_00000720.JPEG +ILSVRC2012_val_00000551.JPEG +ILSVRC2012_val_00006291.JPEG +ILSVRC2012_val_00021894.JPEG +ILSVRC2012_val_00030125.JPEG +ILSVRC2012_val_00003860.JPEG +ILSVRC2012_val_00028883.JPEG +ILSVRC2012_val_00041331.JPEG +ILSVRC2012_val_00034495.JPEG +ILSVRC2012_val_00046778.JPEG +ILSVRC2012_val_00020374.JPEG +ILSVRC2012_val_00039110.JPEG +ILSVRC2012_val_00012537.JPEG +ILSVRC2012_val_00025400.JPEG +ILSVRC2012_val_00005037.JPEG +ILSVRC2012_val_00035958.JPEG +ILSVRC2012_val_00037035.JPEG +ILSVRC2012_val_00036779.JPEG +ILSVRC2012_val_00013777.JPEG +ILSVRC2012_val_00049880.JPEG +ILSVRC2012_val_00040228.JPEG +ILSVRC2012_val_00020453.JPEG +ILSVRC2012_val_00027223.JPEG +ILSVRC2012_val_00025890.JPEG +ILSVRC2012_val_00026552.JPEG +ILSVRC2012_val_00031435.JPEG +ILSVRC2012_val_00037895.JPEG +ILSVRC2012_val_00045272.JPEG +ILSVRC2012_val_00030549.JPEG +ILSVRC2012_val_00033015.JPEG +ILSVRC2012_val_00010567.JPEG +ILSVRC2012_val_00009390.JPEG +ILSVRC2012_val_00026309.JPEG +ILSVRC2012_val_00005315.JPEG +ILSVRC2012_val_00027421.JPEG +ILSVRC2012_val_00006401.JPEG +ILSVRC2012_val_00039771.JPEG +ILSVRC2012_val_00001281.JPEG +ILSVRC2012_val_00033007.JPEG +ILSVRC2012_val_00002534.JPEG +ILSVRC2012_val_00013566.JPEG +ILSVRC2012_val_00044082.JPEG +ILSVRC2012_val_00021031.JPEG +ILSVRC2012_val_00021806.JPEG +ILSVRC2012_val_00006411.JPEG +ILSVRC2012_val_00025923.JPEG +ILSVRC2012_val_00041258.JPEG +ILSVRC2012_val_00036558.JPEG +ILSVRC2012_val_00012661.JPEG +ILSVRC2012_val_00017350.JPEG +ILSVRC2012_val_00043580.JPEG +ILSVRC2012_val_00017933.JPEG +ILSVRC2012_val_00029903.JPEG +ILSVRC2012_val_00012530.JPEG +ILSVRC2012_val_00006781.JPEG +ILSVRC2012_val_00038516.JPEG +ILSVRC2012_val_00041992.JPEG +ILSVRC2012_val_00000594.JPEG +ILSVRC2012_val_00049079.JPEG +ILSVRC2012_val_00008421.JPEG +ILSVRC2012_val_00030667.JPEG +ILSVRC2012_val_00038688.JPEG +ILSVRC2012_val_00043084.JPEG +ILSVRC2012_val_00020813.JPEG +ILSVRC2012_val_00040325.JPEG +ILSVRC2012_val_00005966.JPEG +ILSVRC2012_val_00017505.JPEG +ILSVRC2012_val_00022558.JPEG +ILSVRC2012_val_00041355.JPEG +ILSVRC2012_val_00004531.JPEG +ILSVRC2012_val_00047320.JPEG +ILSVRC2012_val_00046089.JPEG +ILSVRC2012_val_00037709.JPEG +ILSVRC2012_val_00002583.JPEG +ILSVRC2012_val_00020620.JPEG +ILSVRC2012_val_00026128.JPEG +ILSVRC2012_val_00008415.JPEG +ILSVRC2012_val_00030194.JPEG +ILSVRC2012_val_00032709.JPEG +ILSVRC2012_val_00008418.JPEG +ILSVRC2012_val_00043840.JPEG +ILSVRC2012_val_00042443.JPEG +ILSVRC2012_val_00046549.JPEG +ILSVRC2012_val_00042280.JPEG +ILSVRC2012_val_00043235.JPEG +ILSVRC2012_val_00020707.JPEG +ILSVRC2012_val_00026248.JPEG +ILSVRC2012_val_00034379.JPEG +ILSVRC2012_val_00019165.JPEG +ILSVRC2012_val_00038083.JPEG +ILSVRC2012_val_00002654.JPEG +ILSVRC2012_val_00007448.JPEG +ILSVRC2012_val_00043743.JPEG +ILSVRC2012_val_00013160.JPEG +ILSVRC2012_val_00043012.JPEG +ILSVRC2012_val_00019424.JPEG +ILSVRC2012_val_00008666.JPEG +ILSVRC2012_val_00016098.JPEG +ILSVRC2012_val_00031671.JPEG +ILSVRC2012_val_00004675.JPEG +ILSVRC2012_val_00003936.JPEG +ILSVRC2012_val_00000194.JPEG +ILSVRC2012_val_00044702.JPEG +ILSVRC2012_val_00005834.JPEG +ILSVRC2012_val_00024007.JPEG +ILSVRC2012_val_00025213.JPEG +ILSVRC2012_val_00021774.JPEG +ILSVRC2012_val_00036894.JPEG +ILSVRC2012_val_00008825.JPEG +ILSVRC2012_val_00004879.JPEG +ILSVRC2012_val_00016237.JPEG +ILSVRC2012_val_00034479.JPEG +ILSVRC2012_val_00016665.JPEG +ILSVRC2012_val_00049677.JPEG +ILSVRC2012_val_00032674.JPEG +ILSVRC2012_val_00003794.JPEG +ILSVRC2012_val_00038431.JPEG +ILSVRC2012_val_00034188.JPEG +ILSVRC2012_val_00039175.JPEG +ILSVRC2012_val_00016675.JPEG +ILSVRC2012_val_00017888.JPEG +ILSVRC2012_val_00004296.JPEG +ILSVRC2012_val_00003572.JPEG +ILSVRC2012_val_00046960.JPEG +ILSVRC2012_val_00006079.JPEG +ILSVRC2012_val_00047639.JPEG +ILSVRC2012_val_00030336.JPEG +ILSVRC2012_val_00034735.JPEG +ILSVRC2012_val_00021144.JPEG +ILSVRC2012_val_00010940.JPEG +ILSVRC2012_val_00029580.JPEG +ILSVRC2012_val_00004398.JPEG +ILSVRC2012_val_00002347.JPEG +ILSVRC2012_val_00014998.JPEG +ILSVRC2012_val_00010066.JPEG +ILSVRC2012_val_00033055.JPEG +ILSVRC2012_val_00001678.JPEG +ILSVRC2012_val_00024860.JPEG +ILSVRC2012_val_00046183.JPEG +ILSVRC2012_val_00026106.JPEG +ILSVRC2012_val_00015582.JPEG +ILSVRC2012_val_00000621.JPEG +ILSVRC2012_val_00015928.JPEG +ILSVRC2012_val_00014060.JPEG +ILSVRC2012_val_00024950.JPEG +ILSVRC2012_val_00048990.JPEG +ILSVRC2012_val_00026611.JPEG +ILSVRC2012_val_00003260.JPEG +ILSVRC2012_val_00013866.JPEG +ILSVRC2012_val_00006038.JPEG +ILSVRC2012_val_00013515.JPEG +ILSVRC2012_val_00045316.JPEG +ILSVRC2012_val_00003838.JPEG +ILSVRC2012_val_00002834.JPEG +ILSVRC2012_val_00025847.JPEG +ILSVRC2012_val_00042694.JPEG +ILSVRC2012_val_00003518.JPEG +ILSVRC2012_val_00019492.JPEG +ILSVRC2012_val_00009006.JPEG +ILSVRC2012_val_00032160.JPEG +ILSVRC2012_val_00010635.JPEG +ILSVRC2012_val_00017401.JPEG +ILSVRC2012_val_00008028.JPEG +ILSVRC2012_val_00038388.JPEG +ILSVRC2012_val_00000671.JPEG +ILSVRC2012_val_00009414.JPEG +ILSVRC2012_val_00046345.JPEG +ILSVRC2012_val_00048304.JPEG +ILSVRC2012_val_00023359.JPEG +ILSVRC2012_val_00025655.JPEG +ILSVRC2012_val_00014383.JPEG +ILSVRC2012_val_00015905.JPEG +ILSVRC2012_val_00046779.JPEG +ILSVRC2012_val_00022320.JPEG +ILSVRC2012_val_00031960.JPEG +ILSVRC2012_val_00016472.JPEG +ILSVRC2012_val_00042744.JPEG +ILSVRC2012_val_00035859.JPEG +ILSVRC2012_val_00033921.JPEG +ILSVRC2012_val_00009012.JPEG +ILSVRC2012_val_00037473.JPEG +ILSVRC2012_val_00032533.JPEG +ILSVRC2012_val_00030674.JPEG +ILSVRC2012_val_00034402.JPEG +ILSVRC2012_val_00011109.JPEG +ILSVRC2012_val_00028035.JPEG +ILSVRC2012_val_00017985.JPEG +ILSVRC2012_val_00017906.JPEG +ILSVRC2012_val_00026599.JPEG +ILSVRC2012_val_00017673.JPEG +ILSVRC2012_val_00044336.JPEG +ILSVRC2012_val_00013324.JPEG +ILSVRC2012_val_00005466.JPEG +ILSVRC2012_val_00047994.JPEG +ILSVRC2012_val_00047568.JPEG +ILSVRC2012_val_00037362.JPEG +ILSVRC2012_val_00038652.JPEG +ILSVRC2012_val_00040387.JPEG +ILSVRC2012_val_00046350.JPEG +ILSVRC2012_val_00027722.JPEG +ILSVRC2012_val_00048786.JPEG +ILSVRC2012_val_00045722.JPEG +ILSVRC2012_val_00047577.JPEG +ILSVRC2012_val_00049085.JPEG +ILSVRC2012_val_00035252.JPEG +ILSVRC2012_val_00030166.JPEG +ILSVRC2012_val_00026399.JPEG +ILSVRC2012_val_00006970.JPEG +ILSVRC2012_val_00026914.JPEG +ILSVRC2012_val_00007690.JPEG +ILSVRC2012_val_00045050.JPEG +ILSVRC2012_val_00043799.JPEG +ILSVRC2012_val_00042292.JPEG +ILSVRC2012_val_00045218.JPEG +ILSVRC2012_val_00023575.JPEG +ILSVRC2012_val_00035035.JPEG +ILSVRC2012_val_00022705.JPEG +ILSVRC2012_val_00028152.JPEG +ILSVRC2012_val_00042033.JPEG +ILSVRC2012_val_00016307.JPEG +ILSVRC2012_val_00028725.JPEG +ILSVRC2012_val_00032614.JPEG +ILSVRC2012_val_00017839.JPEG +ILSVRC2012_val_00041344.JPEG +ILSVRC2012_val_00012246.JPEG +ILSVRC2012_val_00046579.JPEG +ILSVRC2012_val_00014895.JPEG +ILSVRC2012_val_00026676.JPEG +ILSVRC2012_val_00038291.JPEG +ILSVRC2012_val_00023189.JPEG +ILSVRC2012_val_00001454.JPEG +ILSVRC2012_val_00000276.JPEG +ILSVRC2012_val_00047684.JPEG +ILSVRC2012_val_00046599.JPEG +ILSVRC2012_val_00017271.JPEG +ILSVRC2012_val_00030931.JPEG +ILSVRC2012_val_00013587.JPEG +ILSVRC2012_val_00031797.JPEG +ILSVRC2012_val_00028017.JPEG +ILSVRC2012_val_00008566.JPEG +ILSVRC2012_val_00040621.JPEG +ILSVRC2012_val_00001665.JPEG +ILSVRC2012_val_00030461.JPEG +ILSVRC2012_val_00027077.JPEG +ILSVRC2012_val_00039362.JPEG +ILSVRC2012_val_00011856.JPEG +ILSVRC2012_val_00042968.JPEG +ILSVRC2012_val_00038675.JPEG +ILSVRC2012_val_00039919.JPEG +ILSVRC2012_val_00015036.JPEG +ILSVRC2012_val_00012250.JPEG +ILSVRC2012_val_00018173.JPEG +ILSVRC2012_val_00006377.JPEG +ILSVRC2012_val_00047209.JPEG +ILSVRC2012_val_00043311.JPEG +ILSVRC2012_val_00028927.JPEG +ILSVRC2012_val_00032962.JPEG +ILSVRC2012_val_00035117.JPEG +ILSVRC2012_val_00041535.JPEG +ILSVRC2012_val_00045592.JPEG +ILSVRC2012_val_00003652.JPEG +ILSVRC2012_val_00044652.JPEG +ILSVRC2012_val_00037758.JPEG +ILSVRC2012_val_00020475.JPEG +ILSVRC2012_val_00017373.JPEG +ILSVRC2012_val_00031775.JPEG +ILSVRC2012_val_00005359.JPEG +ILSVRC2012_val_00032372.JPEG +ILSVRC2012_val_00000255.JPEG +ILSVRC2012_val_00007995.JPEG +ILSVRC2012_val_00035178.JPEG +ILSVRC2012_val_00018053.JPEG +ILSVRC2012_val_00043328.JPEG +ILSVRC2012_val_00042974.JPEG +ILSVRC2012_val_00024739.JPEG +ILSVRC2012_val_00047100.JPEG +ILSVRC2012_val_00027687.JPEG +ILSVRC2012_val_00042975.JPEG +ILSVRC2012_val_00049038.JPEG +ILSVRC2012_val_00048835.JPEG +ILSVRC2012_val_00011798.JPEG +ILSVRC2012_val_00005218.JPEG +ILSVRC2012_val_00039206.JPEG +ILSVRC2012_val_00021661.JPEG +ILSVRC2012_val_00010998.JPEG +ILSVRC2012_val_00010424.JPEG +ILSVRC2012_val_00027914.JPEG +ILSVRC2012_val_00035868.JPEG +ILSVRC2012_val_00030234.JPEG +ILSVRC2012_val_00007525.JPEG +ILSVRC2012_val_00039990.JPEG +ILSVRC2012_val_00019432.JPEG +ILSVRC2012_val_00011809.JPEG +ILSVRC2012_val_00002160.JPEG +ILSVRC2012_val_00011418.JPEG +ILSVRC2012_val_00011183.JPEG +ILSVRC2012_val_00028361.JPEG +ILSVRC2012_val_00047709.JPEG +ILSVRC2012_val_00027364.JPEG +ILSVRC2012_val_00041402.JPEG +ILSVRC2012_val_00040146.JPEG +ILSVRC2012_val_00004000.JPEG +ILSVRC2012_val_00021243.JPEG +ILSVRC2012_val_00047708.JPEG +ILSVRC2012_val_00029821.JPEG +ILSVRC2012_val_00043829.JPEG +ILSVRC2012_val_00049097.JPEG +ILSVRC2012_val_00048374.JPEG +ILSVRC2012_val_00040282.JPEG +ILSVRC2012_val_00018394.JPEG +ILSVRC2012_val_00017128.JPEG +ILSVRC2012_val_00032314.JPEG +ILSVRC2012_val_00046643.JPEG +ILSVRC2012_val_00000213.JPEG +ILSVRC2012_val_00020659.JPEG +ILSVRC2012_val_00048877.JPEG +ILSVRC2012_val_00032803.JPEG +ILSVRC2012_val_00036038.JPEG +ILSVRC2012_val_00035358.JPEG +ILSVRC2012_val_00023956.JPEG +ILSVRC2012_val_00010758.JPEG +ILSVRC2012_val_00045388.JPEG +ILSVRC2012_val_00032973.JPEG +ILSVRC2012_val_00014507.JPEG +ILSVRC2012_val_00035662.JPEG +ILSVRC2012_val_00012136.JPEG +ILSVRC2012_val_00006427.JPEG +ILSVRC2012_val_00005442.JPEG +ILSVRC2012_val_00022651.JPEG +ILSVRC2012_val_00018242.JPEG +ILSVRC2012_val_00012774.JPEG +ILSVRC2012_val_00045432.JPEG +ILSVRC2012_val_00001553.JPEG +ILSVRC2012_val_00034265.JPEG +ILSVRC2012_val_00032038.JPEG +ILSVRC2012_val_00033245.JPEG +ILSVRC2012_val_00009083.JPEG +ILSVRC2012_val_00018547.JPEG +ILSVRC2012_val_00046571.JPEG +ILSVRC2012_val_00012048.JPEG +ILSVRC2012_val_00000564.JPEG +ILSVRC2012_val_00026113.JPEG +ILSVRC2012_val_00043243.JPEG +ILSVRC2012_val_00004632.JPEG +ILSVRC2012_val_00017354.JPEG +ILSVRC2012_val_00044335.JPEG +ILSVRC2012_val_00006712.JPEG +ILSVRC2012_val_00017983.JPEG +ILSVRC2012_val_00021446.JPEG +ILSVRC2012_val_00001672.JPEG +ILSVRC2012_val_00044417.JPEG +ILSVRC2012_val_00009933.JPEG +ILSVRC2012_val_00021202.JPEG +ILSVRC2012_val_00003691.JPEG +ILSVRC2012_val_00018102.JPEG +ILSVRC2012_val_00022554.JPEG +ILSVRC2012_val_00029446.JPEG +ILSVRC2012_val_00004769.JPEG +ILSVRC2012_val_00027482.JPEG +ILSVRC2012_val_00015427.JPEG +ILSVRC2012_val_00029045.JPEG +ILSVRC2012_val_00014228.JPEG +ILSVRC2012_val_00014591.JPEG +ILSVRC2012_val_00036717.JPEG +ILSVRC2012_val_00000439.JPEG +ILSVRC2012_val_00025109.JPEG +ILSVRC2012_val_00044018.JPEG +ILSVRC2012_val_00020427.JPEG +ILSVRC2012_val_00037835.JPEG +ILSVRC2012_val_00031746.JPEG +ILSVRC2012_val_00014933.JPEG +ILSVRC2012_val_00045707.JPEG +ILSVRC2012_val_00002314.JPEG +ILSVRC2012_val_00041680.JPEG +ILSVRC2012_val_00013424.JPEG +ILSVRC2012_val_00025297.JPEG +ILSVRC2012_val_00013611.JPEG +ILSVRC2012_val_00002789.JPEG +ILSVRC2012_val_00041296.JPEG +ILSVRC2012_val_00040625.JPEG +ILSVRC2012_val_00048667.JPEG +ILSVRC2012_val_00009279.JPEG +ILSVRC2012_val_00032993.JPEG +ILSVRC2012_val_00004567.JPEG +ILSVRC2012_val_00026273.JPEG +ILSVRC2012_val_00017049.JPEG +ILSVRC2012_val_00023833.JPEG +ILSVRC2012_val_00001574.JPEG +ILSVRC2012_val_00005337.JPEG +ILSVRC2012_val_00015965.JPEG +ILSVRC2012_val_00043666.JPEG +ILSVRC2012_val_00040427.JPEG +ILSVRC2012_val_00006045.JPEG +ILSVRC2012_val_00003892.JPEG +ILSVRC2012_val_00027862.JPEG +ILSVRC2012_val_00049596.JPEG +ILSVRC2012_val_00027869.JPEG +ILSVRC2012_val_00013179.JPEG +ILSVRC2012_val_00032104.JPEG +ILSVRC2012_val_00022515.JPEG +ILSVRC2012_val_00044784.JPEG +ILSVRC2012_val_00030733.JPEG +ILSVRC2012_val_00040417.JPEG +ILSVRC2012_val_00033688.JPEG +ILSVRC2012_val_00006239.JPEG +ILSVRC2012_val_00001064.JPEG +ILSVRC2012_val_00034682.JPEG +ILSVRC2012_val_00005867.JPEG +ILSVRC2012_val_00026177.JPEG +ILSVRC2012_val_00040728.JPEG +ILSVRC2012_val_00037630.JPEG +ILSVRC2012_val_00035657.JPEG +ILSVRC2012_val_00006187.JPEG +ILSVRC2012_val_00024190.JPEG +ILSVRC2012_val_00031854.JPEG +ILSVRC2012_val_00004918.JPEG +ILSVRC2012_val_00016689.JPEG +ILSVRC2012_val_00002844.JPEG +ILSVRC2012_val_00007431.JPEG +ILSVRC2012_val_00038746.JPEG +ILSVRC2012_val_00046715.JPEG +ILSVRC2012_val_00016759.JPEG +ILSVRC2012_val_00015565.JPEG +ILSVRC2012_val_00041383.JPEG +ILSVRC2012_val_00036350.JPEG +ILSVRC2012_val_00034106.JPEG +ILSVRC2012_val_00000289.JPEG +ILSVRC2012_val_00017644.JPEG +ILSVRC2012_val_00025037.JPEG +ILSVRC2012_val_00022503.JPEG +ILSVRC2012_val_00004060.JPEG +ILSVRC2012_val_00006701.JPEG +ILSVRC2012_val_00009471.JPEG +ILSVRC2012_val_00031118.JPEG +ILSVRC2012_val_00009752.JPEG +ILSVRC2012_val_00049992.JPEG +ILSVRC2012_val_00015811.JPEG +ILSVRC2012_val_00034812.JPEG +ILSVRC2012_val_00004392.JPEG +ILSVRC2012_val_00005790.JPEG +ILSVRC2012_val_00040407.JPEG +ILSVRC2012_val_00038068.JPEG +ILSVRC2012_val_00015633.JPEG +ILSVRC2012_val_00048827.JPEG +ILSVRC2012_val_00033671.JPEG +ILSVRC2012_val_00010174.JPEG +ILSVRC2012_val_00001016.JPEG +ILSVRC2012_val_00016135.JPEG +ILSVRC2012_val_00044371.JPEG +ILSVRC2012_val_00022935.JPEG +ILSVRC2012_val_00013595.JPEG +ILSVRC2012_val_00047676.JPEG +ILSVRC2012_val_00014916.JPEG +ILSVRC2012_val_00004871.JPEG +ILSVRC2012_val_00049139.JPEG +ILSVRC2012_val_00001752.JPEG +ILSVRC2012_val_00013028.JPEG +ILSVRC2012_val_00021341.JPEG +ILSVRC2012_val_00027188.JPEG +ILSVRC2012_val_00038999.JPEG +ILSVRC2012_val_00039718.JPEG +ILSVRC2012_val_00025069.JPEG +ILSVRC2012_val_00049197.JPEG +ILSVRC2012_val_00040988.JPEG +ILSVRC2012_val_00035813.JPEG +ILSVRC2012_val_00048531.JPEG +ILSVRC2012_val_00035745.JPEG +ILSVRC2012_val_00009786.JPEG +ILSVRC2012_val_00043052.JPEG +ILSVRC2012_val_00004040.JPEG +ILSVRC2012_val_00033388.JPEG +ILSVRC2012_val_00042753.JPEG +ILSVRC2012_val_00036285.JPEG +ILSVRC2012_val_00004709.JPEG +ILSVRC2012_val_00026637.JPEG +ILSVRC2012_val_00025751.JPEG +ILSVRC2012_val_00023488.JPEG +ILSVRC2012_val_00004532.JPEG +ILSVRC2012_val_00042832.JPEG +ILSVRC2012_val_00031785.JPEG +ILSVRC2012_val_00022226.JPEG +ILSVRC2012_val_00001518.JPEG +ILSVRC2012_val_00029511.JPEG +ILSVRC2012_val_00016597.JPEG +ILSVRC2012_val_00006559.JPEG +ILSVRC2012_val_00008652.JPEG +ILSVRC2012_val_00014625.JPEG +ILSVRC2012_val_00001658.JPEG +ILSVRC2012_val_00021524.JPEG +ILSVRC2012_val_00035343.JPEG +ILSVRC2012_val_00038306.JPEG +ILSVRC2012_val_00013614.JPEG +ILSVRC2012_val_00003990.JPEG +ILSVRC2012_val_00020416.JPEG +ILSVRC2012_val_00006988.JPEG +ILSVRC2012_val_00043772.JPEG +ILSVRC2012_val_00022426.JPEG +ILSVRC2012_val_00004677.JPEG +ILSVRC2012_val_00003683.JPEG +ILSVRC2012_val_00019070.JPEG +ILSVRC2012_val_00043191.JPEG +ILSVRC2012_val_00039084.JPEG +ILSVRC2012_val_00037669.JPEG +ILSVRC2012_val_00030000.JPEG +ILSVRC2012_val_00006381.JPEG +ILSVRC2012_val_00007617.JPEG +ILSVRC2012_val_00026265.JPEG +ILSVRC2012_val_00044063.JPEG +ILSVRC2012_val_00037232.JPEG +ILSVRC2012_val_00040825.JPEG +ILSVRC2012_val_00007332.JPEG +ILSVRC2012_val_00027074.JPEG +ILSVRC2012_val_00027254.JPEG +ILSVRC2012_val_00002564.JPEG +ILSVRC2012_val_00018788.JPEG +ILSVRC2012_val_00037879.JPEG +ILSVRC2012_val_00033279.JPEG +ILSVRC2012_val_00037225.JPEG +ILSVRC2012_val_00045702.JPEG +ILSVRC2012_val_00014716.JPEG +ILSVRC2012_val_00027125.JPEG +ILSVRC2012_val_00041641.JPEG +ILSVRC2012_val_00030083.JPEG +ILSVRC2012_val_00029526.JPEG +ILSVRC2012_val_00026510.JPEG +ILSVRC2012_val_00041594.JPEG +ILSVRC2012_val_00048818.JPEG +ILSVRC2012_val_00025782.JPEG +ILSVRC2012_val_00011782.JPEG +ILSVRC2012_val_00014813.JPEG +ILSVRC2012_val_00028978.JPEG +ILSVRC2012_val_00008749.JPEG +ILSVRC2012_val_00005626.JPEG +ILSVRC2012_val_00049711.JPEG +ILSVRC2012_val_00035437.JPEG +ILSVRC2012_val_00013014.JPEG +ILSVRC2012_val_00045195.JPEG +ILSVRC2012_val_00017910.JPEG +ILSVRC2012_val_00033312.JPEG +ILSVRC2012_val_00034998.JPEG +ILSVRC2012_val_00021047.JPEG +ILSVRC2012_val_00022800.JPEG +ILSVRC2012_val_00025860.JPEG +ILSVRC2012_val_00014871.JPEG +ILSVRC2012_val_00042687.JPEG +ILSVRC2012_val_00002551.JPEG +ILSVRC2012_val_00035589.JPEG +ILSVRC2012_val_00043424.JPEG +ILSVRC2012_val_00016795.JPEG +ILSVRC2012_val_00025180.JPEG +ILSVRC2012_val_00020809.JPEG +ILSVRC2012_val_00014211.JPEG +ILSVRC2012_val_00035564.JPEG +ILSVRC2012_val_00019911.JPEG +ILSVRC2012_val_00009397.JPEG +ILSVRC2012_val_00037751.JPEG +ILSVRC2012_val_00024839.JPEG +ILSVRC2012_val_00008287.JPEG +ILSVRC2012_val_00020500.JPEG +ILSVRC2012_val_00014904.JPEG +ILSVRC2012_val_00043034.JPEG +ILSVRC2012_val_00027911.JPEG +ILSVRC2012_val_00019810.JPEG +ILSVRC2012_val_00026927.JPEG +ILSVRC2012_val_00023173.JPEG +ILSVRC2012_val_00037953.JPEG +ILSVRC2012_val_00025347.JPEG +ILSVRC2012_val_00003298.JPEG +ILSVRC2012_val_00033487.JPEG +ILSVRC2012_val_00033523.JPEG +ILSVRC2012_val_00047736.JPEG +ILSVRC2012_val_00020277.JPEG +ILSVRC2012_val_00041902.JPEG +ILSVRC2012_val_00017251.JPEG +ILSVRC2012_val_00002990.JPEG +ILSVRC2012_val_00043495.JPEG +ILSVRC2012_val_00045231.JPEG +ILSVRC2012_val_00034918.JPEG +ILSVRC2012_val_00038915.JPEG +ILSVRC2012_val_00019257.JPEG +ILSVRC2012_val_00020679.JPEG +ILSVRC2012_val_00003148.JPEG +ILSVRC2012_val_00012503.JPEG +ILSVRC2012_val_00023985.JPEG +ILSVRC2012_val_00020201.JPEG +ILSVRC2012_val_00016819.JPEG +ILSVRC2012_val_00007355.JPEG +ILSVRC2012_val_00024212.JPEG +ILSVRC2012_val_00005145.JPEG +ILSVRC2012_val_00036052.JPEG +ILSVRC2012_val_00048624.JPEG +ILSVRC2012_val_00049889.JPEG +ILSVRC2012_val_00039604.JPEG +ILSVRC2012_val_00032448.JPEG +ILSVRC2012_val_00039880.JPEG +ILSVRC2012_val_00001248.JPEG +ILSVRC2012_val_00023485.JPEG +ILSVRC2012_val_00039962.JPEG +ILSVRC2012_val_00010526.JPEG +ILSVRC2012_val_00025127.JPEG +ILSVRC2012_val_00018624.JPEG +ILSVRC2012_val_00001254.JPEG +ILSVRC2012_val_00013662.JPEG +ILSVRC2012_val_00021299.JPEG +ILSVRC2012_val_00048360.JPEG +ILSVRC2012_val_00007406.JPEG +ILSVRC2012_val_00001622.JPEG +ILSVRC2012_val_00001006.JPEG +ILSVRC2012_val_00019881.JPEG +ILSVRC2012_val_00038584.JPEG +ILSVRC2012_val_00044638.JPEG +ILSVRC2012_val_00017153.JPEG +ILSVRC2012_val_00044671.JPEG +ILSVRC2012_val_00039900.JPEG +ILSVRC2012_val_00020063.JPEG +ILSVRC2012_val_00021454.JPEG +ILSVRC2012_val_00033888.JPEG +ILSVRC2012_val_00018879.JPEG +ILSVRC2012_val_00030299.JPEG +ILSVRC2012_val_00008959.JPEG +ILSVRC2012_val_00005719.JPEG +ILSVRC2012_val_00030348.JPEG +ILSVRC2012_val_00023104.JPEG +ILSVRC2012_val_00036294.JPEG +ILSVRC2012_val_00049163.JPEG +ILSVRC2012_val_00013893.JPEG +ILSVRC2012_val_00035067.JPEG +ILSVRC2012_val_00026058.JPEG +ILSVRC2012_val_00048042.JPEG +ILSVRC2012_val_00010671.JPEG +ILSVRC2012_val_00040704.JPEG +ILSVRC2012_val_00039148.JPEG +ILSVRC2012_val_00012153.JPEG +ILSVRC2012_val_00045698.JPEG +ILSVRC2012_val_00028805.JPEG +ILSVRC2012_val_00042067.JPEG +ILSVRC2012_val_00005502.JPEG +ILSVRC2012_val_00001292.JPEG +ILSVRC2012_val_00039874.JPEG +ILSVRC2012_val_00015306.JPEG +ILSVRC2012_val_00042062.JPEG +ILSVRC2012_val_00045812.JPEG +ILSVRC2012_val_00042545.JPEG +ILSVRC2012_val_00021072.JPEG +ILSVRC2012_val_00030936.JPEG +ILSVRC2012_val_00031038.JPEG +ILSVRC2012_val_00035083.JPEG +ILSVRC2012_val_00040079.JPEG +ILSVRC2012_val_00035612.JPEG +ILSVRC2012_val_00047826.JPEG +ILSVRC2012_val_00029259.JPEG +ILSVRC2012_val_00030367.JPEG +ILSVRC2012_val_00032702.JPEG +ILSVRC2012_val_00011919.JPEG +ILSVRC2012_val_00020365.JPEG +ILSVRC2012_val_00047314.JPEG +ILSVRC2012_val_00029545.JPEG +ILSVRC2012_val_00022976.JPEG +ILSVRC2012_val_00009341.JPEG +ILSVRC2012_val_00000031.JPEG +ILSVRC2012_val_00048868.JPEG +ILSVRC2012_val_00025348.JPEG +ILSVRC2012_val_00018848.JPEG +ILSVRC2012_val_00000809.JPEG +ILSVRC2012_val_00013865.JPEG +ILSVRC2012_val_00042191.JPEG +ILSVRC2012_val_00036501.JPEG +ILSVRC2012_val_00025242.JPEG +ILSVRC2012_val_00044006.JPEG +ILSVRC2012_val_00046363.JPEG +ILSVRC2012_val_00042846.JPEG +ILSVRC2012_val_00030179.JPEG +ILSVRC2012_val_00016681.JPEG +ILSVRC2012_val_00045735.JPEG +ILSVRC2012_val_00013346.JPEG +ILSVRC2012_val_00005203.JPEG +ILSVRC2012_val_00036526.JPEG +ILSVRC2012_val_00013002.JPEG +ILSVRC2012_val_00049870.JPEG +ILSVRC2012_val_00039116.JPEG +ILSVRC2012_val_00022195.JPEG +ILSVRC2012_val_00039196.JPEG +ILSVRC2012_val_00009332.JPEG +ILSVRC2012_val_00031190.JPEG +ILSVRC2012_val_00002563.JPEG +ILSVRC2012_val_00003350.JPEG +ILSVRC2012_val_00014301.JPEG +ILSVRC2012_val_00012375.JPEG +ILSVRC2012_val_00021483.JPEG +ILSVRC2012_val_00007112.JPEG +ILSVRC2012_val_00019507.JPEG +ILSVRC2012_val_00046217.JPEG +ILSVRC2012_val_00044860.JPEG +ILSVRC2012_val_00015141.JPEG +ILSVRC2012_val_00042923.JPEG +ILSVRC2012_val_00027931.JPEG +ILSVRC2012_val_00031113.JPEG +ILSVRC2012_val_00042885.JPEG +ILSVRC2012_val_00001587.JPEG +ILSVRC2012_val_00035216.JPEG +ILSVRC2012_val_00020848.JPEG +ILSVRC2012_val_00044241.JPEG +ILSVRC2012_val_00035183.JPEG +ILSVRC2012_val_00028059.JPEG +ILSVRC2012_val_00025218.JPEG +ILSVRC2012_val_00016335.JPEG +ILSVRC2012_val_00027891.JPEG +ILSVRC2012_val_00012535.JPEG +ILSVRC2012_val_00024276.JPEG +ILSVRC2012_val_00029574.JPEG +ILSVRC2012_val_00033666.JPEG +ILSVRC2012_val_00030978.JPEG +ILSVRC2012_val_00035435.JPEG +ILSVRC2012_val_00001985.JPEG +ILSVRC2012_val_00043233.JPEG +ILSVRC2012_val_00048478.JPEG +ILSVRC2012_val_00018383.JPEG +ILSVRC2012_val_00038768.JPEG +ILSVRC2012_val_00031880.JPEG +ILSVRC2012_val_00027728.JPEG +ILSVRC2012_val_00000706.JPEG +ILSVRC2012_val_00047914.JPEG +ILSVRC2012_val_00005864.JPEG +ILSVRC2012_val_00029781.JPEG +ILSVRC2012_val_00023213.JPEG +ILSVRC2012_val_00027001.JPEG +ILSVRC2012_val_00003797.JPEG +ILSVRC2012_val_00033681.JPEG +ILSVRC2012_val_00030162.JPEG +ILSVRC2012_val_00001833.JPEG +ILSVRC2012_val_00002035.JPEG +ILSVRC2012_val_00009521.JPEG +ILSVRC2012_val_00044288.JPEG +ILSVRC2012_val_00017166.JPEG +ILSVRC2012_val_00038442.JPEG +ILSVRC2012_val_00037582.JPEG +ILSVRC2012_val_00009754.JPEG +ILSVRC2012_val_00042034.JPEG +ILSVRC2012_val_00014665.JPEG +ILSVRC2012_val_00038261.JPEG +ILSVRC2012_val_00041045.JPEG +ILSVRC2012_val_00006728.JPEG +ILSVRC2012_val_00002676.JPEG +ILSVRC2012_val_00023730.JPEG +ILSVRC2012_val_00000892.JPEG +ILSVRC2012_val_00005785.JPEG +ILSVRC2012_val_00011711.JPEG +ILSVRC2012_val_00002209.JPEG +ILSVRC2012_val_00043615.JPEG +ILSVRC2012_val_00012025.JPEG +ILSVRC2012_val_00006658.JPEG +ILSVRC2012_val_00041994.JPEG +ILSVRC2012_val_00017961.JPEG +ILSVRC2012_val_00007095.JPEG +ILSVRC2012_val_00022708.JPEG +ILSVRC2012_val_00025734.JPEG +ILSVRC2012_val_00040233.JPEG +ILSVRC2012_val_00001563.JPEG +ILSVRC2012_val_00005628.JPEG +ILSVRC2012_val_00026670.JPEG +ILSVRC2012_val_00047769.JPEG +ILSVRC2012_val_00029341.JPEG +ILSVRC2012_val_00023848.JPEG +ILSVRC2012_val_00029011.JPEG +ILSVRC2012_val_00019523.JPEG +ILSVRC2012_val_00000011.JPEG +ILSVRC2012_val_00002038.JPEG +ILSVRC2012_val_00012843.JPEG +ILSVRC2012_val_00010507.JPEG +ILSVRC2012_val_00013573.JPEG +ILSVRC2012_val_00020570.JPEG +ILSVRC2012_val_00029231.JPEG +ILSVRC2012_val_00035697.JPEG +ILSVRC2012_val_00004904.JPEG +ILSVRC2012_val_00020445.JPEG +ILSVRC2012_val_00031206.JPEG +ILSVRC2012_val_00015567.JPEG +ILSVRC2012_val_00019929.JPEG +ILSVRC2012_val_00001644.JPEG +ILSVRC2012_val_00009977.JPEG +ILSVRC2012_val_00040095.JPEG +ILSVRC2012_val_00045511.JPEG +ILSVRC2012_val_00034560.JPEG +ILSVRC2012_val_00017329.JPEG +ILSVRC2012_val_00022390.JPEG +ILSVRC2012_val_00044539.JPEG +ILSVRC2012_val_00026711.JPEG +ILSVRC2012_val_00005126.JPEG +ILSVRC2012_val_00035069.JPEG +ILSVRC2012_val_00002150.JPEG +ILSVRC2012_val_00022569.JPEG +ILSVRC2012_val_00031309.JPEG +ILSVRC2012_val_00014787.JPEG +ILSVRC2012_val_00018142.JPEG +ILSVRC2012_val_00038128.JPEG +ILSVRC2012_val_00018474.JPEG +ILSVRC2012_val_00009575.JPEG +ILSVRC2012_val_00017599.JPEG +ILSVRC2012_val_00047158.JPEG +ILSVRC2012_val_00031276.JPEG +ILSVRC2012_val_00029258.JPEG +ILSVRC2012_val_00039169.JPEG +ILSVRC2012_val_00042704.JPEG +ILSVRC2012_val_00035558.JPEG +ILSVRC2012_val_00009367.JPEG +ILSVRC2012_val_00043780.JPEG +ILSVRC2012_val_00039548.JPEG +ILSVRC2012_val_00030618.JPEG +ILSVRC2012_val_00030588.JPEG +ILSVRC2012_val_00047337.JPEG +ILSVRC2012_val_00020135.JPEG +ILSVRC2012_val_00022252.JPEG +ILSVRC2012_val_00024433.JPEG +ILSVRC2012_val_00004808.JPEG +ILSVRC2012_val_00037553.JPEG +ILSVRC2012_val_00021530.JPEG +ILSVRC2012_val_00036534.JPEG +ILSVRC2012_val_00015999.JPEG +ILSVRC2012_val_00028384.JPEG +ILSVRC2012_val_00032067.JPEG +ILSVRC2012_val_00005775.JPEG +ILSVRC2012_val_00023059.JPEG +ILSVRC2012_val_00007300.JPEG +ILSVRC2012_val_00033629.JPEG +ILSVRC2012_val_00030679.JPEG +ILSVRC2012_val_00001680.JPEG +ILSVRC2012_val_00037629.JPEG +ILSVRC2012_val_00005631.JPEG +ILSVRC2012_val_00014180.JPEG +ILSVRC2012_val_00004445.JPEG +ILSVRC2012_val_00001682.JPEG +ILSVRC2012_val_00040498.JPEG +ILSVRC2012_val_00023474.JPEG +ILSVRC2012_val_00023180.JPEG +ILSVRC2012_val_00045918.JPEG +ILSVRC2012_val_00020022.JPEG +ILSVRC2012_val_00006606.JPEG +ILSVRC2012_val_00043661.JPEG +ILSVRC2012_val_00013733.JPEG +ILSVRC2012_val_00021645.JPEG +ILSVRC2012_val_00017617.JPEG +ILSVRC2012_val_00006620.JPEG +ILSVRC2012_val_00049348.JPEG +ILSVRC2012_val_00035255.JPEG +ILSVRC2012_val_00018402.JPEG +ILSVRC2012_val_00033456.JPEG +ILSVRC2012_val_00027231.JPEG +ILSVRC2012_val_00047425.JPEG +ILSVRC2012_val_00008808.JPEG +ILSVRC2012_val_00032896.JPEG +ILSVRC2012_val_00002708.JPEG +ILSVRC2012_val_00042909.JPEG +ILSVRC2012_val_00007892.JPEG +ILSVRC2012_val_00044341.JPEG +ILSVRC2012_val_00028901.JPEG +ILSVRC2012_val_00048862.JPEG +ILSVRC2012_val_00039992.JPEG +ILSVRC2012_val_00012935.JPEG +ILSVRC2012_val_00015029.JPEG +ILSVRC2012_val_00013959.JPEG +ILSVRC2012_val_00022751.JPEG +ILSVRC2012_val_00049467.JPEG +ILSVRC2012_val_00030868.JPEG +ILSVRC2012_val_00034313.JPEG +ILSVRC2012_val_00002918.JPEG +ILSVRC2012_val_00027700.JPEG +ILSVRC2012_val_00015226.JPEG +ILSVRC2012_val_00044885.JPEG +ILSVRC2012_val_00045973.JPEG +ILSVRC2012_val_00024202.JPEG +ILSVRC2012_val_00002056.JPEG +ILSVRC2012_val_00010949.JPEG +ILSVRC2012_val_00037173.JPEG +ILSVRC2012_val_00007490.JPEG +ILSVRC2012_val_00006615.JPEG +ILSVRC2012_val_00044867.JPEG +ILSVRC2012_val_00041670.JPEG +ILSVRC2012_val_00002048.JPEG +ILSVRC2012_val_00022341.JPEG +ILSVRC2012_val_00048542.JPEG +ILSVRC2012_val_00030396.JPEG +ILSVRC2012_val_00042943.JPEG +ILSVRC2012_val_00006366.JPEG +ILSVRC2012_val_00027396.JPEG +ILSVRC2012_val_00033434.JPEG +ILSVRC2012_val_00043671.JPEG +ILSVRC2012_val_00046032.JPEG +ILSVRC2012_val_00018019.JPEG +ILSVRC2012_val_00027523.JPEG +ILSVRC2012_val_00039875.JPEG +ILSVRC2012_val_00029834.JPEG +ILSVRC2012_val_00046293.JPEG +ILSVRC2012_val_00002716.JPEG +ILSVRC2012_val_00007561.JPEG +ILSVRC2012_val_00016481.JPEG +ILSVRC2012_val_00005610.JPEG +ILSVRC2012_val_00029745.JPEG +ILSVRC2012_val_00005156.JPEG +ILSVRC2012_val_00004911.JPEG +ILSVRC2012_val_00011262.JPEG +ILSVRC2012_val_00041639.JPEG +ILSVRC2012_val_00044405.JPEG +ILSVRC2012_val_00036150.JPEG +ILSVRC2012_val_00013413.JPEG +ILSVRC2012_val_00032418.JPEG +ILSVRC2012_val_00012747.JPEG +ILSVRC2012_val_00047145.JPEG +ILSVRC2012_val_00033503.JPEG +ILSVRC2012_val_00019536.JPEG +ILSVRC2012_val_00049462.JPEG +ILSVRC2012_val_00009912.JPEG +ILSVRC2012_val_00010375.JPEG +ILSVRC2012_val_00018499.JPEG +ILSVRC2012_val_00011773.JPEG +ILSVRC2012_val_00039375.JPEG +ILSVRC2012_val_00005620.JPEG +ILSVRC2012_val_00037453.JPEG +ILSVRC2012_val_00005355.JPEG +ILSVRC2012_val_00040102.JPEG +ILSVRC2012_val_00024159.JPEG +ILSVRC2012_val_00029850.JPEG +ILSVRC2012_val_00017088.JPEG +ILSVRC2012_val_00004767.JPEG +ILSVRC2012_val_00006983.JPEG +ILSVRC2012_val_00010362.JPEG +ILSVRC2012_val_00041889.JPEG +ILSVRC2012_val_00027497.JPEG +ILSVRC2012_val_00020942.JPEG +ILSVRC2012_val_00021520.JPEG +ILSVRC2012_val_00015727.JPEG +ILSVRC2012_val_00040437.JPEG +ILSVRC2012_val_00025366.JPEG +ILSVRC2012_val_00042855.JPEG +ILSVRC2012_val_00040634.JPEG +ILSVRC2012_val_00027302.JPEG +ILSVRC2012_val_00005149.JPEG +ILSVRC2012_val_00039193.JPEG +ILSVRC2012_val_00006857.JPEG +ILSVRC2012_val_00030359.JPEG +ILSVRC2012_val_00017883.JPEG +ILSVRC2012_val_00026423.JPEG +ILSVRC2012_val_00021267.JPEG +ILSVRC2012_val_00010164.JPEG +ILSVRC2012_val_00038887.JPEG +ILSVRC2012_val_00018544.JPEG +ILSVRC2012_val_00013872.JPEG +ILSVRC2012_val_00038034.JPEG +ILSVRC2012_val_00030237.JPEG +ILSVRC2012_val_00038450.JPEG +ILSVRC2012_val_00028161.JPEG +ILSVRC2012_val_00013208.JPEG +ILSVRC2012_val_00049885.JPEG +ILSVRC2012_val_00039603.JPEG +ILSVRC2012_val_00024712.JPEG +ILSVRC2012_val_00027986.JPEG +ILSVRC2012_val_00042604.JPEG +ILSVRC2012_val_00028701.JPEG +ILSVRC2012_val_00018917.JPEG +ILSVRC2012_val_00032798.JPEG +ILSVRC2012_val_00024882.JPEG +ILSVRC2012_val_00007836.JPEG +ILSVRC2012_val_00036046.JPEG +ILSVRC2012_val_00031691.JPEG +ILSVRC2012_val_00012156.JPEG +ILSVRC2012_val_00034198.JPEG +ILSVRC2012_val_00024988.JPEG +ILSVRC2012_val_00014181.JPEG +ILSVRC2012_val_00005874.JPEG +ILSVRC2012_val_00021393.JPEG +ILSVRC2012_val_00018761.JPEG +ILSVRC2012_val_00049998.JPEG +ILSVRC2012_val_00043057.JPEG +ILSVRC2012_val_00043063.JPEG +ILSVRC2012_val_00039150.JPEG +ILSVRC2012_val_00006901.JPEG +ILSVRC2012_val_00026149.JPEG +ILSVRC2012_val_00047667.JPEG +ILSVRC2012_val_00029822.JPEG +ILSVRC2012_val_00003226.JPEG +ILSVRC2012_val_00041669.JPEG +ILSVRC2012_val_00011342.JPEG +ILSVRC2012_val_00016400.JPEG +ILSVRC2012_val_00028246.JPEG +ILSVRC2012_val_00022482.JPEG +ILSVRC2012_val_00013153.JPEG +ILSVRC2012_val_00011625.JPEG +ILSVRC2012_val_00047834.JPEG +ILSVRC2012_val_00043415.JPEG +ILSVRC2012_val_00000154.JPEG +ILSVRC2012_val_00024673.JPEG +ILSVRC2012_val_00004148.JPEG +ILSVRC2012_val_00025081.JPEG +ILSVRC2012_val_00044473.JPEG +ILSVRC2012_val_00045650.JPEG +ILSVRC2012_val_00002939.JPEG +ILSVRC2012_val_00041022.JPEG +ILSVRC2012_val_00042730.JPEG +ILSVRC2012_val_00043773.JPEG +ILSVRC2012_val_00032589.JPEG +ILSVRC2012_val_00003509.JPEG +ILSVRC2012_val_00027886.JPEG +ILSVRC2012_val_00015769.JPEG +ILSVRC2012_val_00033978.JPEG +ILSVRC2012_val_00032920.JPEG +ILSVRC2012_val_00017665.JPEG +ILSVRC2012_val_00005577.JPEG +ILSVRC2012_val_00028797.JPEG +ILSVRC2012_val_00034778.JPEG +ILSVRC2012_val_00015555.JPEG +ILSVRC2012_val_00028221.JPEG +ILSVRC2012_val_00041122.JPEG +ILSVRC2012_val_00041168.JPEG +ILSVRC2012_val_00041117.JPEG +ILSVRC2012_val_00049666.JPEG +ILSVRC2012_val_00034527.JPEG +ILSVRC2012_val_00005360.JPEG +ILSVRC2012_val_00038460.JPEG +ILSVRC2012_val_00018327.JPEG +ILSVRC2012_val_00026553.JPEG +ILSVRC2012_val_00018272.JPEG +ILSVRC2012_val_00012993.JPEG +ILSVRC2012_val_00020940.JPEG +ILSVRC2012_val_00025832.JPEG +ILSVRC2012_val_00036339.JPEG +ILSVRC2012_val_00026193.JPEG +ILSVRC2012_val_00014475.JPEG +ILSVRC2012_val_00002707.JPEG +ILSVRC2012_val_00025647.JPEG +ILSVRC2012_val_00032519.JPEG +ILSVRC2012_val_00005895.JPEG +ILSVRC2012_val_00016221.JPEG +ILSVRC2012_val_00025592.JPEG +ILSVRC2012_val_00029326.JPEG +ILSVRC2012_val_00000995.JPEG +ILSVRC2012_val_00023689.JPEG +ILSVRC2012_val_00022975.JPEG +ILSVRC2012_val_00046307.JPEG +ILSVRC2012_val_00004043.JPEG +ILSVRC2012_val_00022793.JPEG +ILSVRC2012_val_00031755.JPEG +ILSVRC2012_val_00021433.JPEG +ILSVRC2012_val_00016097.JPEG +ILSVRC2012_val_00033240.JPEG +ILSVRC2012_val_00012559.JPEG +ILSVRC2012_val_00031317.JPEG +ILSVRC2012_val_00003300.JPEG +ILSVRC2012_val_00004569.JPEG +ILSVRC2012_val_00033310.JPEG +ILSVRC2012_val_00014323.JPEG +ILSVRC2012_val_00036169.JPEG +ILSVRC2012_val_00030957.JPEG +ILSVRC2012_val_00033645.JPEG +ILSVRC2012_val_00012453.JPEG +ILSVRC2012_val_00000965.JPEG +ILSVRC2012_val_00027177.JPEG +ILSVRC2012_val_00031602.JPEG +ILSVRC2012_val_00016643.JPEG +ILSVRC2012_val_00041313.JPEG +ILSVRC2012_val_00007450.JPEG +ILSVRC2012_val_00036691.JPEG +ILSVRC2012_val_00013378.JPEG +ILSVRC2012_val_00030121.JPEG +ILSVRC2012_val_00027438.JPEG +ILSVRC2012_val_00026519.JPEG +ILSVRC2012_val_00037504.JPEG +ILSVRC2012_val_00002148.JPEG +ILSVRC2012_val_00038045.JPEG +ILSVRC2012_val_00013412.JPEG +ILSVRC2012_val_00016213.JPEG +ILSVRC2012_val_00012189.JPEG +ILSVRC2012_val_00012220.JPEG +ILSVRC2012_val_00011135.JPEG +ILSVRC2012_val_00018492.JPEG +ILSVRC2012_val_00026390.JPEG +ILSVRC2012_val_00002645.JPEG +ILSVRC2012_val_00047444.JPEG +ILSVRC2012_val_00039270.JPEG +ILSVRC2012_val_00015447.JPEG +ILSVRC2012_val_00048403.JPEG +ILSVRC2012_val_00003295.JPEG +ILSVRC2012_val_00008292.JPEG +ILSVRC2012_val_00005792.JPEG +ILSVRC2012_val_00017291.JPEG +ILSVRC2012_val_00026101.JPEG +ILSVRC2012_val_00003379.JPEG +ILSVRC2012_val_00024192.JPEG +ILSVRC2012_val_00043129.JPEG +ILSVRC2012_val_00013765.JPEG +ILSVRC2012_val_00010938.JPEG +ILSVRC2012_val_00031349.JPEG +ILSVRC2012_val_00030398.JPEG +ILSVRC2012_val_00010086.JPEG +ILSVRC2012_val_00009722.JPEG +ILSVRC2012_val_00011812.JPEG +ILSVRC2012_val_00029508.JPEG +ILSVRC2012_val_00011529.JPEG +ILSVRC2012_val_00017006.JPEG +ILSVRC2012_val_00024064.JPEG +ILSVRC2012_val_00030809.JPEG +ILSVRC2012_val_00033897.JPEG +ILSVRC2012_val_00013451.JPEG +ILSVRC2012_val_00006475.JPEG +ILSVRC2012_val_00022694.JPEG +ILSVRC2012_val_00036185.JPEG +ILSVRC2012_val_00044987.JPEG +ILSVRC2012_val_00027211.JPEG +ILSVRC2012_val_00017300.JPEG +ILSVRC2012_val_00040743.JPEG +ILSVRC2012_val_00024552.JPEG +ILSVRC2012_val_00033732.JPEG +ILSVRC2012_val_00016779.JPEG +ILSVRC2012_val_00011829.JPEG +ILSVRC2012_val_00014245.JPEG +ILSVRC2012_val_00029947.JPEG +ILSVRC2012_val_00034320.JPEG +ILSVRC2012_val_00018707.JPEG +ILSVRC2012_val_00014154.JPEG +ILSVRC2012_val_00047162.JPEG +ILSVRC2012_val_00035789.JPEG +ILSVRC2012_val_00044500.JPEG +ILSVRC2012_val_00018197.JPEG +ILSVRC2012_val_00003376.JPEG +ILSVRC2012_val_00044614.JPEG +ILSVRC2012_val_00027662.JPEG +ILSVRC2012_val_00006789.JPEG +ILSVRC2012_val_00040515.JPEG +ILSVRC2012_val_00016093.JPEG +ILSVRC2012_val_00000205.JPEG +ILSVRC2012_val_00000348.JPEG +ILSVRC2012_val_00020894.JPEG +ILSVRC2012_val_00003551.JPEG +ILSVRC2012_val_00008679.JPEG +ILSVRC2012_val_00012040.JPEG +ILSVRC2012_val_00002088.JPEG +ILSVRC2012_val_00032529.JPEG +ILSVRC2012_val_00006428.JPEG +ILSVRC2012_val_00000697.JPEG +ILSVRC2012_val_00045213.JPEG +ILSVRC2012_val_00047569.JPEG +ILSVRC2012_val_00013316.JPEG +ILSVRC2012_val_00027803.JPEG +ILSVRC2012_val_00009761.JPEG +ILSVRC2012_val_00012575.JPEG +ILSVRC2012_val_00022582.JPEG +ILSVRC2012_val_00013170.JPEG +ILSVRC2012_val_00000282.JPEG +ILSVRC2012_val_00024327.JPEG +ILSVRC2012_val_00011471.JPEG +ILSVRC2012_val_00032937.JPEG +ILSVRC2012_val_00003107.JPEG +ILSVRC2012_val_00046278.JPEG +ILSVRC2012_val_00003731.JPEG +ILSVRC2012_val_00013557.JPEG +ILSVRC2012_val_00018979.JPEG +ILSVRC2012_val_00007941.JPEG +ILSVRC2012_val_00027204.JPEG +ILSVRC2012_val_00032474.JPEG +ILSVRC2012_val_00026606.JPEG +ILSVRC2012_val_00045900.JPEG +ILSVRC2012_val_00041600.JPEG +ILSVRC2012_val_00002949.JPEG +ILSVRC2012_val_00044418.JPEG +ILSVRC2012_val_00007824.JPEG +ILSVRC2012_val_00034989.JPEG +ILSVRC2012_val_00046215.JPEG +ILSVRC2012_val_00004662.JPEG +ILSVRC2012_val_00037501.JPEG +ILSVRC2012_val_00014330.JPEG +ILSVRC2012_val_00031469.JPEG +ILSVRC2012_val_00038741.JPEG +ILSVRC2012_val_00006737.JPEG +ILSVRC2012_val_00015220.JPEG +ILSVRC2012_val_00018921.JPEG +ILSVRC2012_val_00022528.JPEG +ILSVRC2012_val_00014043.JPEG +ILSVRC2012_val_00012247.JPEG +ILSVRC2012_val_00018789.JPEG +ILSVRC2012_val_00043803.JPEG +ILSVRC2012_val_00045383.JPEG +ILSVRC2012_val_00005026.JPEG +ILSVRC2012_val_00037393.JPEG +ILSVRC2012_val_00001432.JPEG +ILSVRC2012_val_00002377.JPEG +ILSVRC2012_val_00028298.JPEG +ILSVRC2012_val_00037492.JPEG +ILSVRC2012_val_00028886.JPEG +ILSVRC2012_val_00017512.JPEG +ILSVRC2012_val_00039318.JPEG +ILSVRC2012_val_00003690.JPEG +ILSVRC2012_val_00024577.JPEG +ILSVRC2012_val_00042017.JPEG +ILSVRC2012_val_00025249.JPEG +ILSVRC2012_val_00004513.JPEG +ILSVRC2012_val_00000330.JPEG +ILSVRC2012_val_00039787.JPEG +ILSVRC2012_val_00015902.JPEG +ILSVRC2012_val_00042634.JPEG +ILSVRC2012_val_00022389.JPEG +ILSVRC2012_val_00041893.JPEG +ILSVRC2012_val_00032078.JPEG +ILSVRC2012_val_00043892.JPEG +ILSVRC2012_val_00000162.JPEG +ILSVRC2012_val_00020612.JPEG +ILSVRC2012_val_00003626.JPEG +ILSVRC2012_val_00005454.JPEG +ILSVRC2012_val_00047537.JPEG +ILSVRC2012_val_00044345.JPEG +ILSVRC2012_val_00002854.JPEG +ILSVRC2012_val_00037783.JPEG +ILSVRC2012_val_00019841.JPEG +ILSVRC2012_val_00047739.JPEG +ILSVRC2012_val_00008621.JPEG +ILSVRC2012_val_00034003.JPEG +ILSVRC2012_val_00024701.JPEG +ILSVRC2012_val_00031050.JPEG +ILSVRC2012_val_00007288.JPEG +ILSVRC2012_val_00016022.JPEG +ILSVRC2012_val_00049663.JPEG +ILSVRC2012_val_00003775.JPEG +ILSVRC2012_val_00044090.JPEG +ILSVRC2012_val_00023990.JPEG +ILSVRC2012_val_00041754.JPEG +ILSVRC2012_val_00029240.JPEG +ILSVRC2012_val_00041705.JPEG +ILSVRC2012_val_00001101.JPEG +ILSVRC2012_val_00006800.JPEG +ILSVRC2012_val_00008715.JPEG +ILSVRC2012_val_00041113.JPEG +ILSVRC2012_val_00035537.JPEG +ILSVRC2012_val_00046328.JPEG +ILSVRC2012_val_00026025.JPEG +ILSVRC2012_val_00015797.JPEG +ILSVRC2012_val_00041220.JPEG +ILSVRC2012_val_00033362.JPEG +ILSVRC2012_val_00023500.JPEG +ILSVRC2012_val_00031146.JPEG +ILSVRC2012_val_00004220.JPEG +ILSVRC2012_val_00023321.JPEG +ILSVRC2012_val_00042607.JPEG +ILSVRC2012_val_00028226.JPEG +ILSVRC2012_val_00040101.JPEG +ILSVRC2012_val_00000910.JPEG +ILSVRC2012_val_00010329.JPEG +ILSVRC2012_val_00017753.JPEG +ILSVRC2012_val_00026816.JPEG +ILSVRC2012_val_00008065.JPEG +ILSVRC2012_val_00037404.JPEG +ILSVRC2012_val_00033457.JPEG +ILSVRC2012_val_00020621.JPEG +ILSVRC2012_val_00049323.JPEG +ILSVRC2012_val_00009781.JPEG +ILSVRC2012_val_00006752.JPEG +ILSVRC2012_val_00039470.JPEG +ILSVRC2012_val_00031908.JPEG +ILSVRC2012_val_00041043.JPEG +ILSVRC2012_val_00040901.JPEG +ILSVRC2012_val_00045577.JPEG +ILSVRC2012_val_00000521.JPEG +ILSVRC2012_val_00036913.JPEG +ILSVRC2012_val_00023141.JPEG +ILSVRC2012_val_00014080.JPEG +ILSVRC2012_val_00028712.JPEG +ILSVRC2012_val_00009924.JPEG +ILSVRC2012_val_00019103.JPEG +ILSVRC2012_val_00038037.JPEG +ILSVRC2012_val_00009662.JPEG +ILSVRC2012_val_00001448.JPEG +ILSVRC2012_val_00041540.JPEG +ILSVRC2012_val_00042726.JPEG +ILSVRC2012_val_00043350.JPEG +ILSVRC2012_val_00037326.JPEG +ILSVRC2012_val_00041006.JPEG +ILSVRC2012_val_00045197.JPEG +ILSVRC2012_val_00024342.JPEG +ILSVRC2012_val_00019265.JPEG +ILSVRC2012_val_00035706.JPEG +ILSVRC2012_val_00021115.JPEG +ILSVRC2012_val_00012601.JPEG +ILSVRC2012_val_00000128.JPEG +ILSVRC2012_val_00004860.JPEG +ILSVRC2012_val_00041919.JPEG +ILSVRC2012_val_00000143.JPEG +ILSVRC2012_val_00038194.JPEG +ILSVRC2012_val_00022918.JPEG +ILSVRC2012_val_00041779.JPEG +ILSVRC2012_val_00009870.JPEG +ILSVRC2012_val_00011524.JPEG +ILSVRC2012_val_00042757.JPEG +ILSVRC2012_val_00001762.JPEG +ILSVRC2012_val_00000306.JPEG +ILSVRC2012_val_00044373.JPEG +ILSVRC2012_val_00017266.JPEG +ILSVRC2012_val_00010592.JPEG +ILSVRC2012_val_00011482.JPEG +ILSVRC2012_val_00029628.JPEG +ILSVRC2012_val_00008282.JPEG +ILSVRC2012_val_00048258.JPEG +ILSVRC2012_val_00020240.JPEG +ILSVRC2012_val_00028670.JPEG +ILSVRC2012_val_00032620.JPEG +ILSVRC2012_val_00038155.JPEG +ILSVRC2012_val_00035454.JPEG +ILSVRC2012_val_00046911.JPEG +ILSVRC2012_val_00024977.JPEG +ILSVRC2012_val_00010341.JPEG +ILSVRC2012_val_00034555.JPEG +ILSVRC2012_val_00018882.JPEG +ILSVRC2012_val_00035406.JPEG +ILSVRC2012_val_00034863.JPEG +ILSVRC2012_val_00017986.JPEG +ILSVRC2012_val_00033992.JPEG +ILSVRC2012_val_00041150.JPEG +ILSVRC2012_val_00042305.JPEG +ILSVRC2012_val_00002215.JPEG +ILSVRC2012_val_00018129.JPEG +ILSVRC2012_val_00003993.JPEG +ILSVRC2012_val_00011014.JPEG +ILSVRC2012_val_00039085.JPEG +ILSVRC2012_val_00035013.JPEG +ILSVRC2012_val_00005521.JPEG +ILSVRC2012_val_00043316.JPEG +ILSVRC2012_val_00012359.JPEG +ILSVRC2012_val_00004535.JPEG +ILSVRC2012_val_00002896.JPEG +ILSVRC2012_val_00015113.JPEG +ILSVRC2012_val_00042125.JPEG +ILSVRC2012_val_00032800.JPEG +ILSVRC2012_val_00040397.JPEG +ILSVRC2012_val_00023303.JPEG +ILSVRC2012_val_00039964.JPEG +ILSVRC2012_val_00022563.JPEG +ILSVRC2012_val_00032142.JPEG +ILSVRC2012_val_00017992.JPEG +ILSVRC2012_val_00041283.JPEG +ILSVRC2012_val_00020120.JPEG +ILSVRC2012_val_00001766.JPEG +ILSVRC2012_val_00002602.JPEG +ILSVRC2012_val_00018066.JPEG +ILSVRC2012_val_00040345.JPEG +ILSVRC2012_val_00017720.JPEG +ILSVRC2012_val_00021802.JPEG +ILSVRC2012_val_00032725.JPEG +ILSVRC2012_val_00019564.JPEG +ILSVRC2012_val_00039692.JPEG +ILSVRC2012_val_00004533.JPEG +ILSVRC2012_val_00003182.JPEG +ILSVRC2012_val_00008524.JPEG +ILSVRC2012_val_00030531.JPEG +ILSVRC2012_val_00035403.JPEG +ILSVRC2012_val_00010356.JPEG +ILSVRC2012_val_00044538.JPEG +ILSVRC2012_val_00018974.JPEG +ILSVRC2012_val_00040818.JPEG +ILSVRC2012_val_00002570.JPEG +ILSVRC2012_val_00003442.JPEG +ILSVRC2012_val_00047033.JPEG +ILSVRC2012_val_00004165.JPEG +ILSVRC2012_val_00014811.JPEG +ILSVRC2012_val_00036742.JPEG +ILSVRC2012_val_00030646.JPEG +ILSVRC2012_val_00023380.JPEG +ILSVRC2012_val_00011250.JPEG +ILSVRC2012_val_00011864.JPEG +ILSVRC2012_val_00012014.JPEG +ILSVRC2012_val_00021600.JPEG +ILSVRC2012_val_00010979.JPEG +ILSVRC2012_val_00017764.JPEG +ILSVRC2012_val_00016702.JPEG +ILSVRC2012_val_00016197.JPEG +ILSVRC2012_val_00017035.JPEG +ILSVRC2012_val_00044214.JPEG +ILSVRC2012_val_00024763.JPEG +ILSVRC2012_val_00045762.JPEG +ILSVRC2012_val_00033644.JPEG +ILSVRC2012_val_00027530.JPEG +ILSVRC2012_val_00011379.JPEG +ILSVRC2012_val_00045015.JPEG +ILSVRC2012_val_00012819.JPEG +ILSVRC2012_val_00047060.JPEG +ILSVRC2012_val_00044059.JPEG +ILSVRC2012_val_00044957.JPEG +ILSVRC2012_val_00036609.JPEG +ILSVRC2012_val_00039290.JPEG +ILSVRC2012_val_00044426.JPEG +ILSVRC2012_val_00037676.JPEG +ILSVRC2012_val_00001356.JPEG +ILSVRC2012_val_00030906.JPEG +ILSVRC2012_val_00003128.JPEG +ILSVRC2012_val_00042177.JPEG +ILSVRC2012_val_00022806.JPEG +ILSVRC2012_val_00022611.JPEG +ILSVRC2012_val_00022101.JPEG +ILSVRC2012_val_00018827.JPEG +ILSVRC2012_val_00039278.JPEG +ILSVRC2012_val_00043778.JPEG +ILSVRC2012_val_00010670.JPEG +ILSVRC2012_val_00026388.JPEG +ILSVRC2012_val_00004668.JPEG +ILSVRC2012_val_00040556.JPEG +ILSVRC2012_val_00035482.JPEG +ILSVRC2012_val_00007882.JPEG +ILSVRC2012_val_00043852.JPEG +ILSVRC2012_val_00002454.JPEG +ILSVRC2012_val_00030805.JPEG +ILSVRC2012_val_00013701.JPEG +ILSVRC2012_val_00002671.JPEG +ILSVRC2012_val_00016060.JPEG +ILSVRC2012_val_00028634.JPEG +ILSVRC2012_val_00049726.JPEG +ILSVRC2012_val_00016732.JPEG +ILSVRC2012_val_00001411.JPEG +ILSVRC2012_val_00042834.JPEG +ILSVRC2012_val_00041925.JPEG +ILSVRC2012_val_00020939.JPEG +ILSVRC2012_val_00015097.JPEG +ILSVRC2012_val_00028420.JPEG +ILSVRC2012_val_00037946.JPEG +ILSVRC2012_val_00004246.JPEG +ILSVRC2012_val_00027299.JPEG +ILSVRC2012_val_00009048.JPEG +ILSVRC2012_val_00014522.JPEG +ILSVRC2012_val_00026400.JPEG +ILSVRC2012_val_00027841.JPEG +ILSVRC2012_val_00020641.JPEG +ILSVRC2012_val_00006240.JPEG +ILSVRC2012_val_00007603.JPEG +ILSVRC2012_val_00045301.JPEG +ILSVRC2012_val_00000772.JPEG +ILSVRC2012_val_00018275.JPEG +ILSVRC2012_val_00034080.JPEG +ILSVRC2012_val_00006705.JPEG +ILSVRC2012_val_00028230.JPEG +ILSVRC2012_val_00009457.JPEG +ILSVRC2012_val_00026175.JPEG +ILSVRC2012_val_00022811.JPEG +ILSVRC2012_val_00039183.JPEG +ILSVRC2012_val_00017158.JPEG +ILSVRC2012_val_00007013.JPEG +ILSVRC2012_val_00010298.JPEG +ILSVRC2012_val_00020317.JPEG +ILSVRC2012_val_00023878.JPEG +ILSVRC2012_val_00038489.JPEG +ILSVRC2012_val_00027393.JPEG +ILSVRC2012_val_00029869.JPEG +ILSVRC2012_val_00043668.JPEG +ILSVRC2012_val_00027391.JPEG +ILSVRC2012_val_00011763.JPEG +ILSVRC2012_val_00019050.JPEG +ILSVRC2012_val_00032472.JPEG +ILSVRC2012_val_00008868.JPEG +ILSVRC2012_val_00048126.JPEG +ILSVRC2012_val_00004930.JPEG +ILSVRC2012_val_00007361.JPEG +ILSVRC2012_val_00017668.JPEG +ILSVRC2012_val_00028092.JPEG +ILSVRC2012_val_00009286.JPEG +ILSVRC2012_val_00006548.JPEG +ILSVRC2012_val_00038548.JPEG +ILSVRC2012_val_00023793.JPEG +ILSVRC2012_val_00047389.JPEG +ILSVRC2012_val_00029817.JPEG +ILSVRC2012_val_00001707.JPEG +ILSVRC2012_val_00001934.JPEG +ILSVRC2012_val_00013376.JPEG +ILSVRC2012_val_00004358.JPEG +ILSVRC2012_val_00021807.JPEG +ILSVRC2012_val_00002366.JPEG +ILSVRC2012_val_00028453.JPEG +ILSVRC2012_val_00031776.JPEG +ILSVRC2012_val_00020943.JPEG +ILSVRC2012_val_00022738.JPEG +ILSVRC2012_val_00004013.JPEG +ILSVRC2012_val_00041070.JPEG +ILSVRC2012_val_00048832.JPEG +ILSVRC2012_val_00024418.JPEG +ILSVRC2012_val_00049429.JPEG +ILSVRC2012_val_00004011.JPEG +ILSVRC2012_val_00008107.JPEG +ILSVRC2012_val_00034276.JPEG +ILSVRC2012_val_00023395.JPEG +ILSVRC2012_val_00026084.JPEG +ILSVRC2012_val_00024520.JPEG +ILSVRC2012_val_00035168.JPEG +ILSVRC2012_val_00028200.JPEG +ILSVRC2012_val_00046559.JPEG +ILSVRC2012_val_00027549.JPEG +ILSVRC2012_val_00039498.JPEG +ILSVRC2012_val_00045935.JPEG +ILSVRC2012_val_00015181.JPEG +ILSVRC2012_val_00034969.JPEG +ILSVRC2012_val_00001485.JPEG +ILSVRC2012_val_00049208.JPEG +ILSVRC2012_val_00008870.JPEG +ILSVRC2012_val_00013963.JPEG +ILSVRC2012_val_00026284.JPEG +ILSVRC2012_val_00041943.JPEG +ILSVRC2012_val_00001136.JPEG +ILSVRC2012_val_00047927.JPEG +ILSVRC2012_val_00006354.JPEG +ILSVRC2012_val_00019935.JPEG +ILSVRC2012_val_00044783.JPEG +ILSVRC2012_val_00015920.JPEG +ILSVRC2012_val_00002934.JPEG +ILSVRC2012_val_00019552.JPEG +ILSVRC2012_val_00042965.JPEG +ILSVRC2012_val_00006514.JPEG +ILSVRC2012_val_00043766.JPEG +ILSVRC2012_val_00047238.JPEG +ILSVRC2012_val_00038505.JPEG +ILSVRC2012_val_00039719.JPEG +ILSVRC2012_val_00027309.JPEG +ILSVRC2012_val_00028722.JPEG +ILSVRC2012_val_00018962.JPEG +ILSVRC2012_val_00015552.JPEG +ILSVRC2012_val_00037464.JPEG +ILSVRC2012_val_00041612.JPEG +ILSVRC2012_val_00045914.JPEG +ILSVRC2012_val_00009529.JPEG +ILSVRC2012_val_00025521.JPEG +ILSVRC2012_val_00027893.JPEG +ILSVRC2012_val_00039957.JPEG +ILSVRC2012_val_00020538.JPEG +ILSVRC2012_val_00038657.JPEG +ILSVRC2012_val_00019539.JPEG +ILSVRC2012_val_00023966.JPEG +ILSVRC2012_val_00026754.JPEG +ILSVRC2012_val_00030147.JPEG +ILSVRC2012_val_00045961.JPEG +ILSVRC2012_val_00031495.JPEG +ILSVRC2012_val_00040424.JPEG +ILSVRC2012_val_00011341.JPEG +ILSVRC2012_val_00017136.JPEG +ILSVRC2012_val_00015082.JPEG +ILSVRC2012_val_00003013.JPEG +ILSVRC2012_val_00005164.JPEG +ILSVRC2012_val_00041905.JPEG +ILSVRC2012_val_00036901.JPEG +ILSVRC2012_val_00040732.JPEG +ILSVRC2012_val_00047078.JPEG +ILSVRC2012_val_00044744.JPEG +ILSVRC2012_val_00010881.JPEG +ILSVRC2012_val_00024246.JPEG +ILSVRC2012_val_00008644.JPEG +ILSVRC2012_val_00024365.JPEG +ILSVRC2012_val_00033597.JPEG +ILSVRC2012_val_00040218.JPEG +ILSVRC2012_val_00024820.JPEG +ILSVRC2012_val_00032475.JPEG +ILSVRC2012_val_00002042.JPEG +ILSVRC2012_val_00020665.JPEG +ILSVRC2012_val_00048826.JPEG +ILSVRC2012_val_00020391.JPEG +ILSVRC2012_val_00036851.JPEG +ILSVRC2012_val_00046218.JPEG +ILSVRC2012_val_00027203.JPEG +ILSVRC2012_val_00010602.JPEG +ILSVRC2012_val_00047908.JPEG +ILSVRC2012_val_00017566.JPEG +ILSVRC2012_val_00008908.JPEG +ILSVRC2012_val_00009378.JPEG +ILSVRC2012_val_00024524.JPEG +ILSVRC2012_val_00048249.JPEG +ILSVRC2012_val_00037333.JPEG +ILSVRC2012_val_00043595.JPEG +ILSVRC2012_val_00015265.JPEG +ILSVRC2012_val_00018530.JPEG +ILSVRC2012_val_00003524.JPEG +ILSVRC2012_val_00046207.JPEG +ILSVRC2012_val_00020531.JPEG +ILSVRC2012_val_00036591.JPEG +ILSVRC2012_val_00049981.JPEG +ILSVRC2012_val_00018011.JPEG +ILSVRC2012_val_00022238.JPEG +ILSVRC2012_val_00009205.JPEG +ILSVRC2012_val_00048316.JPEG +ILSVRC2012_val_00036420.JPEG +ILSVRC2012_val_00049731.JPEG +ILSVRC2012_val_00020692.JPEG +ILSVRC2012_val_00020834.JPEG +ILSVRC2012_val_00006822.JPEG +ILSVRC2012_val_00012187.JPEG +ILSVRC2012_val_00043542.JPEG +ILSVRC2012_val_00014455.JPEG +ILSVRC2012_val_00023862.JPEG +ILSVRC2012_val_00037834.JPEG +ILSVRC2012_val_00003105.JPEG +ILSVRC2012_val_00043289.JPEG +ILSVRC2012_val_00031173.JPEG +ILSVRC2012_val_00044380.JPEG +ILSVRC2012_val_00042489.JPEG +ILSVRC2012_val_00033270.JPEG +ILSVRC2012_val_00012606.JPEG +ILSVRC2012_val_00023281.JPEG +ILSVRC2012_val_00028866.JPEG +ILSVRC2012_val_00005032.JPEG +ILSVRC2012_val_00014370.JPEG +ILSVRC2012_val_00029132.JPEG +ILSVRC2012_val_00010115.JPEG +ILSVRC2012_val_00021298.JPEG +ILSVRC2012_val_00042172.JPEG +ILSVRC2012_val_00027345.JPEG +ILSVRC2012_val_00009836.JPEG +ILSVRC2012_val_00037681.JPEG +ILSVRC2012_val_00030939.JPEG +ILSVRC2012_val_00049576.JPEG +ILSVRC2012_val_00024107.JPEG +ILSVRC2012_val_00029567.JPEG +ILSVRC2012_val_00005268.JPEG +ILSVRC2012_val_00031188.JPEG +ILSVRC2012_val_00011050.JPEG +ILSVRC2012_val_00015741.JPEG +ILSVRC2012_val_00023722.JPEG +ILSVRC2012_val_00027023.JPEG +ILSVRC2012_val_00030562.JPEG +ILSVRC2012_val_00044343.JPEG +ILSVRC2012_val_00036980.JPEG +ILSVRC2012_val_00007958.JPEG +ILSVRC2012_val_00012229.JPEG +ILSVRC2012_val_00023772.JPEG +ILSVRC2012_val_00024091.JPEG +ILSVRC2012_val_00022685.JPEG +ILSVRC2012_val_00028470.JPEG +ILSVRC2012_val_00038485.JPEG +ILSVRC2012_val_00046507.JPEG +ILSVRC2012_val_00002658.JPEG +ILSVRC2012_val_00034078.JPEG +ILSVRC2012_val_00009285.JPEG +ILSVRC2012_val_00047674.JPEG +ILSVRC2012_val_00039164.JPEG +ILSVRC2012_val_00046051.JPEG +ILSVRC2012_val_00042712.JPEG +ILSVRC2012_val_00049243.JPEG +ILSVRC2012_val_00014108.JPEG +ILSVRC2012_val_00016031.JPEG +ILSVRC2012_val_00040808.JPEG +ILSVRC2012_val_00042875.JPEG +ILSVRC2012_val_00049164.JPEG +ILSVRC2012_val_00006560.JPEG +ILSVRC2012_val_00049924.JPEG +ILSVRC2012_val_00031866.JPEG +ILSVRC2012_val_00040311.JPEG +ILSVRC2012_val_00047714.JPEG +ILSVRC2012_val_00020736.JPEG +ILSVRC2012_val_00013176.JPEG +ILSVRC2012_val_00048334.JPEG +ILSVRC2012_val_00041987.JPEG +ILSVRC2012_val_00011658.JPEG +ILSVRC2012_val_00011445.JPEG +ILSVRC2012_val_00049191.JPEG +ILSVRC2012_val_00017785.JPEG +ILSVRC2012_val_00027141.JPEG +ILSVRC2012_val_00029909.JPEG +ILSVRC2012_val_00029419.JPEG +ILSVRC2012_val_00037762.JPEG +ILSVRC2012_val_00000414.JPEG +ILSVRC2012_val_00027793.JPEG +ILSVRC2012_val_00019619.JPEG +ILSVRC2012_val_00034791.JPEG +ILSVRC2012_val_00017077.JPEG +ILSVRC2012_val_00001829.JPEG +ILSVRC2012_val_00009850.JPEG +ILSVRC2012_val_00014624.JPEG +ILSVRC2012_val_00043739.JPEG +ILSVRC2012_val_00010187.JPEG +ILSVRC2012_val_00019618.JPEG +ILSVRC2012_val_00008918.JPEG +ILSVRC2012_val_00022306.JPEG +ILSVRC2012_val_00001950.JPEG +ILSVRC2012_val_00022834.JPEG +ILSVRC2012_val_00003541.JPEG +ILSVRC2012_val_00024533.JPEG +ILSVRC2012_val_00008970.JPEG +ILSVRC2012_val_00043505.JPEG +ILSVRC2012_val_00040961.JPEG +ILSVRC2012_val_00000934.JPEG +ILSVRC2012_val_00018168.JPEG +ILSVRC2012_val_00034804.JPEG +ILSVRC2012_val_00014430.JPEG +ILSVRC2012_val_00036373.JPEG +ILSVRC2012_val_00034431.JPEG +ILSVRC2012_val_00038424.JPEG +ILSVRC2012_val_00019974.JPEG +ILSVRC2012_val_00008018.JPEG +ILSVRC2012_val_00020192.JPEG +ILSVRC2012_val_00003947.JPEG +ILSVRC2012_val_00033370.JPEG +ILSVRC2012_val_00001535.JPEG +ILSVRC2012_val_00000026.JPEG +ILSVRC2012_val_00029866.JPEG +ILSVRC2012_val_00013991.JPEG +ILSVRC2012_val_00017956.JPEG +ILSVRC2012_val_00014481.JPEG +ILSVRC2012_val_00031792.JPEG +ILSVRC2012_val_00017422.JPEG +ILSVRC2012_val_00007354.JPEG +ILSVRC2012_val_00046776.JPEG +ILSVRC2012_val_00014515.JPEG +ILSVRC2012_val_00031346.JPEG +ILSVRC2012_val_00014493.JPEG +ILSVRC2012_val_00002744.JPEG +ILSVRC2012_val_00011028.JPEG +ILSVRC2012_val_00014943.JPEG +ILSVRC2012_val_00010944.JPEG +ILSVRC2012_val_00045350.JPEG +ILSVRC2012_val_00010935.JPEG +ILSVRC2012_val_00010418.JPEG +ILSVRC2012_val_00014041.JPEG +ILSVRC2012_val_00033913.JPEG +ILSVRC2012_val_00030100.JPEG +ILSVRC2012_val_00042592.JPEG +ILSVRC2012_val_00045929.JPEG +ILSVRC2012_val_00039303.JPEG +ILSVRC2012_val_00048538.JPEG +ILSVRC2012_val_00016464.JPEG +ILSVRC2012_val_00008570.JPEG +ILSVRC2012_val_00013984.JPEG +ILSVRC2012_val_00044712.JPEG +ILSVRC2012_val_00006068.JPEG +ILSVRC2012_val_00000117.JPEG +ILSVRC2012_val_00020007.JPEG +ILSVRC2012_val_00024153.JPEG +ILSVRC2012_val_00021002.JPEG +ILSVRC2012_val_00023117.JPEG +ILSVRC2012_val_00048345.JPEG +ILSVRC2012_val_00032253.JPEG +ILSVRC2012_val_00011559.JPEG +ILSVRC2012_val_00020562.JPEG +ILSVRC2012_val_00032893.JPEG +ILSVRC2012_val_00010224.JPEG +ILSVRC2012_val_00038834.JPEG +ILSVRC2012_val_00032530.JPEG +ILSVRC2012_val_00004460.JPEG +ILSVRC2012_val_00042349.JPEG +ILSVRC2012_val_00004666.JPEG +ILSVRC2012_val_00016822.JPEG +ILSVRC2012_val_00034474.JPEG +ILSVRC2012_val_00006365.JPEG +ILSVRC2012_val_00023740.JPEG +ILSVRC2012_val_00003485.JPEG +ILSVRC2012_val_00043030.JPEG +ILSVRC2012_val_00040256.JPEG +ILSVRC2012_val_00002841.JPEG +ILSVRC2012_val_00017133.JPEG +ILSVRC2012_val_00036251.JPEG +ILSVRC2012_val_00041214.JPEG +ILSVRC2012_val_00025983.JPEG +ILSVRC2012_val_00026738.JPEG +ILSVRC2012_val_00041591.JPEG +ILSVRC2012_val_00044640.JPEG +ILSVRC2012_val_00004953.JPEG +ILSVRC2012_val_00037606.JPEG +ILSVRC2012_val_00009917.JPEG +ILSVRC2012_val_00016290.JPEG +ILSVRC2012_val_00030638.JPEG +ILSVRC2012_val_00022645.JPEG +ILSVRC2012_val_00046095.JPEG +ILSVRC2012_val_00047746.JPEG +ILSVRC2012_val_00042568.JPEG +ILSVRC2012_val_00027195.JPEG +ILSVRC2012_val_00017239.JPEG +ILSVRC2012_val_00038529.JPEG +ILSVRC2012_val_00001178.JPEG +ILSVRC2012_val_00009625.JPEG +ILSVRC2012_val_00045924.JPEG +ILSVRC2012_val_00028646.JPEG +ILSVRC2012_val_00033723.JPEG +ILSVRC2012_val_00034289.JPEG +ILSVRC2012_val_00031737.JPEG +ILSVRC2012_val_00042843.JPEG +ILSVRC2012_val_00025415.JPEG +ILSVRC2012_val_00036443.JPEG +ILSVRC2012_val_00022753.JPEG +ILSVRC2012_val_00019802.JPEG +ILSVRC2012_val_00027248.JPEG +ILSVRC2012_val_00042659.JPEG +ILSVRC2012_val_00029306.JPEG +ILSVRC2012_val_00039594.JPEG +ILSVRC2012_val_00003307.JPEG +ILSVRC2012_val_00017786.JPEG +ILSVRC2012_val_00028291.JPEG +ILSVRC2012_val_00029202.JPEG +ILSVRC2012_val_00034285.JPEG +ILSVRC2012_val_00028714.JPEG +ILSVRC2012_val_00001667.JPEG +ILSVRC2012_val_00021208.JPEG +ILSVRC2012_val_00013306.JPEG +ILSVRC2012_val_00040849.JPEG +ILSVRC2012_val_00029498.JPEG +ILSVRC2012_val_00019493.JPEG +ILSVRC2012_val_00018695.JPEG +ILSVRC2012_val_00016110.JPEG +ILSVRC2012_val_00049783.JPEG +ILSVRC2012_val_00043053.JPEG +ILSVRC2012_val_00037382.JPEG +ILSVRC2012_val_00040754.JPEG +ILSVRC2012_val_00007383.JPEG +ILSVRC2012_val_00011474.JPEG +ILSVRC2012_val_00029746.JPEG +ILSVRC2012_val_00039621.JPEG +ILSVRC2012_val_00012336.JPEG +ILSVRC2012_val_00001488.JPEG +ILSVRC2012_val_00041937.JPEG +ILSVRC2012_val_00042491.JPEG +ILSVRC2012_val_00036947.JPEG +ILSVRC2012_val_00019704.JPEG +ILSVRC2012_val_00007968.JPEG +ILSVRC2012_val_00020712.JPEG +ILSVRC2012_val_00003439.JPEG +ILSVRC2012_val_00002435.JPEG +ILSVRC2012_val_00013133.JPEG +ILSVRC2012_val_00038853.JPEG +ILSVRC2012_val_00029897.JPEG +ILSVRC2012_val_00020485.JPEG +ILSVRC2012_val_00034764.JPEG +ILSVRC2012_val_00007569.JPEG +ILSVRC2012_val_00022450.JPEG +ILSVRC2012_val_00026531.JPEG +ILSVRC2012_val_00037272.JPEG +ILSVRC2012_val_00022644.JPEG +ILSVRC2012_val_00048348.JPEG +ILSVRC2012_val_00040440.JPEG +ILSVRC2012_val_00003044.JPEG +ILSVRC2012_val_00014379.JPEG +ILSVRC2012_val_00016943.JPEG +ILSVRC2012_val_00046055.JPEG +ILSVRC2012_val_00014024.JPEG +ILSVRC2012_val_00016183.JPEG +ILSVRC2012_val_00029099.JPEG +ILSVRC2012_val_00006593.JPEG +ILSVRC2012_val_00042031.JPEG +ILSVRC2012_val_00017843.JPEG +ILSVRC2012_val_00003430.JPEG +ILSVRC2012_val_00031679.JPEG +ILSVRC2012_val_00011490.JPEG +ILSVRC2012_val_00036769.JPEG +ILSVRC2012_val_00010126.JPEG +ILSVRC2012_val_00046775.JPEG +ILSVRC2012_val_00018950.JPEG +ILSVRC2012_val_00015380.JPEG +ILSVRC2012_val_00019494.JPEG +ILSVRC2012_val_00027058.JPEG +ILSVRC2012_val_00039391.JPEG +ILSVRC2012_val_00013617.JPEG +ILSVRC2012_val_00026009.JPEG +ILSVRC2012_val_00020070.JPEG +ILSVRC2012_val_00010993.JPEG +ILSVRC2012_val_00014634.JPEG +ILSVRC2012_val_00015692.JPEG +ILSVRC2012_val_00024000.JPEG +ILSVRC2012_val_00013802.JPEG +ILSVRC2012_val_00047285.JPEG +ILSVRC2012_val_00048052.JPEG +ILSVRC2012_val_00033719.JPEG +ILSVRC2012_val_00045600.JPEG +ILSVRC2012_val_00046774.JPEG +ILSVRC2012_val_00020721.JPEG +ILSVRC2012_val_00029044.JPEG +ILSVRC2012_val_00022492.JPEG +ILSVRC2012_val_00031715.JPEG +ILSVRC2012_val_00031139.JPEG +ILSVRC2012_val_00026827.JPEG +ILSVRC2012_val_00030815.JPEG +ILSVRC2012_val_00013303.JPEG +ILSVRC2012_val_00040571.JPEG +ILSVRC2012_val_00031583.JPEG +ILSVRC2012_val_00024647.JPEG +ILSVRC2012_val_00028596.JPEG +ILSVRC2012_val_00035261.JPEG +ILSVRC2012_val_00018992.JPEG +ILSVRC2012_val_00034445.JPEG +ILSVRC2012_val_00036095.JPEG +ILSVRC2012_val_00023063.JPEG +ILSVRC2012_val_00005581.JPEG +ILSVRC2012_val_00046120.JPEG +ILSVRC2012_val_00035161.JPEG +ILSVRC2012_val_00032666.JPEG +ILSVRC2012_val_00038974.JPEG +ILSVRC2012_val_00015701.JPEG +ILSVRC2012_val_00019389.JPEG +ILSVRC2012_val_00037038.JPEG +ILSVRC2012_val_00010773.JPEG +ILSVRC2012_val_00009888.JPEG +ILSVRC2012_val_00031497.JPEG +ILSVRC2012_val_00000616.JPEG +ILSVRC2012_val_00045044.JPEG +ILSVRC2012_val_00034012.JPEG +ILSVRC2012_val_00012254.JPEG +ILSVRC2012_val_00044561.JPEG +ILSVRC2012_val_00017836.JPEG +ILSVRC2012_val_00049913.JPEG +ILSVRC2012_val_00037198.JPEG +ILSVRC2012_val_00018065.JPEG +ILSVRC2012_val_00048237.JPEG +ILSVRC2012_val_00011019.JPEG +ILSVRC2012_val_00005197.JPEG +ILSVRC2012_val_00037812.JPEG +ILSVRC2012_val_00026570.JPEG +ILSVRC2012_val_00005363.JPEG +ILSVRC2012_val_00021672.JPEG +ILSVRC2012_val_00028879.JPEG +ILSVRC2012_val_00016167.JPEG +ILSVRC2012_val_00004553.JPEG +ILSVRC2012_val_00031493.JPEG +ILSVRC2012_val_00021861.JPEG +ILSVRC2012_val_00010360.JPEG +ILSVRC2012_val_00037719.JPEG +ILSVRC2012_val_00016895.JPEG +ILSVRC2012_val_00025871.JPEG +ILSVRC2012_val_00048498.JPEG +ILSVRC2012_val_00036476.JPEG +ILSVRC2012_val_00008741.JPEG +ILSVRC2012_val_00010889.JPEG +ILSVRC2012_val_00016246.JPEG +ILSVRC2012_val_00031277.JPEG +ILSVRC2012_val_00027575.JPEG +ILSVRC2012_val_00000963.JPEG +ILSVRC2012_val_00011871.JPEG +ILSVRC2012_val_00048483.JPEG +ILSVRC2012_val_00044299.JPEG +ILSVRC2012_val_00039814.JPEG +ILSVRC2012_val_00025226.JPEG +ILSVRC2012_val_00033999.JPEG +ILSVRC2012_val_00010927.JPEG +ILSVRC2012_val_00022971.JPEG +ILSVRC2012_val_00029204.JPEG +ILSVRC2012_val_00007098.JPEG +ILSVRC2012_val_00048941.JPEG +ILSVRC2012_val_00022477.JPEG +ILSVRC2012_val_00014290.JPEG +ILSVRC2012_val_00010930.JPEG +ILSVRC2012_val_00005333.JPEG +ILSVRC2012_val_00016284.JPEG +ILSVRC2012_val_00021463.JPEG +ILSVRC2012_val_00011647.JPEG +ILSVRC2012_val_00005110.JPEG +ILSVRC2012_val_00039374.JPEG +ILSVRC2012_val_00007027.JPEG +ILSVRC2012_val_00015022.JPEG +ILSVRC2012_val_00030870.JPEG +ILSVRC2012_val_00036275.JPEG +ILSVRC2012_val_00049810.JPEG +ILSVRC2012_val_00029007.JPEG +ILSVRC2012_val_00047880.JPEG +ILSVRC2012_val_00029201.JPEG +ILSVRC2012_val_00002260.JPEG +ILSVRC2012_val_00004825.JPEG +ILSVRC2012_val_00049368.JPEG +ILSVRC2012_val_00011064.JPEG +ILSVRC2012_val_00019962.JPEG +ILSVRC2012_val_00020107.JPEG +ILSVRC2012_val_00034596.JPEG +ILSVRC2012_val_00031995.JPEG +ILSVRC2012_val_00026021.JPEG +ILSVRC2012_val_00022157.JPEG +ILSVRC2012_val_00033290.JPEG +ILSVRC2012_val_00028205.JPEG +ILSVRC2012_val_00026066.JPEG +ILSVRC2012_val_00032885.JPEG +ILSVRC2012_val_00023036.JPEG +ILSVRC2012_val_00038029.JPEG +ILSVRC2012_val_00006408.JPEG +ILSVRC2012_val_00008746.JPEG +ILSVRC2012_val_00025172.JPEG +ILSVRC2012_val_00036431.JPEG +ILSVRC2012_val_00024641.JPEG +ILSVRC2012_val_00040857.JPEG +ILSVRC2012_val_00015339.JPEG +ILSVRC2012_val_00013270.JPEG +ILSVRC2012_val_00023779.JPEG +ILSVRC2012_val_00043115.JPEG +ILSVRC2012_val_00022363.JPEG +ILSVRC2012_val_00006088.JPEG +ILSVRC2012_val_00043210.JPEG +ILSVRC2012_val_00015596.JPEG +ILSVRC2012_val_00006724.JPEG +ILSVRC2012_val_00013292.JPEG +ILSVRC2012_val_00024101.JPEG +ILSVRC2012_val_00013419.JPEG +ILSVRC2012_val_00040948.JPEG +ILSVRC2012_val_00029692.JPEG +ILSVRC2012_val_00039341.JPEG +ILSVRC2012_val_00003086.JPEG +ILSVRC2012_val_00007980.JPEG +ILSVRC2012_val_00017108.JPEG +ILSVRC2012_val_00018194.JPEG +ILSVRC2012_val_00034179.JPEG +ILSVRC2012_val_00010669.JPEG +ILSVRC2012_val_00046963.JPEG +ILSVRC2012_val_00039431.JPEG +ILSVRC2012_val_00017044.JPEG +ILSVRC2012_val_00025284.JPEG +ILSVRC2012_val_00031808.JPEG +ILSVRC2012_val_00039018.JPEG +ILSVRC2012_val_00040646.JPEG +ILSVRC2012_val_00015532.JPEG +ILSVRC2012_val_00043496.JPEG +ILSVRC2012_val_00018681.JPEG +ILSVRC2012_val_00002804.JPEG +ILSVRC2012_val_00014117.JPEG +ILSVRC2012_val_00033949.JPEG +ILSVRC2012_val_00043431.JPEG +ILSVRC2012_val_00021070.JPEG +ILSVRC2012_val_00039389.JPEG +ILSVRC2012_val_00020060.JPEG +ILSVRC2012_val_00013111.JPEG +ILSVRC2012_val_00039712.JPEG +ILSVRC2012_val_00037344.JPEG +ILSVRC2012_val_00026736.JPEG +ILSVRC2012_val_00048004.JPEG +ILSVRC2012_val_00039932.JPEG +ILSVRC2012_val_00004853.JPEG +ILSVRC2012_val_00026014.JPEG +ILSVRC2012_val_00003453.JPEG +ILSVRC2012_val_00003382.JPEG +ILSVRC2012_val_00016743.JPEG +ILSVRC2012_val_00042445.JPEG +ILSVRC2012_val_00047349.JPEG +ILSVRC2012_val_00030902.JPEG +ILSVRC2012_val_00004175.JPEG +ILSVRC2012_val_00032850.JPEG +ILSVRC2012_val_00005821.JPEG +ILSVRC2012_val_00020058.JPEG +ILSVRC2012_val_00023328.JPEG +ILSVRC2012_val_00040355.JPEG +ILSVRC2012_val_00001147.JPEG +ILSVRC2012_val_00037460.JPEG +ILSVRC2012_val_00042724.JPEG +ILSVRC2012_val_00011156.JPEG +ILSVRC2012_val_00004985.JPEG +ILSVRC2012_val_00035489.JPEG +ILSVRC2012_val_00018393.JPEG +ILSVRC2012_val_00014268.JPEG +ILSVRC2012_val_00036338.JPEG +ILSVRC2012_val_00034053.JPEG +ILSVRC2012_val_00013050.JPEG +ILSVRC2012_val_00048388.JPEG +ILSVRC2012_val_00004609.JPEG +ILSVRC2012_val_00004294.JPEG +ILSVRC2012_val_00024467.JPEG +ILSVRC2012_val_00031169.JPEG +ILSVRC2012_val_00001642.JPEG +ILSVRC2012_val_00042949.JPEG +ILSVRC2012_val_00039809.JPEG +ILSVRC2012_val_00031501.JPEG +ILSVRC2012_val_00025981.JPEG +ILSVRC2012_val_00017769.JPEG +ILSVRC2012_val_00027896.JPEG +ILSVRC2012_val_00022287.JPEG +ILSVRC2012_val_00022258.JPEG +ILSVRC2012_val_00040191.JPEG +ILSVRC2012_val_00012526.JPEG +ILSVRC2012_val_00034634.JPEG +ILSVRC2012_val_00040526.JPEG +ILSVRC2012_val_00016396.JPEG +ILSVRC2012_val_00048728.JPEG +ILSVRC2012_val_00034191.JPEG +ILSVRC2012_val_00036005.JPEG +ILSVRC2012_val_00034387.JPEG +ILSVRC2012_val_00032922.JPEG +ILSVRC2012_val_00012954.JPEG +ILSVRC2012_val_00004876.JPEG +ILSVRC2012_val_00044651.JPEG +ILSVRC2012_val_00037398.JPEG +ILSVRC2012_val_00010883.JPEG +ILSVRC2012_val_00026632.JPEG +ILSVRC2012_val_00020242.JPEG +ILSVRC2012_val_00001954.JPEG +ILSVRC2012_val_00005694.JPEG +ILSVRC2012_val_00013681.JPEG +ILSVRC2012_val_00014356.JPEG +ILSVRC2012_val_00033863.JPEG +ILSVRC2012_val_00035654.JPEG +ILSVRC2012_val_00019886.JPEG +ILSVRC2012_val_00047866.JPEG +ILSVRC2012_val_00021351.JPEG +ILSVRC2012_val_00045139.JPEG +ILSVRC2012_val_00023909.JPEG +ILSVRC2012_val_00010162.JPEG +ILSVRC2012_val_00049165.JPEG +ILSVRC2012_val_00020491.JPEG +ILSVRC2012_val_00031208.JPEG +ILSVRC2012_val_00044420.JPEG +ILSVRC2012_val_00023363.JPEG +ILSVRC2012_val_00003908.JPEG +ILSVRC2012_val_00018056.JPEG +ILSVRC2012_val_00036644.JPEG +ILSVRC2012_val_00000796.JPEG +ILSVRC2012_val_00011249.JPEG +ILSVRC2012_val_00012399.JPEG +ILSVRC2012_val_00001413.JPEG +ILSVRC2012_val_00022891.JPEG +ILSVRC2012_val_00049015.JPEG +ILSVRC2012_val_00047456.JPEG +ILSVRC2012_val_00011597.JPEG +ILSVRC2012_val_00015557.JPEG +ILSVRC2012_val_00027782.JPEG +ILSVRC2012_val_00037493.JPEG +ILSVRC2012_val_00000504.JPEG +ILSVRC2012_val_00021016.JPEG +ILSVRC2012_val_00002053.JPEG +ILSVRC2012_val_00027616.JPEG +ILSVRC2012_val_00045588.JPEG +ILSVRC2012_val_00044091.JPEG +ILSVRC2012_val_00039946.JPEG +ILSVRC2012_val_00011038.JPEG +ILSVRC2012_val_00019127.JPEG +ILSVRC2012_val_00023071.JPEG +ILSVRC2012_val_00040757.JPEG +ILSVRC2012_val_00020164.JPEG +ILSVRC2012_val_00013643.JPEG +ILSVRC2012_val_00030387.JPEG +ILSVRC2012_val_00018912.JPEG +ILSVRC2012_val_00035687.JPEG +ILSVRC2012_val_00040278.JPEG +ILSVRC2012_val_00009163.JPEG +ILSVRC2012_val_00023857.JPEG +ILSVRC2012_val_00037991.JPEG +ILSVRC2012_val_00025057.JPEG +ILSVRC2012_val_00028908.JPEG +ILSVRC2012_val_00010737.JPEG +ILSVRC2012_val_00015519.JPEG +ILSVRC2012_val_00010983.JPEG +ILSVRC2012_val_00013517.JPEG +ILSVRC2012_val_00033350.JPEG +ILSVRC2012_val_00010611.JPEG +ILSVRC2012_val_00031130.JPEG +ILSVRC2012_val_00024581.JPEG +ILSVRC2012_val_00023266.JPEG +ILSVRC2012_val_00023056.JPEG +ILSVRC2012_val_00031333.JPEG +ILSVRC2012_val_00026264.JPEG +ILSVRC2012_val_00037355.JPEG +ILSVRC2012_val_00015767.JPEG \ No newline at end of file diff --git a/big_vision/evaluators/proj/uvim/common.py b/big_vision/evaluators/proj/uvim/common.py new file mode 100644 index 0000000000000000000000000000000000000000..937fde88ea416c87346e2dd9dfbd3de8940e6867 --- /dev/null +++ b/big_vision/evaluators/proj/uvim/common.py @@ -0,0 +1,64 @@ +# Copyright 2022 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Common utilities used in evaluators.""" +import math +import jax +import tensorflow as tf +import tensorflow_datasets as tfds + + +def get_jax_process_dataset(dataset, split, global_batch_size, pp_fn, + dataset_dir=None, cache=True, add_tfds_id=False): + """Returns dataset to be processed by current jax host. + + The dataset is sharded and padded with zeros such that all processes + have equal number of batches. The first 2 dimensions of the dataset + elements are: [local_device_count, device_batch_size]. + + Args: + dataset: dataset name. + split: dataset split. + global_batch_size: batch size to be process per iteration on the dataset. + pp_fn: preprocessing function to apply per example. + dataset_dir: path for tfds to find the prepared data. + cache: whether to cache the dataset after batching. + add_tfds_id: whether to add the unique `tfds_id` string to each example. + """ + assert global_batch_size % jax.device_count() == 0 + total_examples = tfds.load( + dataset, split=split, data_dir=dataset_dir).cardinality() + num_batches = math.ceil(total_examples / global_batch_size) + + process_split = tfds.even_splits( + split, n=jax.process_count(), drop_remainder=False)[jax.process_index()] + data = tfds.load( + dataset, + split=process_split, + data_dir=dataset_dir, + read_config=tfds.ReadConfig(add_tfds_id=add_tfds_id)).map(pp_fn) + pad_data = tf.data.Dataset.from_tensors( + jax.tree_map(lambda x: tf.zeros(x.shape, x.dtype), data.element_spec) + ).repeat() + + data = data.concatenate(pad_data) + data = data.batch(global_batch_size // jax.device_count()) + data = data.batch(jax.local_device_count()) + data = data.take(num_batches) + if cache: + # Eval datasets are often used many times and caching the dataset after + # batching allows one to have the buffers ready to be used and not have + # to wait for preprocessing to be done over and over. + data = data.cache() + return data diff --git a/big_vision/evaluators/proj/uvim/compute_mean.py b/big_vision/evaluators/proj/uvim/compute_mean.py new file mode 100644 index 0000000000000000000000000000000000000000..86bceac51b5a97231cff86ab4344331d67efe35a --- /dev/null +++ b/big_vision/evaluators/proj/uvim/compute_mean.py @@ -0,0 +1,79 @@ +# Copyright 2022 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Evaluator for computing mean of per-example metrics.""" +import functools +from typing import Mapping + +from big_vision import input_pipeline +from big_vision.datasets import core as ds_core +from big_vision.pp import builder as pp_builder + +import jax +import jax.numpy as jnp +import numpy as np + + +# Note: global to avoid jax re-compiling across different evaluator instances. +@functools.partial(jax.pmap, static_broadcasted_argnums=0, axis_name='batch') +def _run_predict_fn(predict_fn, params, batch): + """Sum per-example metrics weighted by `_mask`.""" + mask = batch['_mask'] + metrics = predict_fn(params, batch) + # Sanity check output format of predict_fn. + assert isinstance(metrics, Mapping), 'predict_fn must return a dict' + for y in jax.tree_leaves(metrics): + if y.shape != mask.shape: + raise ValueError( + f'Expected per-example metrics of shape {mask.shape} found ' + f'{jax.tree_map(lambda x: x.shape, metrics)}.') + metrics = {**metrics, '_mask': mask} + metrics = jax.tree_map(lambda x: jnp.inner(x, mask), metrics) + return jax.lax.psum(metrics, axis_name='batch') + + +class Evaluator: + """Report the mean of per-example metrics computed by predict_fn. + + `predict_fn(params, batch)` must return a dict from metric name to + per-example metrics of shape [batch_size]. + """ + + def __init__(self, predict_fn, data, pp_fn, batch_size, + cache_final=True, cache_raw=False, prefetch=1): + data = ds_core.get(**data) + self.dataset, self.steps = input_pipeline.make_for_inference( + data.get_tfdata(ordered=True), batch_size=batch_size, + num_ex_per_process=data.num_examples_per_process(), + preprocess_fn=pp_builder.get_preprocess_fn(pp_fn), + cache_final=cache_final, cache_raw=cache_raw) + self.data_iter = input_pipeline.start_input_pipeline(self.dataset, prefetch) + self.predict_fn = predict_fn + + def run(self, params): + """Computes all metrics.""" + metrics = [] + + # Compute batch metrics without blocking. + for _, batch in zip(range(self.steps), self.data_iter): + batch_metrics = _run_predict_fn(self.predict_fn, params, batch) + metrics.append(batch_metrics) + + # Transfer metrics from device 0 to host (blocking). + metrics = jax.device_get(jax.tree_map(lambda x: x[0], metrics)) + + metrics_sum = jax.tree_map(lambda *x: np.sum(x), *metrics) + mask_sum = metrics_sum.pop('_mask') + for key, value_sum in metrics_sum.items(): + yield (key, value_sum / mask_sum) diff --git a/big_vision/evaluators/proj/uvim/nyu_depth.py b/big_vision/evaluators/proj/uvim/nyu_depth.py new file mode 100644 index 0000000000000000000000000000000000000000..ea5d3e706d9ae50f8395930b7ddd033c1fc120e2 --- /dev/null +++ b/big_vision/evaluators/proj/uvim/nyu_depth.py @@ -0,0 +1,154 @@ +# Copyright 2022 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Evaluation for NYU depth. + +At evaluation time the ground truth is cropped and clipped. Values outside of +the test crop or clipping range are not included in eval calculations. + +In this evaluator, it is assume that the groud truth is already cropped, so the +entire image is evaluated. However, the evaluator does perform the clipping. + +Reference implementations: + https://github.com/zhyever/Monocular-Depth-Estimation-Toolbox/blo(internal link)a0f341244260ff61541191a613dd74bc/depth/datasets/nyu.py + https://github.com/vinvino02/GLPDepth/blob/7f3c78df4ecd6e7c79fd0c4b73c95d61f4aa2121/code/utils/metrics.py + https://github.com/shariqfarooq123/AdaBins/blob/2fb686a66a304f0a719bc53d77412460af97fd61/evaluate.py +""" + +import functools + +import big_vision.evaluators.proj.uvim.common as common +import big_vision.pp.builder as pp_builder +import jax +import jax.numpy as jnp +import numpy as np +import tensorflow as tf + +EVAL_CROP_H = 426 +EVAL_CROP_W = 560 + + +class Evaluator: + """Evaluator for NYU depth.""" + + def __init__(self, + predict_fn, + pp_fn, + batch_size, + dataset, + split, + min_depth=1e-3, + max_depth=10, + dataset_dir=None, + predict_kwargs=None): + self.min_depth = min_depth + self.max_depth = max_depth + + def predict(params, batch): + pred = predict_fn(params, batch, **(predict_kwargs or {})) + + return jax.lax.all_gather({ + "mask": batch["mask"], + "gt": jnp.squeeze(batch["ground_truth"], axis=-1), + "y": pred["depth"], + }, axis_name="data", axis=0) + + self.predict_fn = jax.pmap(predict, axis_name="data") + + # Prepare data for each process and pad with zeros so all processes have the + # same number of batches. + def preprocess(example): + return { + "mask": tf.constant(1), + **pp_builder.get_preprocess_fn(pp_fn)(example), + } + + self.process_batch_size = batch_size // jax.process_count() + + self.data = common.get_jax_process_dataset( + dataset=dataset, + dataset_dir=dataset_dir, + split=split, + global_batch_size=batch_size, + pp_fn=preprocess) + + def run(self, params): + """Run eval.""" + # Assumes that the ground truth is processed by the eval crop. + eval_mask = np.ones((EVAL_CROP_H, EVAL_CROP_W), dtype=np.bool_) + rmses = [] + abs_res = [] + abs_logs = [] + d1s = [] + d2s = [] + d3s = [] + for batch in self.data.as_numpy_iterator(): + # Outputs is a dict with values shaped (gather/same, devices, batch, ...) + out = self.predict_fn(params, batch) + + if jax.process_index(): # Host0 gets all preds and does eval. + continue + + # First, we remove the "gather" dim and transfer the result to host, + # leading to numpy arrays of (devices, device_batch, ...) + out = jax.tree_map(lambda x: jax.device_get(x[0]), out) + # Then the bool-indexing with mask resulting in flat (global_batch, ...) + out = jax.tree_map(lambda x: x[out["mask"] == 1], out) # pylint:disable=cell-var-from-loop + + for gt, pred in zip(out["gt"], out["y"]): + pred = _resize_nearest(pred, (EVAL_CROP_H, EVAL_CROP_W)) + valid_mask = np.logical_and(gt > self.min_depth, gt < self.max_depth) + valid_mask = np.logical_and(valid_mask, eval_mask) + + rmses.append(_compute_rmse(gt[valid_mask], pred[valid_mask])) + abs_res.append(_compute_abs_re(gt[valid_mask], pred[valid_mask])) + abs_logs.append(_compute_abs_log(gt[valid_mask], pred[valid_mask])) + d1s.append(_compute_delta(gt[valid_mask], pred[valid_mask], order=1)) + d2s.append(_compute_delta(gt[valid_mask], pred[valid_mask], order=2)) + d3s.append(_compute_delta(gt[valid_mask], pred[valid_mask], order=3)) + + if jax.process_index(): # Host0 gets all preds and does eval. + return + + yield "RMSE", np.mean(rmses) + yield "abs_RE", np.mean(abs_res) + yield "log10", np.mean(abs_logs) + yield "delta1", np.mean(d1s) + yield "delta2", np.mean(d2s) + yield "delta3", np.mean(d3s) + + +@functools.partial(jax.jit, static_argnums=(1,), backend="cpu") +def _resize_nearest(image, shape): + return jax.image.resize(image, shape, "nearest") + + +def _compute_rmse(gt, pred): + diff = gt - pred + return np.sqrt(np.mean(np.power(diff, 2))) + + +def _compute_abs_re(gt, pred): + diff = np.abs(gt - pred) + return np.mean(diff / gt) + + +def _compute_abs_log(gt, pred): + diff = np.abs(np.log10(gt) - np.log10(pred)) + return np.mean(diff) + + +def _compute_delta(gt, pred, order): + rel_diff = np.maximum(gt / pred, pred / gt) + return np.sum(rel_diff < 1.25**order) / rel_diff.size diff --git a/big_vision/evaluators/proj/uvim/psnr.py b/big_vision/evaluators/proj/uvim/psnr.py new file mode 100644 index 0000000000000000000000000000000000000000..dc9404b86b4d38077182063bebb8f7a37f5a4c73 --- /dev/null +++ b/big_vision/evaluators/proj/uvim/psnr.py @@ -0,0 +1,100 @@ +# Copyright 2022 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Compute PSNR, currently used for colorization and superresolution.""" + +import functools + +import big_vision.evaluators.proj.uvim.common as common +import big_vision.pp.builder as pp_builder +import jax +import jax.numpy as jnp +import numpy as np +import tensorflow as tf + + +class Evaluator: + """PSNR evaluator. + + `predict_fn` accepts arbitrary dictionaries of parameters and data, where + the data dictionary is produced by the `pp_fn` op. It is expected to output a + single-key dict containing an RGB image with intensities in [-1,1]. + """ + + def __init__(self, + predict_fn, + pp_fn, + batch_size, + dataset="imagenet2012", + split="validation", + predict_kwargs=None): + + def predict(params, batch): + + def _f(x): + y = predict_fn(params, x, **(predict_kwargs or {})) + # Assume image intensities are in [-1,1]. + # Evaluator expects a dict with a single item. + pred, = y.values() + return _psnr(pred, x["labels"], 2.) + return jax.lax.all_gather({ + "mask": batch["mask"], + "psnr": _f(batch["input"]), + }, axis_name="data", axis=0) + + self.predict_fn = jax.pmap(predict, axis_name="data") + + # Prepare data for each process and pad with zeros so all processes have the + # same number of batches. + def preprocess(example): + return { + "mask": tf.constant(1), + "input": pp_builder.get_preprocess_fn(pp_fn)(example), + } + + self.data = common.get_jax_process_dataset( + dataset, + split, + global_batch_size=batch_size, + add_tfds_id=True, + pp_fn=preprocess) + + def run(self, params): + """Run eval.""" + psnrs = [] + + for batch in self.data.as_numpy_iterator(): + # Outputs is a dict with values shaped (gather/same, devices, batch, ...) + out = self.predict_fn(params, batch) + + if jax.process_index(): # Host0 gets all preds and does eval. + continue + + # First, we remove the "gather" dim and transfer the result to host, + # leading to numpy arrays of (devices, device_batch, ...) + out = jax.tree_map(lambda x: jax.device_get(x[0]), out) + mask = out["mask"] + batch_psnrs = out["psnr"][mask != 0] + psnrs.extend(batch_psnrs) + + if jax.process_index(): # Host0 gets all preds and does eval. + return + + yield "PSNR", np.mean(psnrs) + + +@functools.partial(jax.vmap, in_axes=[0, 0, None]) +def _psnr(img0, img1, dynamic_range): + mse = jnp.mean(jnp.power(img0 - img1, 2)) + return 20. * jnp.log10(dynamic_range) - 10. * jnp.log10(mse) diff --git a/big_vision/evaluators/proj/uvim/save_predictions.py b/big_vision/evaluators/proj/uvim/save_predictions.py new file mode 100644 index 0000000000000000000000000000000000000000..26bb75cf2a34c09c85e8018a0dd0c9c4de102cdd --- /dev/null +++ b/big_vision/evaluators/proj/uvim/save_predictions.py @@ -0,0 +1,95 @@ +# Copyright 2022 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Evaluator to save predictions.""" +# pylint: disable=consider-using-from-import +import os + +from absl import flags +from absl import logging +import big_vision.evaluators.proj.uvim.common as common +import big_vision.pp.builder as pp_builder +import big_vision.utils as u +import jax +import numpy as np +import tensorflow as tf + + +class Evaluator: + """Save predictions in "{FLAGS.workdir}/{outfile}". + + Results can then be easily inspected in a notebook such as: + + ``` + results = utils.load_checkpoint(None, "") + inputs, outputs = (results["inputs"], results["outputs"]) + ``` + """ + + def __init__(self, predict_fn, pp_fn, dataset, split, batch_size, outfile, + predict_kwargs=None, dataset_dir=None): + # Prepare to run predict on all processes and gather predictions on all + # devices. Note: if needed consider only gather across processes. + def predict(params, batch): + y = predict_fn(params, batch['inputs'], **(predict_kwargs or {})) + res = {'inputs': batch['inputs'], 'outputs': y, 'mask': batch['mask']} + return jax.lax.all_gather(res, axis_name='data', axis=0, tiled=True) + + self.predict_fn = jax.pmap(predict, axis_name='data') + + # Prepare data for each process and pad with zeros so all processes have the + # same number of batches. + def preprocess(example): + return { + 'mask': tf.constant(1), + 'inputs': pp_builder.get_preprocess_fn(pp_fn)(example), + } + self.data = common.get_jax_process_dataset( + dataset=dataset, split=split, + dataset_dir=dataset_dir, + global_batch_size=batch_size, + pp_fn=preprocess) + + self.path = os.path.join(flags.FLAGS.workdir, outfile) + + def run(self, params): + """Compute all predictions, gather in main host and save in outfile.""" + count = 0 + outputs = [] + for batch in self.data.as_numpy_iterator(): + out = self.predict_fn(params, batch) + if jax.process_index(): + continue + + out = jax.device_get(jax.tree_map(lambda x: x[0], out)) + out = jax.tree_map(lambda x: x[out['mask'] == 1], out) # pylint: disable=cell-var-from-loop + count += out['mask'].shape[0] + out.pop('mask') + outputs.append(out) + + logging.log_every_n_seconds( + logging.INFO, 'Save predictions: processed %i examples so far.', 30, + count) + + if jax.process_index(): + return + + logging.info('Save predictions: processed %d examples.', count) + + # Actually save in filesystem. + outputs = jax.tree_map(lambda *x: np.concatenate(x, axis=0), *outputs) + u.save_checkpoint(outputs, self.path, compressed=True) + return + + yield None # pylint: disable=unreachable diff --git a/big_vision/evaluators/save.py b/big_vision/evaluators/save.py new file mode 100644 index 0000000000000000000000000000000000000000..49bcfc59b9fd9c613611b1edcbd157b2d8c2d6d5 --- /dev/null +++ b/big_vision/evaluators/save.py @@ -0,0 +1,121 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Evaluator that save inputs and outputs of prediction functions.""" +import functools + +from absl import flags +from absl import logging + +from big_vision import input_pipeline +from big_vision import optax as bv_optax +from big_vision import utils +from big_vision.datasets import core as ds_core +from big_vision.pp import builder as pp_builder + +import jax +import numpy as np + +# Temporary global flag to facilitate backwards compatability. Will be removed +# by the end of year 2023. +API = 'jit' + + +# Note: global to avoid jax re-compiling across different evaluator instances. +def _run_predict_fn(predict_fn, train_state, batch): + """Run predict_fn and gather all outputs on all devices.""" + y = predict_fn(train_state, batch) + return {'inputs': batch, 'outputs': y} + + +class Evaluator: + """Evaluator that saves the inputs and outputs of a prediction function. + + Example configuration: + + ``` + config.evals.save_pred = { + 'type': 'save', + 'pred': 'inference', + 'outfile': '{workdir}/inference-{step:09d}.npz', + 'data': ..., 'pp_fn': ..., 'log_steps': ..., + } + ``` + + Results can then be easily inspected in a notebook such as: + + ``` + results = utils.load_checkpoint("") + inputs, outputs = (results["inputs"], results["outputs"]) + ``` + """ + + def __init__(self, predict_fn, data, pp_fn, batch_size, outfile, + cache_final=True, cache_raw=False, prefetch=1, *, devices): + replicate = jax.sharding.NamedSharding( + jax.sharding.Mesh(devices, ('devices',)), + jax.sharding.PartitionSpec() + ) + self.predict_fn = functools.partial( + jax.jit(_run_predict_fn, static_argnums=0, out_shardings=replicate), + predict_fn, + ) + + data = ds_core.get(**data) + self.dataset, self.steps = input_pipeline.make_for_inference( + data.get_tfdata(ordered=True), + batch_size=batch_size, + num_ex_per_process=data.num_examples_per_process(), + preprocess_fn=pp_builder.get_preprocess_fn(pp_fn), + cache_final=cache_final, + cache_raw=cache_raw, + ) + self.data_iter = input_pipeline.start_global( + self.dataset, devices, prefetch + ) + + self.outfile = outfile + + def run(self, train_state): + """Compute all predictions, gather in main host and save in outfile.""" + step = jax.device_get(bv_optax.get_count(train_state['opt'], jittable=True)) + outfile = self.outfile.format(workdir=flags.FLAGS.workdir, step=step) + + count = 0 + outputs = [] + for _, batch in zip(range(self.steps), self.data_iter): + out = self.predict_fn(train_state, batch) + if jax.process_index(): + continue + + out = jax.device_get(out) + mask = out['inputs']['_mask'] + out = jax.tree.map(lambda x: x[mask == 1], out) # pylint: disable=cell-var-from-loop + count += mask.shape[0] + out['inputs'].pop('_mask') + outputs.append(out) + + logging.log_every_n_seconds( + logging.INFO, 'Processed %i examples so far.', 60, + count) + + if jax.process_index(): + return + + logging.info('Saving %d examples in %s', count, outfile) + outputs = jax.tree.map(lambda *x: np.concatenate(x, axis=0), *outputs) + utils.save_checkpoint(outputs, outfile, compressed=True) + return + + yield None # pylint: disable=unreachable diff --git a/big_vision/input_pipeline.py b/big_vision/input_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..742e09eeb9f974ad11cc6d9e92c9a25fb394bb7d --- /dev/null +++ b/big_vision/input_pipeline.py @@ -0,0 +1,346 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""ImageNet input pipeline.""" +import collections +import functools +import itertools +import math +import multiprocessing.pool + +from absl import logging +from big_vision.datasets import sequence_packing +import big_vision.datasets.core as ds_core +import big_vision.pp.builder as pp_builder +import big_vision.utils as u +import einops +import jax +import numpy as np +import tensorflow as tf + + +DEFAULT_NUM_PARALLEL_CALLS = 100 + + +def make_for_train( + data, preprocess_fn, batch_size, + shuffle_buffer_size=None, cache_raw=False, + num_parallel_calls=DEFAULT_NUM_PARALLEL_CALLS, prefetch=2, + *, + pre_filter_fn=None, post_filter_fn=None, + pack=None, skip_errors=False, +): + """Makes an input pipeline for training.""" + # Use data filtering at your own risk: the actual split sizes won't be known + # in advance, so epoch-based things won't work correctly. + + data = _add_tpu_host_options(data) + + data = data.filter(pre_filter_fn) if pre_filter_fn else data + data = data.cache() if cache_raw else data + + # First shuffle and then repeat (each with a different shuffle). This way + # the data for one epoch is all seen before the next one is processed and + # significantly affects the number of times each example is seen when + # processing for small number of epochs. + if shuffle_buffer_size: + data = data.shuffle(shuffle_buffer_size, reshuffle_each_iteration=True) + data = data.repeat(None) + + data = data.map(preprocess_fn, num_parallel_calls=num_parallel_calls) + data = data.filter(post_filter_fn) if post_filter_fn else data + + data = data.ignore_errors(log_warning=True) if skip_errors else data + + data = sequence_packing.pack_dataset(data, pack) if pack else data + + # Drop remainder makes shape fully static, so we can later use it if needed. + if batch_size: + data = data.batch(batch_size // jax.process_count(), drop_remainder=True) + if prefetch: # None means autotune, but we never want that. + data = data.prefetch(prefetch) + return data + + +def training(input_config): + """Reads the data from a single dataset, or mixes it from multiple. + + The data is read either from one or mixed from multiple datasets, depending + on the `input_config`. + + Args: + input_config: Configures the input pipeline. See input_pipeline_test for + examples. + + Returns: + A tuple containing (possibly mixed) tf.data.Dataset and a total number of + training examples. + """ + per_pipeline_configs = ( + "shuffle_buffer_size", "cache_raw", "num_parallel_calls", + "pre_filter_fn", "post_filter_fn", "pack", "skip_errors") + def config_to_kw(config): + assert "filter_fn" not in config, "Deprecated; use `pre_filter_fn` instead." + return {k: config[k] for k in per_pipeline_configs if k in config} + + batch_size = input_config.batch_size + # Handle separately the common case when no mixing happens. + if isinstance(input_config.data.get("name"), str): + train_data = ds_core.get(**input_config.data) + train_ds = make_for_train( + data=train_data.get_tfdata(ordered=False), + batch_size=batch_size, + preprocess_fn=pp_builder.get_preprocess_fn(input_config.get("pp")), + prefetch=input_config.get("prefetch", 2), # Default 2 for bwd compat. + **config_to_kw(input_config) + ) + return train_ds, train_data.total_examples + + # A helpful error instead of silent ignore: + for k in per_pipeline_configs: + assert k not in input_config, f"{k} is per-dataset in multi-input." + + # Parallelize the loading of datasets when doing data mixture. + # For larger mixes, we sometimes spend >5min when doing sequentially. + # NOTE: functools.cache is thread-safe. + def _make(name_and_weight): + name, weight = name_and_weight + dataset = input_config[name] + train_data = ds_core.get(**dataset.data) + dataset = make_for_train( + data=train_data.get_tfdata(ordered=False), + # Don't batch the data just yet, it will be done after + # mixing the different datasets below. + batch_size=None, + preprocess_fn=pp_builder.get_preprocess_fn(dataset.get("pp"), name), + prefetch=0, # Prefetching each pipeline leads to huge OOMs. + **config_to_kw(dataset) + ) + if keys := input_config.get("keep_only"): + dataset = dataset.map(lambda d, keys=keys: {k: d[k] for k in keys}) + return name, dataset, weight, train_data.total_examples + + names, datasets, weights, totals = [], [], [], [] + pool = multiprocessing.pool.ThreadPool(len(input_config.data)) + for name, dataset, weight, total in pool.map( + # Skip weight=0 datasets as a convenient optimization in sweeps. + _make, ((name, w) for name, w in input_config.data.items() if w)): + names.append(name) + datasets.append(dataset) + weights.append(weight) + totals.append(total) + + # Normalize the weights such that they sum up to 1. + weights = [x / sum(weights) for x in weights] + + logging.info( + "NOTE: Total dataset mix size: %d\nContributions:\n%s", sum(totals), + "\n".join(f"{ds}: {n} ({w * 100:.1g}%)" + for ds, n, w in zip(names, totals, weights)) + ) + + train_ds = tf.data.Dataset.sample_from_datasets( + datasets, weights, stop_on_empty_dataset=True) + if input_config.get("pack"): + train_ds = sequence_packing.pack_dataset(train_ds, input_config.get("pack")) + train_ds = train_ds.batch( + input_config["batch_size"] // jax.process_count(), drop_remainder=True) + if (pf := input_config.get("prefetch", 2)): + train_ds = train_ds.prefetch(pf) + + return train_ds, sum(totals) + + +# The pipeline below is used for evals in multi-{G,T}PU and multi-host settings. +# As the total number of examples may not be evenly divisible accross all +# devices, we use the `infinite tf.data padding` trick, which was suggested by +# Andreas Steiner and also implemented by him in the clu library: +# https://github.com/google/CommonLoopUtils/blob/84b777c42dfd3fb6685537138433bfeb5241a006/clu/deterministic_data.py#L304. +def make_for_inference( + data, preprocess_fn, batch_size, num_ex_per_process, + cache_raw=False, cache_final=False, + num_parallel_calls=DEFAULT_NUM_PARALLEL_CALLS, prefetch=1, +): + """Makes an input pipeline for inference.""" + + data = _add_tpu_host_options(data) + data = data.cache() if cache_raw else data + data = data.map(_add_internal_fields(preprocess_fn), + num_parallel_calls=num_parallel_calls) + data = data.concatenate(_get_pad_data(data)) + + local_batch_size = batch_size // jax.process_count() + # This is just like `batch`, but allows batching elements of different shapes + # into a tf.RaggedTensor. Elements of the same fixed shape remain tf.Tensors. + # Since we do 'infinite' padding it is safe to drop the remainder. + data = data.ragged_batch(batch_size=local_batch_size, drop_remainder=True) + + # We need to make sure that all hosts process all data and exactly the same + # number of batches. Below we take max per-host num examples and use it on all + # hosts to derive the number of batches. + num_batches = math.ceil(max(num_ex_per_process) / local_batch_size) + data = data.take(num_batches) + + # Note we cache data after a finite number of batches is taken. + data = data.cache() if cache_final else data + data = data.repeat() + data = data.prefetch(prefetch) if prefetch else data + return data, num_batches + + +def _get_pad_data(data): + def zeros_like_spec(spec): + # For unknown/flexible dimensions (None), just use 0 instead. + return tf.zeros([x or 0 for x in spec.shape], spec.dtype) + + zero = jax.tree.map(zeros_like_spec, data.element_spec) + return tf.data.Dataset.from_tensors(zero).repeat() + + +def _add_internal_fields(pp_fn): + """Wraps pp_fn to add _mask and _id keys.""" + # Adds internal keys, that we either, in this order of preference: + # 1. keep from result of pp_fn, + # 2. carry over from raw (not pp_fn'd) example, or + # 3. add, if that makes sense. + def _pp_fn(example): + result = pp_fn(example) + # _mask will be False on padded examples (see _get_pad_data). + result.setdefault("_mask", example.get("_mask", tf.constant(True))) + # Not all data-sources can provide an ID. Only carry-over if it can: + if "_id" in example and "_id" not in result: + result["_id"] = example["_id"] + return result + return _pp_fn + + +def _add_tpu_host_options(data): + options = tf.data.Options() + options.threading.private_threadpool_size = 48 + options.threading.max_intra_op_parallelism = 1 + + # Stop a whole bunch of magic stuff that eats up all RAM: + options.experimental_optimization.inject_prefetch = False + + return data.with_options(options) + + +def prefetch_iterator(it, n): + """Runs iterator `it` ahead for `n` steps. Adapted from flax.""" + if not n: + yield from it + return + queue = collections.deque() + + def enqueue(n_steps): # Enqueues *up to* `n` elements from the iterator. + for data in itertools.islice(it, n_steps): + # Prefetching will parallelize any processing that happens in a different + # thread (like `jax.device_put()`), but it will be of no use for + # processing that happens in the same thread. + queue.append(data) + + enqueue(n) # Fill up the buffer. + while queue: + yield queue.popleft() + enqueue(1) + + +def threadstart_iterator(it): + """Starts an iterator right away in a background thread.""" + # We already want to "start" the iterator in order to start the underlying + # dataset prefetch mechanisms, so here we get the first element. But we don't + # want to lose it from training, so we yield that one afterwards. + # (internal link) + pool = multiprocessing.pool.ThreadPool(processes=1) + first_ex_promise = pool.apply_async(lambda: next(it)) + + yield first_ex_promise.get() + yield from it + + +def tf_to_numpy(x): + """Convert any TF types to numpy.""" + if isinstance(x, tf.Tensor): + if x.dtype != tf.string: # Dense, non-string tensor? Easy! + return x.numpy() + else: # A dense string tensor? Turn into actual strings, not bytes. + return np.vectorize(bytes.decode, otypes=[str])(x.numpy()) + + # The rest deals with RaggedTensors, for two main reasons: + # - For strings, recursively apply the above conversion + # - For common cases (eg batch of images), return more reasonable shapes. + + # Replace all None's in the shape by a fixed number, in the (somewhat common) + # case that they are marked ragged, but really all have the same shape. + real_shape = list(x.shape) + for i, s in enumerate(real_shape[1:]): + if s is not None: continue + rowlens = np.diff(x.nested_row_splits[i]) + if len(set(rowlens)) == 1: + real_shape[i + 1] = rowlens[0] + + if None not in real_shape: + return tf_to_numpy(x.flat_values).reshape(real_shape) + + # It's actually ragged, reconstruct the array from the variable length pieces. + splits = x.row_splits.numpy() + rows = [tf_to_numpy(x.values[splits[i]:splits[i + 1]]) + for i in range(len(splits) - 1)] + return np.fromiter(rows, dtype=object) + + +# Note that the order of global devices for sharding data is important and +# should be compatible with device order used for models params, state, etc. +def start_global( + data, global_devices, n_prefetch=1, keep_on_cpu=frozenset(), warmup=False): + """Starts the global input pipeline.""" + def maybe_shard(name, x): + if name in keep_on_cpu: + return tf_to_numpy(x) + return u.make_fsarray_from_local_slice(x, global_devices) + + it = iter(data) + if warmup: # actually pre-fill shuffle buffers etc. + it = threadstart_iterator(it) + + it = (u.tree_map_with_names(maybe_shard, elem) for elem in it) + return prefetch_iterator(it, n_prefetch) + + +########################################################################## +# The code below is pmap-specific and is deprecated, please switch to jit. +########################################################################## + + +def shard_and_put(x, shard=True, put=True): + x = np.asarray(memoryview(x)) # No-copy conversion: http://(internal link) + if shard: + x = einops.rearrange(x, "(d l) ... -> d l ...", d=jax.local_device_count()) + if shard and put: # Only works for pmap (for now). + x = jax.device_put_sharded(list(x), jax.local_devices()) + return x + + +def start_input_pipeline(data, n_prefetch=1, shard=True): + fn = functools.partial(shard_and_put, shard=shard, put=n_prefetch) + it = (jax.tree.map(fn, elem) for elem in iter(data)) + return prefetch_iterator(it, n_prefetch) + + +def start_ragged_input_pipeline(data, n_prefetch=1, shard=True, ragged=None): + def maybe_shard_and_put(name, x): + return x if name in (ragged or {}) else shard_and_put(x, shard) + + it = (u.tree_map_with_names(maybe_shard_and_put, elem) for elem in iter(data)) + return prefetch_iterator(it, n_prefetch) diff --git a/big_vision/models/__init__.py b/big_vision/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/big_vision/models/__pycache__/__init__.cpython-310.pyc b/big_vision/models/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..603ea115ab9b47e6c081a2c0c9467c3628afdb90 Binary files /dev/null and b/big_vision/models/__pycache__/__init__.cpython-310.pyc differ diff --git a/big_vision/models/__pycache__/common.cpython-310.pyc b/big_vision/models/__pycache__/common.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9480224c591001a2758b2e3250169385e9199d45 Binary files /dev/null and b/big_vision/models/__pycache__/common.cpython-310.pyc differ diff --git a/big_vision/models/__pycache__/vit.cpython-310.pyc b/big_vision/models/__pycache__/vit.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a5004d0b5784f7011c901383a5a15959bff98576 Binary files /dev/null and b/big_vision/models/__pycache__/vit.cpython-310.pyc differ diff --git a/big_vision/models/bit.py b/big_vision/models/bit.py new file mode 100644 index 0000000000000000000000000000000000000000..cd4235df9ec87549590396deb69739e310e6770d --- /dev/null +++ b/big_vision/models/bit.py @@ -0,0 +1,162 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""ResNet V1 with GroupNorm.""" + +from typing import Optional, Sequence, Union + +from big_vision import utils +from big_vision.models import common +import flax +import flax.linen as nn +import flax.training.checkpoints +import jax.numpy as jnp +import numpy as np + + +def weight_standardize(w, axis, eps): + w = w - jnp.mean(w, axis=axis) + w = w / (jnp.std(w, axis=axis) + eps) + return w + + +class StdConv(nn.Conv): + + def param(self, name, *a, **kw): + param = super().param(name, *a, **kw) + if name == "kernel": + param = weight_standardize(param, axis=[0, 1, 2], eps=1e-5) + return param + + +class ResidualUnit(nn.Module): + """Bottleneck ResNet block.""" + nmid: Optional[int] = None + strides: Sequence[int] = (1, 1) + + @nn.compact + def __call__(self, x): + nmid = self.nmid or x.shape[-1] // 4 + nout = nmid * 4 + + residual = x + if x.shape[-1] != nout or self.strides != (1, 1): + residual = StdConv(nout, (1, 1), self.strides, use_bias=False, + name="conv_proj")(residual) + residual = nn.GroupNorm(name="gn_proj")(residual) + + y = StdConv(nmid, (1, 1), use_bias=False, name="conv1")(x) + y = nn.GroupNorm(name="gn1")(y) + y = nn.relu(y) + y = StdConv(nmid, (3, 3), self.strides, use_bias=False, name="conv2")(y) + y = nn.GroupNorm(name="gn2")(y) + y = nn.relu(y) + y = StdConv(nout, (1, 1), use_bias=False, name="conv3")(y) + + y = nn.GroupNorm(name="gn3", scale_init=nn.initializers.zeros)(y) + y = nn.relu(residual + y) + return y + + +class ResNetStage(nn.Module): + """One stage of ResNet.""" + block_size: int + first_stride: Sequence[int] = (1, 1) + nmid: Optional[int] = None + + @nn.compact + def __call__(self, x): + x = ResidualUnit(self.nmid, strides=self.first_stride, name="unit1")(x) + for i in range(1, self.block_size): + x = ResidualUnit(self.nmid, name=f"unit{i + 1}")(x) + return x + + +class Model(nn.Module): + """ResNetV1.""" + num_classes: Optional[int] = None + width: float = 1 + depth: Union[int, Sequence[int]] = 50 + + @nn.compact + def __call__(self, image, *, train=False): + del train # Unused + blocks = get_block_desc(self.depth) + width = int(64 * self.width) + + out = {} + + # Root block + x = StdConv(width, (7, 7), (2, 2), use_bias=False, name="conv_root")(image) + x = nn.GroupNorm(name="gn_root")(x) + x = nn.relu(x) + x = nn.max_pool(x, (3, 3), strides=(2, 2), padding="SAME") + out["stem"] = x + + # Stages + x = ResNetStage(blocks[0], nmid=width, name="block1")(x) + out["stage1"] = x + for i, block_size in enumerate(blocks[1:], 1): + x = ResNetStage(block_size, nmid=width * 2 ** i, + first_stride=(2, 2), name=f"block{i + 1}")(x) + out[f"stage{i + 1}"] = x + out["pre_logits_2d"] = x + + # Head + x = out["pre_logits"] = jnp.mean(x, axis=(1, 2)) + + if self.num_classes: + head = nn.Dense(self.num_classes, name="head", + kernel_init=nn.initializers.zeros) + out["logits_2d"] = head(out["pre_logits_2d"]) + x = out["logits"] = head(out["pre_logits"]) + + return x, out + + +# A dictionary mapping the number of layers in a resnet to the number of +# blocks in each stage of the model. +# NOTE: Does not include 18/34 as they also need non-bottleneck block! +def get_block_desc(depth): + if isinstance(depth, list): # Be robust to silly mistakes. + depth = tuple(depth) + return { + 26: [2, 2, 2, 2], # From timm, gets ~75% on ImageNet. + 50: [3, 4, 6, 3], + 101: [3, 4, 23, 3], + 152: [3, 8, 36, 3], + 200: [3, 24, 36, 3] + }.get(depth, depth) + + +def fix_old_checkpoints(params): + """Modifies params from old checkpoints to run with current implementation.""" + params = flax.core.unfreeze( + flax.training.checkpoints.convert_pre_linen(params)) + # Old linen used to store non-squeezed GN params. + params = flax.traverse_util.unflatten_dict({ + k: np.squeeze(v) if (set(k) + & {"gn_root", "gn_proj", "gn1", "gn2", "gn3"}) else v + for k, v in flax.traverse_util.flatten_dict(params).items() + }) + return params + + +def load(init_params, init_file, model_cfg, dont_load=()): + """Load init from checkpoint.""" + del model_cfg # Unused + params = utils.load_params(init_file) + params = common.merge_params(params, init_params, dont_load) + params = fix_old_checkpoints(params) + return params diff --git a/big_vision/models/bit_paper.py b/big_vision/models/bit_paper.py new file mode 100644 index 0000000000000000000000000000000000000000..26e5ba83616ce046a78d1a9b3fa32f8b4cbc1000 --- /dev/null +++ b/big_vision/models/bit_paper.py @@ -0,0 +1,260 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""BiT models as in the paper (ResNet V2) w/ loading of public weights. + +See reproduction proof: http://(internal link)/qY70qs6j944 +""" + +import functools +import re +from typing import Optional, Sequence, Union + +from big_vision import utils as u +from big_vision.models import bit +from big_vision.models import common +import flax.linen as nn +import jax.numpy as jnp + + +def standardize(x, axis, eps): + x = x - jnp.mean(x, axis=axis, keepdims=True) + x = x / jnp.sqrt(jnp.mean(jnp.square(x), axis=axis, keepdims=True) + eps) + return x + + +# Defined our own, because we compute normalizing variance slightly differently, +# which does affect performance when loading pre-trained weights! +class GroupNorm(nn.Module): + """Group normalization (arxiv.org/abs/1803.08494).""" + ngroups: int = 32 + + @nn.compact + def __call__(self, x): + + input_shape = x.shape + group_shape = x.shape[:-1] + (self.ngroups, x.shape[-1] // self.ngroups) + + x = x.reshape(group_shape) + + # Standardize along spatial and group dimensions + x = standardize(x, axis=[1, 2, 4], eps=1e-5) + x = x.reshape(input_shape) + + bias_scale_shape = tuple([1, 1, 1] + [input_shape[-1]]) + x = x * self.param('scale', nn.initializers.ones, bias_scale_shape) + x = x + self.param('bias', nn.initializers.zeros, bias_scale_shape) + return x + + +class StdConv(nn.Conv): + + def param(self, name, *a, **kw): + param = super().param(name, *a, **kw) + if name == 'kernel': + param = standardize(param, axis=[0, 1, 2], eps=1e-10) + return param + + +class RootBlock(nn.Module): + """Root block of ResNet.""" + width: int + + @nn.compact + def __call__(self, x): + x = StdConv(self.width, (7, 7), (2, 2), padding=[(3, 3), (3, 3)], + use_bias=False, name='conv_root')(x) + x = nn.max_pool(x, (3, 3), strides=(2, 2), padding=[(1, 1), (1, 1)]) + return x + + +class ResidualUnit(nn.Module): + """Bottleneck ResNet block.""" + nmid: Optional[int] = None + strides: Sequence[int] = (1, 1) + + @nn.compact + def __call__(self, x): + nmid = self.nmid or x.shape[-1] // 4 + nout = nmid * 4 + conv = functools.partial(StdConv, use_bias=False) + + residual = x + x = GroupNorm(name='gn1')(x) + x = nn.relu(x) + + if x.shape[-1] != nout or self.strides != (1, 1): + residual = conv(nout, (1, 1), self.strides, name='conv_proj')(x) + + x = conv(nmid, (1, 1), name='conv1')(x) + x = GroupNorm(name='gn2')(x) + x = nn.relu(x) + x = conv(nmid, (3, 3), self.strides, padding=[(1, 1), (1, 1)], + name='conv2')(x) + x = GroupNorm(name='gn3')(x) + x = nn.relu(x) + x = conv(nout, (1, 1), name='conv3')(x) + + return x + residual + + +class ResNetStage(nn.Module): + """A stage (sequence of same-resolution blocks).""" + block_size: int + nmid: Optional[int] = None + first_stride: Sequence[int] = (1, 1) + + @nn.compact + def __call__(self, x): + out = {} + x = out['unit01'] = ResidualUnit( + self.nmid, strides=self.first_stride, name='unit01')(x) + for i in range(1, self.block_size): + x = out[f'unit{i+1:02d}'] = ResidualUnit( + self.nmid, name=f'unit{i+1:02d}')(x) + return x, out + + +class Model(nn.Module): + """ResNetV2.""" + num_classes: Optional[int] = None + width: int = 1 + depth: Union[int, Sequence[int]] = 50 # 50/101/152, or list of block depths. + head_zeroinit: bool = True + + @nn.compact + def __call__(self, image, *, train=False): + blocks = bit.get_block_desc(self.depth) + width = int(64 * self.width) + out = {} + + x = out['stem'] = RootBlock(width=width, name='root_block')(image) + + # Blocks + x, out['stage1'] = ResNetStage(blocks[0], nmid=width, name='block1')(x) + for i, block_size in enumerate(blocks[1:], 1): + x, out[f'stage{i + 1}'] = ResNetStage( + block_size, width * 2 ** i, + first_stride=(2, 2), name=f'block{i + 1}')(x) + + # Pre-head + x = out['norm_pre_head'] = GroupNorm(name='norm-pre-head')(x) + x = out['pre_logits_2d'] = nn.relu(x) + x = out['pre_logits'] = jnp.mean(x, axis=(1, 2)) + + # Head + if self.num_classes: + kw = {'kernel_init': nn.initializers.zeros} if self.head_zeroinit else {} + head = nn.Dense(self.num_classes, name='head', **kw) + out['logits_2d'] = head(out['pre_logits_2d']) + x = out['logits'] = head(out['pre_logits']) + + return x, out + + +def load(init_params, init_file, model_cfg, dont_load=()): + """Loads the TF-dumped NumPy or big_vision checkpoint. + + Args: + init_params: random init params from which the new head is taken. + init_file: comes from `config.model_init`, can either be an absolute + path (ie starts with /) to the checkpoint, or a string like + "L-imagenet2012" describing one of the variants from the paper. + model_cfg: the model configuration. + dont_load: list of param names to be reset to init. + + Returns: + The loaded parameters. + """ + + # Support for vanity model names from the paper. + vanity = { + 'FunMatch-224px-i1k82.8': 'gs://bit_models/distill/R50x1_224.npz', + 'FunMatch-160px-i1k80.5': 'gs://bit_models/distill/R50x1_160.npz', + } + if init_file[0] in ('L', 'M', 'S'): # The models from the original paper. + # Supported names are of the following type: + # - 'M' or 'S': the original "upstream" model without fine-tuning. + # - 'M-ILSVRC2012': i21k model fine-tuned on i1k. + # - 'M-run0-caltech101': i21k model fine-tuned on VTAB's caltech101. + # each VTAB fine-tuning was run 3x, so there's run0, run1, run2. + if '-' in init_file: + up, down = init_file[0], init_file[1:] + else: + up, down = init_file, '' + down = {'-imagenet2012': '-ILSVRC2012'}.get(down, down) # normalize + fname = f'BiT-{up}-R{model_cfg.depth}x{model_cfg.width}{down}.npz' + fname = f'gs://bit_models/{fname}' + else: + fname = vanity.get(init_file, init_file) + + params = u.load_params(fname) + params = maybe_convert_big_transfer_format(params) + return common.merge_params(params, init_params, dont_load) + + +def maybe_convert_big_transfer_format(params_tf): + """If the checkpoint comes from legacy codebase, convert it.""" + + # Only do anything at all if we recognize the format. + if 'resnet' not in params_tf: + return params_tf + + # For ease of processing and backwards compatibility, flatten again: + params_tf = dict(u.tree_flatten_with_names(params_tf)[0]) + + # Works around some files containing weird naming of variables: + for k in list(params_tf): + k2 = re.sub('/standardized_conv2d_\\d+/', '/standardized_conv2d/', k) + if k2 != k: + params_tf[k2] = params_tf[k] + del params_tf[k] + + params = { + 'root_block': {'conv_root': {'kernel': params_tf[ + 'resnet/root_block/standardized_conv2d/kernel']}}, + 'norm-pre-head': { + 'bias': params_tf['resnet/group_norm/beta'][None, None, None], + 'scale': params_tf['resnet/group_norm/gamma'][None, None, None], + }, + 'head': { + 'kernel': params_tf['resnet/head/conv2d/kernel'][0, 0], + 'bias': params_tf['resnet/head/conv2d/bias'], + } + } + + for block in ('block1', 'block2', 'block3', 'block4'): + params[block] = {} + units = set([re.findall(r'unit\d+', p)[0] for p in params_tf.keys() + if p.find(block) >= 0]) + for unit in units: + params[block][unit] = {} + for i, group in enumerate('abc', 1): + params[block][unit][f'conv{i}'] = { + 'kernel': params_tf[f'resnet/{block}/{unit}/{group}/standardized_conv2d/kernel'] # pylint: disable=line-too-long + } + params[block][unit][f'gn{i}'] = { + 'bias': params_tf[f'resnet/{block}/{unit}/{group}/group_norm/beta'][None, None, None], # pylint: disable=line-too-long + 'scale': params_tf[f'resnet/{block}/{unit}/{group}/group_norm/gamma'][None, None, None], # pylint: disable=line-too-long + } + + projs = [p for p in params_tf.keys() + if p.find(f'{block}/{unit}/a/proj') >= 0] + assert len(projs) <= 1 + if projs: + params[block][unit]['conv_proj'] = { + 'kernel': params_tf[projs[0]] + } + + return params diff --git a/big_vision/models/common.py b/big_vision/models/common.py new file mode 100644 index 0000000000000000000000000000000000000000..175dfa77a1360bc2a0276fa12245c8d357b39406 --- /dev/null +++ b/big_vision/models/common.py @@ -0,0 +1,133 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities shared across models.""" + +from absl import logging +import big_vision.utils as u +import flax.linen as nn +import jax +import jax.numpy as jnp + + +def merge_params(loaded, inited, dont_load=(), match_dtype=False): + """Makes `loaded` pytree match `init`, warning or failing on mismatch. + + Args: + loaded: pytree of parameters, typically loaded from a checkpoint. + inited: pytree of parameter, typically coming from model init. + dont_load: List of regexes for parameters which shall not be taken + from `loaded`, either because they should remain at their init value, + or because they are missing on either side. + match_dtype: returned pytree as leaves converted to dtype from `inited`. + + Returns: + If successful, a new pytree which matches the structure of `init` + but contains values from `loaded`, except for `dont_load`. + + If structures don't match and mismatches are not covered by regexes in + `dont_load` argument, then raises an exception with more information. + """ + if inited is None: # A useful shortcut for example for colabs. + return loaded + + dont_load = u.check_and_compile_patterns(dont_load) + + def should_merge(name): + return not any(pattern.fullmatch(name) for pattern in dont_load) + + loaded_flat, _ = u.tree_flatten_with_names(loaded) + inited_flat, _ = u.tree_flatten_with_names(inited) + loaded_flat = {k: v for k, v in loaded_flat} + inited_flat = {k: v for k, v in inited_flat} + + # Let's first build the pytree from all common keys. + merged = {} + for name, init_val in inited_flat.items(): + # param is present in both. Load or ignore it! + if name in loaded_flat and should_merge(name): + merged[name] = loaded_flat[name] + if match_dtype: + merged[name] = loaded_flat[name].astype(init_val.dtype) + else: + logging.info("Ignoring checkpoint and using init value for %s", name) + merged[name] = init_val + + def pp(title, names, indent=" "): # Just pretty-printing + if names: + return f"{title}:\n" + "\n".join(f"{indent}{k}" for k in sorted(names)) + else: + return "" + + # Now, if there are keys that only exist in inited or loaded, be helpful: + not_in_loaded = inited_flat.keys() - loaded_flat.keys() + not_in_inited = loaded_flat.keys() - inited_flat.keys() + logging.info(pp("Parameters in model but not in checkpoint", not_in_loaded)) + logging.info(pp("Parameters in checkpoint but not in model", not_in_inited)) + + # And now see if any of them are not explicitly ignored => an error + not_in_loaded = {k for k in not_in_loaded if should_merge(k)} + not_in_inited = {k for k in not_in_inited if should_merge(k)} + + if not_in_loaded or not_in_inited: + raise ValueError( + pp("Params in checkpoint", loaded_flat.keys()) + "\n" + + pp("Params in model (code)", inited_flat.keys()) + "\n" + + pp("Params in model (code) but not in checkpoint and not `dont_load`ed", + not_in_loaded, indent=" - ") + "\n" + # Special indent for tests. + pp("Params in checkpoint but not in model (code) and not `dont_load`ed", + not_in_inited, indent=" + ")) # Special indent for tests. + + return u.recover_tree(merged.keys(), merged.values()) + + +class AddPositionEmbs(nn.Module): + """Adds positional embeddings to the inputs, supports caching for decode. + + Attributes: + decode: whether to run in single-position autoregressive mode. + """ + decode: bool = False + + @nn.compact + def __call__(self, inputs, posemb): + """Applies AddPositionEmbs module. + + Adds posemb to the inputs, supports single-position autoregressive mode. + + Args: + inputs: input data [batch_size, seq_len, emb_dim]. + posemb: positional embeddings. + + Returns: + output: inputs modulated by pos-embeddings [batch_size, seq_len, emb_dim]. + """ + assert inputs.ndim == 3, f"Unexpected inputs shape: {inputs.shape}" + _, seq_len, emb_dim = inputs.shape + pe = posemb[:, :seq_len, :] + + if self.decode: + is_initialized = self.has_variable("cache", "cache_index") + # We use a cache position index for tracking decoding position. + cache_index = self.variable("cache", "cache_index", + lambda: jnp.array(0, dtype=jnp.uint32)) + if is_initialized: + i = cache_index.value + cache_index.value = i + 1 + # Returns posemb[0, i, :], the positional embedding for the + # current decoding position. + pe = jax.lax.dynamic_slice(posemb, + start_indices=jnp.array((0, i, 0)), + slice_sizes=(1, 1, emb_dim)) + return inputs + pe diff --git a/big_vision/models/mlp_mixer.py b/big_vision/models/mlp_mixer.py new file mode 100644 index 0000000000000000000000000000000000000000..58bd4b99d21f061693da007b26dd24013e341851 --- /dev/null +++ b/big_vision/models/mlp_mixer.py @@ -0,0 +1,177 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""MLP-Mixer model.""" + +from typing import Optional, Tuple +from absl import logging + +from big_vision import utils +from big_vision.models import common + +import einops +import flax.linen as nn +import flax.training.checkpoints +import jax +import jax.numpy as jnp + + +class MlpBlock(nn.Module): + mlp_dim: int + + @nn.compact + def __call__(self, x): + y = nn.Dense(self.mlp_dim)(x) + y = nn.gelu(y) + return nn.Dense(x.shape[-1])(y) + + +class MixerBlock(nn.Module): + """Mixer block layer.""" + tokens_mlp_dim: int + channels_mlp_dim: int + drop_p: float + + @nn.compact + def __call__(self, x, *, train=False): + y = nn.LayerNorm()(x) + y = jnp.swapaxes(y, 1, 2) + y = MlpBlock(self.tokens_mlp_dim, name="token_mixing")(y) + y = jnp.swapaxes(y, 1, 2) + x = x + y * _stoch_depth_mask(x, self.drop_p, not train, self.make_rng) + y = nn.LayerNorm()(x) + y = MlpBlock(self.channels_mlp_dim, name="channel_mixing")(y) + return x + y * _stoch_depth_mask(x, self.drop_p, not train, self.make_rng) + + +class MlpMixer(nn.Module): + """Mixer architecture.""" + patch_size: Tuple[int, int] + num_classes: Optional[int] + num_blocks: int + hidden_dim: int + tokens_mlp_dim: int + channels_mlp_dim: int + model_name: Optional[str] = None + stoch_depth: float = 0.0 + + @nn.compact + def __call__(self, image, *, train=False): + out = {} + x = out["stem"] = nn.Conv(self.hidden_dim, self.patch_size, + strides=self.patch_size, name="stem")(image) + x = out["input_tokens"] = einops.rearrange(x, "n h w c -> n (h w) c") + for i in range(self.num_blocks): + drop_p = (i / max(self.num_blocks - 1, 1)) * self.stoch_depth + x = out[f"block_{i}"] = MixerBlock( + self.tokens_mlp_dim, self.channels_mlp_dim, drop_p)(x, train=train) + x = nn.LayerNorm(name="pre_head_layer_norm")(x) + x = out["pre_logits"] = jnp.mean(x, axis=1) + if self.num_classes: + x = out["logits"] = nn.Dense( + self.num_classes, kernel_init=nn.initializers.zeros, name="head")(x) + return x, out + + +def Model(num_classes=None, *, variant=None, **kw): # pylint: disable=invalid-name + """Factory function to easily create a Model variant like "L/16".""" + + if variant is not None: + model_size, patch = variant.split("/") + kw.setdefault("patch_size", (int(patch), int(patch))) + config = { + "S": { + "hidden_dim": 512, + "num_blocks": 8, + "channels_mlp_dim": 2048, + "tokens_mlp_dim": 256 + }, + "B": { + "hidden_dim": 768, + "num_blocks": 12, + "channels_mlp_dim": 3072, + "tokens_mlp_dim": 384 + }, + "L": { + "hidden_dim": 1024, + "num_blocks": 24, + "channels_mlp_dim": 4096, + "tokens_mlp_dim": 512 + }, + "H": { + "hidden_dim": 1280, + "num_blocks": 32, + "channels_mlp_dim": 5120, + "tokens_mlp_dim": 640 + }, + }[model_size] + + for k, v in config.items(): + kw.setdefault(k, v) + + logging.info("Mixer config: %s", kw) + return MlpMixer(num_classes=num_classes, **kw) + + +def load(init_params, init_file, model_cfg, dont_load=()): + """Load checkpoint.""" + + del model_cfg + # Shortcut names for some canonical paper checkpoints: + init_file = { + # pylint: disable=line-too-long + # Pretrained models from the MLP-Mixer paper: https://arxiv.org/abs/2105.01601. + "B-i1k/16": "gs://mixer_models/imagenet1k/Mixer-B_16.npz", + "L-i1k/16": "gs://mixer_models/imagenet1k/Mixer-L_16.npz", + "B-i21k/16": "gs://mixer_models/imagenet21k/Mixer-B_16.npz", + "L-i21k/16": "gs://mixer_models/imagenet21k/Mixer-L_16.npz", + # pylint: enable=line-too-long + }.get(init_file, init_file) + restored_params = utils.load_params(init_file) + restored_params = flax.training.checkpoints.convert_pre_linen(restored_params) + + if "Mixer" in restored_params: + restored_params["pre_head_layer_norm"] = restored_params["Mixer"].pop( + "encoder_norm" + ) + restored_params["stem"] = restored_params.pop("embedding") + def unflatten_dense(d): + return { + "Dense_0": { + "bias": d["bias1"].squeeze(), + "kernel": d["kernel1"].squeeze(), + }, + "Dense_1": { + "bias": d["bias2"].squeeze(), + "kernel": d["kernel2"].squeeze(), + }, + } + for k, v in restored_params["Mixer"].items(): + assert k.startswith("encoderblock_"), k + v["token_mixing"] = unflatten_dense(v.pop("token_mixing_phase_0")) + v["channel_mixing"] = unflatten_dense(v.pop("channel_mixing_phase_0")) + restored_params["MixerBlock_" + k[len("encoderblock_"):]] = v + del restored_params["Mixer"] + + # possibly use the random init for some of the params (such as, the head). + restored_params = common.merge_params(restored_params, init_params, dont_load) + + return restored_params + + +def _stoch_depth_mask(x, drop_p, deterministic, make_rng): + if not deterministic and drop_p: + shape = (x.shape[0],) + (1,) * (x.ndim - 1) + return 1.0 - jax.random.bernoulli(make_rng("dropout"), drop_p, shape) + return 1.0 diff --git a/big_vision/models/ppp/__init__.py b/big_vision/models/ppp/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/big_vision/models/ppp/__pycache__/__init__.cpython-310.pyc b/big_vision/models/ppp/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ba798b42460e2246acce9722abb04d642e957d53 Binary files /dev/null and b/big_vision/models/ppp/__pycache__/__init__.cpython-310.pyc differ diff --git a/big_vision/models/ppp/__pycache__/gemma.cpython-310.pyc b/big_vision/models/ppp/__pycache__/gemma.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..15bffc57a3184b00bcdebcd48e8ba5874ea78233 Binary files /dev/null and b/big_vision/models/ppp/__pycache__/gemma.cpython-310.pyc differ diff --git a/big_vision/models/ppp/gemma.py b/big_vision/models/ppp/gemma.py new file mode 100644 index 0000000000000000000000000000000000000000..e11606fba89dedb988ea541f39f608bc7d0cd515 --- /dev/null +++ b/big_vision/models/ppp/gemma.py @@ -0,0 +1,533 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""gemma reimplementation for big_vision. + +We follow this einsum axis naming convention: + B: batch + T: query length + S: k/v length + N: num query heads + K: num k/v heads + G: num query heads per k/v head + H: head dim + D: d_model ("features") + +Example Colab using the models via the PaliGemma decoding logic: +(internal link) + +Doc locating the variable initializers in the original code and validating them: +(internal link) +""" + + +from big_vision.models import common +import big_vision.utils as u +import einops +import flax +import flax.linen as nn +import jax +import jax.numpy as jnp +import ml_collections +import numpy as np +import orbax.checkpoint + + +def get_config(variant): + """Returns config for specified gemma variant.""" + if variant == "gemma_2b": + return ml_collections.ConfigDict( + dict( + variant=variant, + width=2048, + depth=18, + mlp_dim=16_384, + num_heads=8, + num_kv_heads=1, + head_dim=256, + norm_eps=1e-6, + vocab_size=256_128, + scan=True, + remat_policy="nothing_saveable", + ) + ) + if variant == "gemma_7b": + return ml_collections.ConfigDict( + dict( + variant=variant, + width=3072, + depth=28, + mlp_dim=24_576, + num_heads=16, + num_kv_heads=16, + head_dim=256, + norm_eps=1e-6, + vocab_size=256_128, + scan=True, + remat_policy="nothing_saveable", + ) + ) + raise ValueError(f"Unknown variant: {variant}") + + +def _apply_rope(x, *, positions, max_wavelength=10_000): + """Applies RoPE positions [B, L] to x [B, L, H, D].""" + freq_exponents = (2. / x.shape[-1]) * jnp.arange(x.shape[-1] // 2) + timescale = (max_wavelength ** freq_exponents) + radians = positions[..., None] / timescale[None, None, :] + radians = radians[..., None, :] + # radians.shape = [...,L,1,d=D/2] + sin, cos = jnp.sin(radians), jnp.cos(radians) + x1, x2 = jnp.split(x, 2, axis=-1) + res = jnp.concatenate([x1 * cos - x2 * sin, x2 * cos + x1 * sin], axis=-1) + return res + + +def _update_kv_cache(module, k, v, cache_size, cache_dtype): + """Updates KV cache and returns its current contents.""" + initialized = module.has_variable("cache", "idx") + batch_size, update_len, num_heads, head_dim = k.shape + cache_dtype = cache_dtype or k.dtype + + # Idx of which cache row to update next is the same for all examples, so that + # it allows to update with dynamic_update_slice. But in order to keep things + # nicely partitioned we store it with leading batch dimension and use only + # the first entry. + idx = module.variable("cache", "idx", jnp.zeros, (batch_size,), jnp.int32) + + kv_shape = (batch_size, cache_size, num_heads, head_dim) + k_cache = module.variable( + "cache", "k_cache", jnp.zeros, kv_shape, cache_dtype) + v_cache = module.variable( + "cache", "v_cache", jnp.zeros, kv_shape, cache_dtype) + + if initialized: # write k, v in the next cache position. + assert update_len == 1, update_len + # Note: idx is the same for all examples. Use value from example 0. + indices = (0, idx.value[0], 0, 0) + k_cache.value = jax.lax.dynamic_update_slice( + k_cache.value, k.astype(cache_dtype), indices) + v_cache.value = jax.lax.dynamic_update_slice( + v_cache.value, v.astype(cache_dtype), indices) + idx.value = idx.value + 1 + else: # init cache with k, v after padding to cache_size. + prefill_len = k.shape[1] + pad_width = ((0, 0), (0, cache_size - prefill_len), (0, 0), (0, 0)) + k_cache.value = jnp.pad(k.astype(cache_dtype), pad_width) + v_cache.value = jnp.pad(v.astype(cache_dtype), pad_width) + idx.value = idx.value + prefill_len + + return k_cache.value.astype(k.dtype), v_cache.value.astype(v.dtype) + + +def trunc_norm_init(in_axis, out_axis, batch_axis): + return nn.initializers.variance_scaling( + 1.0, "fan_in", "truncated_normal", + in_axis=in_axis, out_axis=out_axis, batch_axis=batch_axis) + + +class Einsum(nn.Module): + shape: tuple[int, ...] + w_init: nn.initializers.Initializer = nn.initializers.zeros_init() + + @nn.compact + def __call__(self, eqn, x): + w = self.param("w", self.w_init, self.shape) + return jnp.einsum(eqn, x, w) + + +class RMSNorm(nn.Module): + + @nn.compact + def __call__(self, x): + scale = self.param("scale", nn.initializers.zeros_init(), (x.shape[-1])) + var = jnp.mean(jnp.square(x), axis=-1, keepdims=True) + normed_inputs = jnp.asarray(x * jnp.reciprocal(jnp.sqrt(var + 1e-06))) + normed_inputs = normed_inputs * (1 + scale) + return normed_inputs + + +class Embedder(nn.Module): + """Embedder module.""" + + vocab_size: int + embed_dim: int + + def setup(self): + self.input_embedding_table = self.param( + "input_embedding", + nn.initializers.variance_scaling( + scale=1.0, mode="fan_in", distribution="normal", + in_axis=1, out_axis=0,), + (self.vocab_size, self.embed_dim), + ) + + def encode(self, x): + x = self.input_embedding_table[(x,)] + x *= jnp.sqrt(self.embed_dim).astype(x.dtype) + return x + + def decode(self, x): + return jnp.dot(x, self.input_embedding_table.T) + + +class Attention(nn.Module): + """Attention module.""" + + num_heads: int + num_kv_heads: int + features: int + head_dim: int + + cache_dtype: str | None = None + + def setup(self): + if self.num_kv_heads == self.num_heads: + self.qkv_einsum = Einsum( + shape=(3, self.num_heads, self.features, self.head_dim), + w_init=trunc_norm_init( + in_axis=(2,), out_axis=(0, 1, 3), batch_axis=()), + ) + else: + # MQA + self.q_einsum = Einsum( + shape=(self.num_heads, self.features, self.head_dim), + w_init=trunc_norm_init(in_axis=(1,), out_axis=(0, 2), batch_axis=()), + ) + self.kv_einsum = Einsum( + shape=(2, self.num_kv_heads, self.features, self.head_dim), + w_init=trunc_norm_init( + in_axis=(2,), out_axis=(0, 1, 3), batch_axis=()), + ) + self.attn_vec_einsum = Einsum( + shape=(self.num_heads, self.head_dim, self.features), + w_init=trunc_norm_init(in_axis=(0, 1), out_axis=(2,), batch_axis=()), + ) + + @nn.compact + def __call__(self, x, positions, attn_mask, decode, deterministic=True): + if self.num_kv_heads == self.num_heads: + q, k, v = self.qkv_einsum("BSD,3KDH->3BSKH", x) + else: + q = self.q_einsum("BTD,NDH->BTNH", x) + k, v = self.kv_einsum("BSD,2KDH->2BSKH", x) + + q = _apply_rope(q, positions=positions) + q *= self.head_dim**-0.5 + + k = _apply_rope(k, positions=positions) + if decode: + k, v = _update_kv_cache(self, k, v, + cache_size=attn_mask.shape[-1], + cache_dtype=self.cache_dtype) + + q = einops.rearrange(q, "B T (K G) H -> B T K G H", K=self.num_kv_heads) + logits = jnp.einsum("BTKGH,BSKH->BKGTS", q, k) + logits = logits.astype(jnp.float32) + + if attn_mask.shape != (q.shape[0], 1, q.shape[1], k.shape[1]): + raise ValueError( + f"Attention mask with shape {attn_mask.shape} but shapes for q and k " + f"are: {q.shape} and {k.shape}" + ) + + # big_neg = jnp.finfo(logits.dtype).min + big_neg = -2.3819763e38 # See gemma/modules.py + masked_logits = jnp.where(attn_mask[:, :, None, :, :], logits, big_neg) + + probs = jax.nn.softmax(masked_logits, axis=-1).astype(k.dtype) + + encoded = jnp.einsum("BKGTS,BSKH->BTKGH", probs, v) + encoded = einops.rearrange(encoded, "B T K G H -> B T (K G) H") + attn_output = self.attn_vec_einsum("BTNH,NHD->BTD", encoded) + + return attn_output + + +class FeedForward(nn.Module): + """Feed forward module.""" + + features: int + hidden_dim: int + + @nn.compact + def __call__(self, x): + w_gating = self.param( + "gating_einsum", + trunc_norm_init(in_axis=(1,), out_axis=(0, 2), batch_axis=()), + ((2, self.features, self.hidden_dim)), + ) + ff_gate = jnp.dot(x, w_gating[0]) + gate_value = nn.gelu(ff_gate) + + ff1 = jnp.dot(x, w_gating[1]) + activations = gate_value * ff1 + + w_linear = self.param( + "linear", + trunc_norm_init(in_axis=(0,), out_axis=(1,), batch_axis=()), + (self.hidden_dim, self.features), + ) + outputs = jnp.dot(activations, w_linear) + + return outputs + + +class Block(nn.Module): + """Transformer block.""" + + num_heads: int + num_kv_heads: int + embed_dim: int + head_dim: int + hidden_dim: int + + dropout: float = 0.0 + dropout_bdims: tuple[int, ...] = () + cache_dtype: str | None = None + + def setup(self): + self.pre_attention_norm = RMSNorm() + self.attn = Attention( + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + features=self.embed_dim, + head_dim=self.head_dim, + cache_dtype=self.cache_dtype, + ) + self.pre_ffw_norm = RMSNorm() + self.mlp = FeedForward(features=self.embed_dim, hidden_dim=self.hidden_dim) + if self.dropout: + self.drop = nn.Dropout(self.dropout, self.dropout_bdims) + else: + self.drop = lambda x, _: x + + def __call__(self, x, unused_scan_arg, positions, attn_mask, + decode, deterministic=True): + x = nn.with_logical_constraint(x, ("act_batch", "act_len", "act_emb")) + inputs_normalized = self.pre_attention_norm(x) + attn_output = self.attn(inputs_normalized, positions, attn_mask, + decode, deterministic) + attn_output = self.drop(attn_output, deterministic) + attn_output += x + residual = attn_output + attn_output = self.pre_ffw_norm(attn_output) + outputs = self.mlp(attn_output) + outputs = self.drop(outputs, deterministic) + outputs = residual + outputs + return outputs, unused_scan_arg + + +class Model(nn.Module): + """gemma model.""" + + variant: str + + width: int + depth: int + mlp_dim: int + num_heads: int + num_kv_heads: int + head_dim: int + norm_eps: float + vocab_size: int + + dropout: float = 0.0 + dropout_bdims: tuple[int, ...] = () # Every float is dropped independently. + cache_dtype: str | None = None + + # TODO: Wire this in all places needed so that the model can be + # run with different activation dtype. For now only float32 runs. + embed_dtype: str = "float32" + + scan: bool = False + remat_policy: str = "none" + + @nn.compact + def __call__( + self, tokens, *, + embedded_prefix=None, + embed_only=False, + pre_logits=None, + positions=None, mask=None, + decode=False, deterministic=True, + ): + """Embed only, or complete forward pass. + + Args: + tokens: Embedded, then and appended to `embedded_prefix`. Can be None. + embedded_prefix: Optional prefix that is already embedded. + embed_only: Whether to compute embeddings only. + pre_logits: If present computes logits from pre_logits and returns. + positions: Optional `[B, T]` allows to specify the absolute position of + the tokens. + mask: Optional attention mask `[B, T, S]`. + decode: Whether to use kv-cache. Caller must pass masks and positions. + deterministic: Forwarded to all dropout layers. + + Returns: + If `embed_only=False`, then `(logits, out)` will be returned. + If `embed_only=True`, then the embeddings will be returned. + """ + out = {} + + embedder = Embedder( + vocab_size=self.vocab_size, + embed_dim=self.width, + name="embedder") + + if pre_logits is not None: + x = out["pre_logits"] = pre_logits + logits = out["logits"] = embedder.decode(x) + return logits, out + + x = [] + if embedded_prefix is not None: + x.append(embedded_prefix) + if tokens is not None: + x.append(embedder.encode(tokens)) + + x = jnp.concatenate(x, axis=-2) + x = x.astype(self.embed_dtype) + batch_size, seq_len, width = x.shape + + if embed_only: + return x + + if decode: + assert positions is not None and mask is not None, ( + "Must explicitly pass positions and mask for decoding.") + + if positions is None: + positions = jnp.arange(seq_len).astype(jnp.int32)[None, :] + assert positions.shape[1] == x.shape[1], (positions.shape, x.shape) + + if mask is None: + mask = nn.attention.make_causal_mask(jnp.ones([batch_size, seq_len])) + if mask.ndim == 3: + mask = mask[:, None, :, :] + cache_size = max(seq_len, mask.shape[-1]) + assert mask.shape == (batch_size, 1, seq_len, cache_size), mask.shape + + if self.remat_policy == "none": + block_cls = Block + else: + block_cls = nn.remat( + Block, + prevent_cse=not self.scan, + static_argnums=(5, 6), # 0=self, 5=decode, 6=deterministic + policy=getattr(jax.checkpoint_policies, self.remat_policy), + ) + + block_kw = dict( + num_heads=self.num_heads, + head_dim=self.head_dim, + num_kv_heads=self.num_kv_heads, + embed_dim=width, + hidden_dim=self.mlp_dim, + dropout=self.dropout, + dropout_bdims=self.dropout_bdims, + cache_dtype=self.cache_dtype, + ) + layers = self.scope.push("layers") + if self.scan: + blocks = [nn.scan( + block_cls, + # cache has axis 1 since we want leading dimension to be batch size. + variable_axes={"params": 0, "cache": 1}, + split_rngs={"params": True, "dropout": True}, + in_axes=nn.broadcast, + length=self.depth, + )( + parent=layers, **block_kw + )] + else: + blocks = [ + block_cls( + parent=layers.push(str(layer)), + **block_kw, + ) + for layer in range(self.depth) + ] + unused_scan_arg = () + for block in blocks: + x, unused_scan_arg = block( + x, unused_scan_arg, positions, mask, decode, deterministic) + + assert x.dtype == jnp.dtype(self.embed_dtype) # Sanity check. + out["encoded"] = x + + x = RMSNorm(name="final_norm")(x) + out["pre_logits"] = x + + x = embedder.decode(x) + out["logits"] = x + + return x, out + + +_ORBAX_INITS = {} +_BV_INITS = {} + + +def _load_orbax(path): + """Loads and coverts Orbax gemma checkpoint.""" + checkpointer = orbax.checkpoint.PyTreeCheckpointer() + params = checkpointer.restore(path) + params = flax.traverse_util.unflatten_dict(params, sep="/")["transformer"] + n = sum(1 for k in params if k.startswith("layer_")) + params["layers"] = jax.tree.map( + lambda *xs: np.stack(xs), *(params.pop(f"layer_{i}") for i in range(n)) + ) + mlp = params["layers"]["mlp"] + mlp["gating_einsum"] = mlp["gating_einsum"].pop("w") + mlp["linear"] = mlp["linear"].pop("w") + return params + + +def _del_pad_rows(params): + """Some checkpoints have 128 unused padding tokens.""" + emb = params["embedder"]["input_embedding"] + assert emb.shape[0] == 256_128 + params["embedder"]["input_embedding"] = np.asarray(emb)[:256_000] + return params + + +def load(init_params, init_file, model_cfg=None, dont_load=()): + """Loads existing weights.""" + model_cfg = model_cfg or {} + variant = model_cfg.get("variant", "gemma_2b") + init_variant = f"{init_file} {variant}" + if init_variant in _ORBAX_INITS: + params = _del_pad_rows(_load_orbax(_ORBAX_INITS[init_variant])) + elif init_variant in _BV_INITS: + params = _del_pad_rows(u.load_params(_BV_INITS[init_variant])) + else: + params = u.load_params(init_file) + + def extend_rows(emb1, target_rows): + if (missing_rows := target_rows - emb1.shape[0]) == 0: + return emb1 + assert missing_rows > 0, "You're asking to shrink vocab?!" + new_rows = np.random.randn(missing_rows, emb1.shape[1]) + new_rows = (new_rows * 0.02).astype(emb1.dtype) + return np.r_[np.asarray(emb1), new_rows] + + if "vocab_size" in model_cfg: + params["embedder"]["input_embedding"] = extend_rows( + params["embedder"]["input_embedding"], + model_cfg["vocab_size"], + ) + + return common.merge_params(params, init_params, dont_load) diff --git a/big_vision/models/proj/cappa/cappa.py b/big_vision/models/proj/cappa/cappa.py new file mode 100644 index 0000000000000000000000000000000000000000..8c20b1b78f06d3eece7f2e5e8707eef995466529 --- /dev/null +++ b/big_vision/models/proj/cappa/cappa.py @@ -0,0 +1,428 @@ +# Copyright 2023 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Model definitions for CapPa (https://arxiv.org/abs/2306.07915). + +Used abbreviations for dimension annotations: + B: batch size. + H: image height. + W: image width. + P: number of patches (PH/PW: number of patches in height/width dimensions). + E: embedding size. + L: sequence length of text tokens. + V: vocab size. +""" + +from collections.abc import Sequence + +from big_vision import utils +from big_vision.models import common +from big_vision.models import vit +import flax +import flax.linen as nn +from flax.linen import partitioning +import jax +import jax.numpy as jnp + + +def shift_right(x, axis=1, constant_values=0): + """Shift to the right on given axis with padding value 0.""" + pad_widths = [(0, 0)] * len(x.shape) + pad_widths[axis] = (1, 0) + padded = jnp.pad(x, pad_widths, constant_values=constant_values) + # Cuts off the rightmost slice of size along the `axis` dimension. + # Note that `list[:-1]`` is the same as `list[slice(-1)]`. + return padded[tuple(slice(-1 if i == axis else None) for i in range(x.ndim))] + + +class MlpBlock(nn.Module): + """Transformer MLP / feed-forward block with option to deactivate bias.""" + mlp_dim: int | None = None # Defaults to 4x input dim + dropout: float = 0.0 + use_bias: bool = True + + @nn.compact + def __call__(self, x, deterministic=True): + """Applies Transformer MlpBlock module.""" + inits = dict( + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.normal(stddev=1e-6), + ) + + n, l, d = x.shape # pylint: disable=unused-variable + x = nn.Dense(self.mlp_dim or 4 * d, use_bias=self.use_bias, **inits)(x) + x = nn.gelu(x) + x = nn.Dropout(rate=self.dropout)(x, deterministic) + x = nn.Dense(d, use_bias=self.use_bias, **inits)(x) + return x + + +class EncoderDecoderBlock(nn.Module): + """Transformer encoder-decoder layer.""" + mlp_dim: int + num_heads: int + dropout_rate: float = 0. + decode: bool = False + use_bias: bool = True + + @nn.compact + def __call__(self, targets, encoded, decoder_mask=None, deterministic=True): + """Applies EncoderDecoder1DBlock module. + + Args: + targets: target text embeddings [B, L, E]. + encoded: encoded image patches from encoder [B, P, E]. + decoder_mask: decoder self-attention mask. + deterministic: bool, deterministic or not (to apply dropout). + + Returns: + output after transformer encoder-decoder block [B, L, E]. + """ + def wlc(f): + dim_names = ("act_batch", "act_len", "act_emb") + return nn.with_logical_constraint(f, dim_names) + + # Decoder block. + x = wlc(nn.LayerNorm(name="LayerNorm1", use_bias=self.use_bias)(targets)) + x = wlc(nn.SelfAttention( + num_heads=self.num_heads, use_bias=False, broadcast_dropout=False, + dropout_rate=self.dropout_rate, decode=self.decode, name="SelfAttn")( + x, decoder_mask, deterministic=deterministic)) + x = wlc(nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic)) + x = wlc(x + targets) + + if encoded is not None: + # Encoder-Decoder block. + y = wlc(nn.LayerNorm(name="LayerNorm2", use_bias=self.use_bias)(x)) + y = wlc(nn.MultiHeadDotProductAttention( + num_heads=self.num_heads, use_bias=False, broadcast_dropout=False, + dropout_rate=self.dropout_rate, name="CrossAttn")( + y, encoded, deterministic=deterministic)) + y = wlc( + nn.Dropout(rate=self.dropout_rate)(y, deterministic=deterministic)) + y = wlc(y + x) + else: + y = x + + # MLP block. + z = wlc(nn.LayerNorm(name="LayerNorm3", use_bias=self.use_bias)(y)) + z = wlc(MlpBlock( + mlp_dim=self.mlp_dim, dropout=self.dropout_rate, use_bias=self.use_bias, + name="MLP")(z, deterministic=deterministic)) + + return wlc(y + z), None + + +class Decoder(nn.Module): + """Transformer decoder with parallel prediction.""" + emb_dim: int + mlp_dim: int + num_heads: int + num_layers: int + dropout_rate: float = 0. + output_vocab_size: int = 32_000 + + # Masked prediction training mode + masked_pred_prob: float = 0. + masking_ratio: float = 0. + + # Whether to use bias in MLP blocks and LN + use_bias: bool = True + + scan: bool = False + remat_policy: str = "nothing_saveable" + + @nn.compact + def __call__(self, + encoded, + targets, + pos_emb, + decoder_mask=None, + decode=False, + deterministic=True, + max_decode_length=None): + """Applies Transformer model on the inputs. + + Args: + encoded: encoded image patches from encoder [B, P, E]. + targets: target text tokens [B, L]. + pos_emb: positional embeddings. + decoder_mask: decoder self-attention mask. + decode: bool, whether to perform fast autoregressive decoding with cache. + deterministic: bool, deterministic or not (to apply dropout). + max_decode_length: optional max length for positional embeddings. + + Returns: + output of a transformer decoder [B, L, V]. + """ + y = targets.astype("int32") + if not decode: + if self.masked_pred_prob > 0.0 and not deterministic: + # Binary random variable indicating whether to do masked prediction + + def _add_random_masks(a): + # Generate random mask + n_masked = int(self.masking_ratio * a.shape[1]) + mask_locations = jnp.zeros(a.shape[:2], dtype=jnp.int32) + mask_locations = mask_locations.at[:, :n_masked].set(1) + mask_locations = jax.random.permutation( + self.make_rng("dropout"), mask_locations, axis=1, independent=True + ) + # Replace mask locations with mask token index (=vocab_size) + a_masked = jnp.where(mask_locations, self.output_vocab_size, a) + return a_masked + + def where(mask, x, y): + mask = mask.reshape((-1,) + (1,) * (x.ndim - 1)) + return jnp.where(mask, x, y) + + do_masked_pred = ( + jax.random.uniform(self.make_rng("dropout"), (len(y),)) + < self.masked_pred_prob + ) + y = where(do_masked_pred, _add_random_masks(y), shift_right(y)) + decoder_mask = where( + do_masked_pred, jnp.ones_like(decoder_mask), decoder_mask + ) + + else: + y = shift_right(y) + + embed = nn.Embed( + self.output_vocab_size + (1 if self.masked_pred_prob > 0.0 else 0), + self.emb_dim, + name="EmbedTargets", + embedding_init=nn.initializers.normal(stddev=1.0), + ) + y = embed(y) + + y = common.AddPositionEmbs( + decode=decode, name="PosEmbedTargets")(y, pos_emb) + # NOTE: One could apply dropout on the decoder's inputs here. Whether to do + # it or not, and if so, what is the best/common way, is to be determined. + # y = nn.Dropout(rate=self.dropout_rate)(y, deterministic=deterministic) + + if self.scan: + # Mostly followed + # https://github.com/google/maxtext/blob/4d99e30b3e0e0cb1d1aa11c7db7fffe18e301498/MaxText/layers.py#L1126 + # for the scanned version. + # 1. remat + enc_dec_block_remat = nn.remat( + EncoderDecoderBlock, + prevent_cse=False, + static_argnums=(-1,), + policy=getattr(jax.checkpoint_policies, self.remat_policy, None)) + # 2. scan + initializing = self.is_mutable_collection("params") + param_scan_axis = 1 + params_spec = (param_scan_axis if initializing + else partitioning.ScanIn(param_scan_axis)) + dec_scanned = nn.scan(enc_dec_block_remat, + variable_axes={ + "params": params_spec, + "cache": 0, + }, + split_rngs={"params": True, "dropout": True}, + in_axes=nn.broadcast, + length=self.num_layers) + # 3. fprop + y, _ = dec_scanned(num_heads=self.num_heads, mlp_dim=self.mlp_dim, + dropout_rate=self.dropout_rate, decode=decode, + use_bias=self.use_bias, name="EncDecBlock")( + y, encoded, decoder_mask, deterministic) + else: + for lyr in range(self.num_layers): + y, _ = EncoderDecoderBlock( + num_heads=self.num_heads, mlp_dim=self.mlp_dim, + dropout_rate=self.dropout_rate, decode=decode, + use_bias=self.use_bias, name=f"EncDecBlock{lyr}")( + y, encoded, decoder_mask=decoder_mask, + deterministic=deterministic) + + y = nn.LayerNorm(name="LayerNorm")(y) + + logits = nn.Dense( + self.output_vocab_size, + kernel_init=nn.initializers.zeros, + name="LogitsDense", + )(y) + return logits + + +class Model(nn.Module): + """Transformer Model for sequence to sequence translation.""" + # Encoder/decoder: + num_heads: int = 8 + num_layers: int = 6 + mlp_dim: int = 2048 + emb_dim: int = 512 + enc_dropout_rate: float = 0. + vocab_size: int = 32_000 + seq_len: int = 256 + + # Encoder: + patches: Sequence[int] = (16, 16) + input_seq_len: int = 768 + posemb_type: str = "learn" + patch_dropout: float = 0. + + # Decoder: + decoder_num_heads: int = 0 + decoder_num_layers: int = 0 + decoder_mlp_dim: int = 0 + decoder_emb_dim: int = 0 + dec_dropout_rate: float = 0. + # Probability of masked prediction rather than autoregressive prediciton. + masked_pred_prob: float = 0. + # Masking ratio for masked prediction. + masking_ratio: float = 0. + # Whether to use bias in decoder MLP blocks and LN. + decoder_bias: bool = True + + scan: bool = False + remat_policy: str = "nothing_saveable" + + def setup(self): + + self.encoder = vit.Model( + patch_size=self.patches, + width=self.emb_dim, + depth=self.num_layers, + mlp_dim=self.mlp_dim, + num_heads=self.num_heads, + dropout=self.enc_dropout_rate, + posemb=self.posemb_type, + scan=self.scan, + remat_policy=self.remat_policy, + ) + + self.pos_emb_for_decoder = vit.get_posemb( + self, + self.posemb_type, + (1, self.seq_len), + self.decoder_emb_dim or self.emb_dim, + "pos_embedding_decoder", + ) + self.decoder = Decoder( + num_layers=self.decoder_num_layers or self.num_layers, + mlp_dim=self.decoder_mlp_dim or self.mlp_dim, + num_heads=self.decoder_num_heads or self.num_heads, + dropout_rate=self.dec_dropout_rate, + emb_dim=self.decoder_emb_dim or self.emb_dim, + output_vocab_size=self.vocab_size, + masked_pred_prob=self.masked_pred_prob, + masking_ratio=self.masking_ratio, + use_bias=self.decoder_bias, + scan=self.scan, + remat_policy=self.remat_policy, + ) + + def encode(self, image, train=False, return_enc_features=False): + """Encodes input image or embeddings.""" + + _, out = self.encoder(image, train=train) + encoded = out["encoded"] + + # Return intermediate features if required + if return_enc_features: + return encoded, out + + return encoded + + def decode(self, encoded, targets, decode=False, train=False, + max_decode_length=None): + """Applies Transformer decoder-branch on encoded-input and target. + + Args: + encoded: encoded image patches from encoder [B, P, E]. + targets: target text tokens [B, L]. + decode: whether to prepare and use an autoregressive cache. + train: whether it is training. + max_decode_length: optional max length for positional embeddings. + + Returns: + logits array from transformer decoder [B, L, V]. + """ + decoder_mask = None if decode else nn.make_causal_mask(targets) + logits = self.decoder( + encoded, + targets, + pos_emb=self.pos_emb_for_decoder, + decoder_mask=decoder_mask, + decode=decode, + deterministic=not train, + max_decode_length=max_decode_length) + return logits + + def __call__(self, image, text, *, decode=False, + train=False, return_enc_features=False): + """Applies Transformer model on the inputs. + + Args: + image: batch of images [B, H, W, 3]. + text: batch of tokenized texts [B, L]. + decode: whether to prepare and use an autoregressive cache. + train: whether it is training. + return_enc_features: whether to return the encoder features. + + Returns: + logits array from full transformer [B, L, V]. + """ + if return_enc_features: + encoded, out = self.encode(image, train=train, return_enc_features=True) + return encoded, out + + encoded = self.encode(image, train=train) + + decoded = self.decode(encoded, text, decode=decode, train=train) + return decoded + + +def load(init_params, init_files, model_params=None, + dont_load=("head/kernel", "head/bias", "cls")): + """Loads params from init checkpoint and merges into init_params.""" + + if isinstance(init_files, str): + # A shortcut for a single file checkpoint of a vtt model. + ckpt_params = utils.load_params(init_files) + ckpt_params = flax.training.checkpoints.convert_pre_linen(ckpt_params) + ckpt_params = common.merge_params(ckpt_params, init_params, dont_load) + + # Detect attempts to load non-scan checkpoint into scan model if possible. + if (model_params.get("scan") and + "encoderblock" not in ckpt_params["encoder"]["Transformer"]): + raise NotImplementedError("Loading a non-scan checkpoint into a " + "scan model is not supported yet!") + if (not model_params.get("scan") + and "encoderblock" in ckpt_params["encoder"]["Transformer"]): + assert "decoder.*" in dont_load or "decoder/.*" in dont_load, ( + "Converting scan decoder to a non-scan one is not supported yet!") + ckpt_params["encoder"] = utils.jit_cpu()( + vit.scan_to_pyloop)(ckpt_params["encoder"]) + + else: + assert set(init_files) == {"encoder"}, "Only encoder init supported" + enc_init = init_files["encoder"] + ckpt_params = flax.core.freeze(init_params).unfreeze() + vit_params = ckpt_params["encoder"] + encoder_params = vit.load( + vit_params, enc_init, model_cfg={}, + dont_load=dont_load) + ckpt_params["encoder"] = encoder_params + + ckpt_params["encoder"]["pos_embedding"] = vit.resample_posemb( + old=ckpt_params["encoder"]["pos_embedding"], + new=init_params["encoder"]["pos_embedding"]) + + return ckpt_params diff --git a/big_vision/models/proj/clippo/one_tower.py b/big_vision/models/proj/clippo/one_tower.py new file mode 100644 index 0000000000000000000000000000000000000000..c7ce5abd03174b91328e74302d1f15dadd98e0dc --- /dev/null +++ b/big_vision/models/proj/clippo/one_tower.py @@ -0,0 +1,96 @@ +# Copyright 2022 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Model definition to train a single ViT model with the contrastive trainer.""" + +import importlib +from typing import Optional, Any + +from big_vision import utils +import flax.linen as nn +import jax.numpy as jnp + +ConfigDict = Any + + +class Model(nn.Module): + """Single ViT to encode regular images and text images.""" + image: Optional[ConfigDict] = None + image_model: str = "vit" + out_dim: int = 768 + temperature_init: float = 10.0 + + @nn.compact + def __call__(self, image, text=None, **kw): + """Returns (B, C) image and (B, C) text representations, and some extras.""" + ztxt, zimg = None, None + kw = kw or {} + + image_model = importlib.import_module( + f"big_vision.models.{self.image_model}" + ).Model(**{"num_classes": self.out_dim, **(self.image or {})}, name="img") # pylint: disable=not-a-mapping + + def _compute_embedding(input_image, prefix): + zemb, out_emb = image_model(input_image, **kw) + out = {f"{prefix}/{k}": v for k, v in out_emb.items()} + + # Normalize the embeddings. + out[f"{prefix}/norm"] = jnp.linalg.norm(zemb, axis=1, keepdims=True) + out[f"{prefix}/normalized"] = zemb = zemb / (out[f"{prefix}/norm"] + 1e-8) + return zemb, out + + out = {} + if image is not None: + zimg, out_img = _compute_embedding(image, "img") + out.update(out_img) + + if text is not None: + ztxt, out_txt = _compute_embedding(text, "txt") + out.update(out_txt) + + temp_init = jnp.log(self.temperature_init) + t = self.param("t", + lambda key, shape, dtype: temp_init*jnp.ones(shape, dtype), + (1,), jnp.float32) + out["t"] = jnp.exp(t) + out["t/parameter"] = t + + return zimg, ztxt, out + + +def load(init_params, init_files, model_cfg, img_load_kw={}): # pylint: disable=dangerous-default-value + """Loads the ViT parameters - adapted from proj/image_text/two_towers.py.""" + if isinstance(init_files, str): + # A shortcut for a single file checkpoint of a two_towers model. + init_files = {k: f"{init_files}:{k}" for k in ("img", "t")} + else: + init_files = {**init_files} # Shallow copy because we'll pop stuff off. + + restored_params = {**init_params} + + img_init = init_files.pop("image", init_files.pop("img", None)) + if img_init: + restored_params["img"] = importlib.import_module( + f"big_vision.models.{model_cfg.image_model}" + ).load(init_params["img"], img_init, model_cfg.image, **img_load_kw) + + t_init = init_files.pop("temperature", init_files.pop("t", None)) + if t_init: + restored_params["t"] = utils.load_params(None, t_init) + + assert not init_files, ( + f"There's something unused left in `config.model_init`. You probably got " + f"a typo. Here it is: {init_files}") + + return restored_params diff --git a/big_vision/models/proj/flaxformer/bert.py b/big_vision/models/proj/flaxformer/bert.py new file mode 100644 index 0000000000000000000000000000000000000000..4bad30be3dbdb736586fa04109ff756b3b0975c5 --- /dev/null +++ b/big_vision/models/proj/flaxformer/bert.py @@ -0,0 +1,94 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""BERT encoder, optionally loading pre-trained checkpoints.""" + +import dataclasses +from typing import Optional + +from absl import logging +from big_vision import utils +from big_vision.models import common +import flax +import flax.linen as nn +import jax.numpy as jnp +from tensorflow.io import gfile + +from flaxformer.architectures.bert import bert +from flaxformer.architectures.bert import bert_checkpoint_converter +from flaxformer.architectures.bert import configs + + +class Model(nn.Module): + """BERT encoder with linear projection on last layer CLS token.""" + + config: str + num_classes: Optional[int] = None + head_zeroinit: bool = True + + @nn.compact + def __call__(self, text, *, train=False): + out = {} + + batch_size, max_len = text.shape + bert_model = bert.BertEncoder(**dataclasses.asdict({ + "base": configs.BertBaseConfig(), + "large": configs.BertLargeConfig(), + }[self.config])) + x = out["transformed"] = bert_model( + token_ids=text, + position_ids=jnp.tile( + jnp.arange(0, max_len, dtype=jnp.int32), [batch_size, 1]), + segment_ids=jnp.zeros([batch_size, max_len], dtype=jnp.int32), + input_mask=text.astype(jnp.bool_).astype(jnp.int32), + enable_dropout=train, + ) + + x = out["pre_logits"] = x[:, 0] # CLS token + + if self.num_classes: + kw = {"kernel_init": nn.initializers.zeros} if self.head_zeroinit else {} + x = out["logits"] = nn.Dense(self.num_classes, name="head", **kw)(x) + + return x, out + + +def load(params, path, model_cfg=None, dont_load=()): + """Returns `params` with BERT weights replaced from checkpoint at `path`.""" + del model_cfg + + checkpoint_path = f"{path}/bert_model.ckpt" + if gfile.exists(f"{checkpoint_path}.index"): + logging.info("Loading original BERT checkpoint from '%s'", checkpoint_path) + params = flax.core.FrozenDict(params).unfreeze() # Recursive copy. + max_len = ( + params["BertEncoder_0"]["embedder"]["embedders_position_ids"] + ["embedding"].shape[0]) + bert_params, pooler_params = ( + bert_checkpoint_converter.load_params_from_tf_checkpoint( + checkpoint_path=f"{path}/bert_model.ckpt")) + del pooler_params + if isinstance(bert_params, flax.core.FrozenDict): + bert_params = bert_params.unfreeze() + bert_params["embedder"]["embedders_position_ids"]["embedding"] = ( + bert_params["embedder"]["embedders_position_ids"]["embedding"][:max_len] + ) + return common.merge_params( + {"BertEncoder_0": bert_params}, params, dont_load) + + logging.info( + "Could not find original BERT checkpoint path '%s', " + "loading big_vision checkpoint '%s'", checkpoint_path, path) + restored_params = utils.load_params(path) + return common.merge_params(restored_params, params, dont_load) diff --git a/big_vision/models/proj/flaxformer/bert_test.py b/big_vision/models/proj/flaxformer/bert_test.py new file mode 100644 index 0000000000000000000000000000000000000000..90405c2ec70a45b3a8ae90aa2dd39aaea2fee240 --- /dev/null +++ b/big_vision/models/proj/flaxformer/bert_test.py @@ -0,0 +1,77 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for bert.""" + +import tempfile + +from big_vision import input_pipeline +from big_vision.models.proj.flaxformer import bert +from big_vision.models.proj.flaxformer import bert_test_util +import big_vision.pp.builder as pp_builder +import big_vision.pp.ops_general # pylint: disable=unused-import +import big_vision.pp.proj.flaxformer.bert_ops # pylint: disable=unused-import +import flax +import jax +import jax.numpy as jnp +import tensorflow as tf + + +# BERT vocabulary for testing. +_BERT_VOCAB = [ + "[PAD]", + "[UNK]", + "this", + "is", + "a", + "test", + "[CLS]", + "[SEP]", +] +_TOKEN_LEN = 16 + + +class BertTest(tf.test.TestCase): + + def test_load_apply(self): + inkey = "text" + vocab_path = f"{tempfile.mkdtemp()}/vocab.txt" + with open(vocab_path, "w") as f: + f.write("\n".join(_BERT_VOCAB)) + ds2, _ = input_pipeline.make_for_inference( + tf.data.Dataset.from_tensor_slices( + {inkey: tf.ragged.constant([["this is a test"]])}), + num_ex_per_process=[1], + preprocess_fn=pp_builder.get_preprocess_fn( + f"bert_tokenize(inkey='{inkey}', vocab_path='{vocab_path}', " + f"max_len={_TOKEN_LEN})" + "|keep('labels')"), + batch_size=1, + ) + text = jnp.array(next(iter(ds2))["labels"]) + model = bert.Model(config="base") + variables = model.init(jax.random.PRNGKey(0), text) + params = bert.load(flax.core.unfreeze(variables)["params"], + bert_test_util.create_base_checkpoint()) + x, out = model.apply({"params": params}, text) + self.assertAllEqual(jax.tree_map(jnp.shape, x), (1, 768)) + self.assertAllEqual( + jax.tree_map(jnp.shape, out), { + "transformed": (1, 16, 768), + "pre_logits": (1, 768), + }) + + +if __name__ == "__main__": + tf.test.main() diff --git a/big_vision/models/proj/flaxformer/bert_test_util.py b/big_vision/models/proj/flaxformer/bert_test_util.py new file mode 100644 index 0000000000000000000000000000000000000000..11a8a0105dec13490e8438d4162040b2fa4a06fe --- /dev/null +++ b/big_vision/models/proj/flaxformer/bert_test_util.py @@ -0,0 +1,261 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for fake BERT checkpoint.""" + +import tempfile + +import tensorflow.compat.v1 as tf + +# Checkpoint structure was extracted with the following (Colab) snippet: +# +# !wget https://storage.googleapis.com/bert_models/2020_02_20/uncased_L-12_H-768_A-12.zip # pylint: disable=line-too-long +# !unzip uncased_L-12_H-768_A-12.zip +# +# import tensorflow.compat.v1 as tf +# +# ckpt_reader = tf.train.load_checkpoint('bert_model.ckpt') +# tf_params = { +# tf_name: ckpt_reader.get_tensor(tf_name) +# for tf_name in ckpt_reader.get_variable_to_dtype_map() +# } +# +# 'shapes_dtypes = {\n%s\n}' % '\n'.join( +# f' "{k}": ({v.shape}, "{v.dtype}"),' +# for k, v, in tf_params.items() +# ) + +# pylint: disable=line-too-long +_BASE_SHAPES_DTYPES = { + "cls/seq_relationship/output_bias": ((2,), "float32"), + "cls/predictions/transform/LayerNorm/gamma": ((768,), "float32"), + "cls/predictions/transform/LayerNorm/beta": ((768,), "float32"), + "bert/pooler/dense/kernel": ((768, 768), "float32"), + "bert/encoder/layer_5/attention/output/dense/bias": ((768,), "float32"), + "bert/encoder/layer_9/output/dense/bias": ((768,), "float32"), + "bert/encoder/layer_9/output/LayerNorm/gamma": ((768,), "float32"), + "bert/encoder/layer_3/attention/self/value/bias": ((768,), "float32"), + "bert/encoder/layer_7/output/dense/bias": ((768,), "float32"), + "bert/encoder/layer_9/output/LayerNorm/beta": ((768,), "float32"), + "bert/encoder/layer_7/attention/output/LayerNorm/beta": ((768,), "float32"), + "bert/encoder/layer_9/intermediate/dense/bias": ((3072,), "float32"), + "bert/encoder/layer_9/attention/self/query/kernel": ((768, 768), "float32"), + "bert/encoder/layer_9/attention/self/query/bias": ((768,), "float32"), + "bert/encoder/layer_9/attention/output/LayerNorm/gamma": ((768,), "float32"), + "bert/encoder/layer_9/attention/output/LayerNorm/beta": ((768,), "float32"), + "bert/encoder/layer_8/output/dense/bias": ((768,), "float32"), + "bert/encoder/layer_4/attention/self/value/kernel": ((768, 768), "float32"), + "bert/encoder/layer_8/output/LayerNorm/gamma": ((768,), "float32"), + "bert/encoder/layer_8/output/LayerNorm/beta": ((768,), "float32"), + "bert/encoder/layer_11/output/LayerNorm/beta": ((768,), "float32"), + "bert/encoder/layer_11/output/LayerNorm/gamma": ((768,), "float32"), + "bert/encoder/layer_8/intermediate/dense/kernel": ((768, 3072), "float32"), + "bert/encoder/layer_8/attention/self/value/kernel": ((768, 768), "float32"), + "bert/encoder/layer_2/output/dense/kernel": ((3072, 768), "float32"), + "bert/encoder/layer_8/attention/self/value/bias": ((768,), "float32"), + "bert/encoder/layer_1/attention/self/value/bias": ((768,), "float32"), + "bert/encoder/layer_8/attention/output/dense/kernel": ((768, 768), "float32"), + "bert/encoder/layer_3/attention/output/LayerNorm/beta": ((768,), "float32"), + "bert/encoder/layer_8/attention/self/query/bias": ((768,), "float32"), + "bert/encoder/layer_8/attention/output/LayerNorm/gamma": ((768,), "float32"), + "bert/encoder/layer_8/attention/output/LayerNorm/beta": ((768,), "float32"), + "bert/encoder/layer_7/output/LayerNorm/gamma": ((768,), "float32"), + "bert/encoder/layer_7/intermediate/dense/bias": ((3072,), "float32"), + "bert/encoder/layer_8/intermediate/dense/bias": ((3072,), "float32"), + "bert/encoder/layer_7/attention/self/value/kernel": ((768, 768), "float32"), + "bert/encoder/layer_7/attention/self/value/bias": ((768,), "float32"), + "bert/encoder/layer_9/attention/self/key/kernel": ((768, 768), "float32"), + "bert/encoder/layer_7/attention/self/query/kernel": ((768, 768), "float32"), + "bert/encoder/layer_7/attention/self/key/kernel": ((768, 768), "float32"), + "bert/encoder/layer_6/output/dense/kernel": ((3072, 768), "float32"), + "bert/encoder/layer_6/output/LayerNorm/gamma": ((768,), "float32"), + "bert/encoder/layer_6/intermediate/dense/kernel": ((768, 3072), "float32"), + "bert/encoder/layer_0/attention/self/query/bias": ((768,), "float32"), + "bert/encoder/layer_6/attention/output/dense/bias": ((768,), "float32"), + "bert/encoder/layer_7/output/LayerNorm/beta": ((768,), "float32"), + "bert/encoder/layer_4/attention/self/key/kernel": ((768, 768), "float32"), + "bert/encoder/layer_5/output/dense/bias": ((768,), "float32"), + "bert/encoder/layer_2/attention/self/value/kernel": ((768, 768), "float32"), + "bert/encoder/layer_5/output/LayerNorm/gamma": ((768,), "float32"), + "bert/encoder/layer_5/output/LayerNorm/beta": ((768,), "float32"), + "bert/encoder/layer_5/intermediate/dense/kernel": ((768, 3072), "float32"), + "bert/encoder/layer_9/attention/self/value/kernel": ((768, 768), "float32"), + "bert/encoder/layer_3/attention/self/query/kernel": ((768, 768), "float32"), + "bert/encoder/layer_8/attention/self/key/kernel": ((768, 768), "float32"), + "bert/encoder/layer_5/intermediate/dense/bias": ((3072,), "float32"), + "bert/encoder/layer_5/attention/self/query/kernel": ((768, 768), "float32"), + "bert/encoder/layer_5/attention/self/key/bias": ((768,), "float32"), + "bert/encoder/layer_5/attention/output/dense/kernel": ((768, 768), "float32"), + "bert/encoder/layer_5/attention/self/value/kernel": ((768, 768), "float32"), + "bert/encoder/layer_5/attention/output/LayerNorm/gamma": ((768,), "float32"), + "bert/encoder/layer_5/attention/output/LayerNorm/beta": ((768,), "float32"), + "bert/encoder/layer_4/output/dense/bias": ((768,), "float32"), + "bert/embeddings/token_type_embeddings": ((2, 768), "float32"), + "bert/encoder/layer_4/output/LayerNorm/gamma": ((768,), "float32"), + "bert/encoder/layer_4/output/LayerNorm/beta": ((768,), "float32"), + "bert/encoder/layer_7/output/dense/kernel": ((3072, 768), "float32"), + "bert/encoder/layer_4/intermediate/dense/kernel": ((768, 3072), "float32"), + "bert/encoder/layer_9/attention/self/key/bias": ((768,), "float32"), + "bert/encoder/layer_10/attention/self/key/bias": ((768,), "float32"), + "bert/encoder/layer_6/attention/self/value/kernel": ((768, 768), "float32"), + "bert/encoder/layer_4/attention/self/query/bias": ((768,), "float32"), + "cls/seq_relationship/output_weights": ((2, 768), "float32"), + "bert/encoder/layer_7/intermediate/dense/kernel": ((768, 3072), "float32"), + "bert/encoder/layer_4/attention/output/dense/kernel": ((768, 768), "float32"), + "bert/encoder/layer_4/attention/output/dense/bias": ((768,), "float32"), + "bert/encoder/layer_4/attention/output/LayerNorm/gamma": ((768,), "float32"), + "bert/encoder/layer_3/output/dense/kernel": ((3072, 768), "float32"), + "bert/encoder/layer_1/output/LayerNorm/beta": ((768,), "float32"), + "bert/encoder/layer_2/attention/self/value/bias": ((768,), "float32"), + "bert/encoder/layer_8/attention/self/query/kernel": ((768, 768), "float32"), + "bert/encoder/layer_4/output/dense/kernel": ((3072, 768), "float32"), + "bert/encoder/layer_3/output/dense/bias": ((768,), "float32"), + "bert/encoder/layer_4/attention/self/value/bias": ((768,), "float32"), + "bert/encoder/layer_3/intermediate/dense/bias": ((3072,), "float32"), + "bert/encoder/layer_1/intermediate/dense/bias": ((3072,), "float32"), + "bert/encoder/layer_3/attention/self/value/kernel": ((768, 768), "float32"), + "bert/encoder/layer_10/attention/self/value/bias": ((768,), "float32"), + "bert/encoder/layer_3/attention/self/query/bias": ((768,), "float32"), + "bert/encoder/layer_1/attention/self/key/kernel": ((768, 768), "float32"), + "bert/encoder/layer_0/attention/self/query/kernel": ((768, 768), "float32"), + "bert/encoder/layer_10/output/dense/bias": ((768,), "float32"), + "bert/encoder/layer_3/attention/self/key/kernel": ((768, 768), "float32"), + "bert/encoder/layer_3/attention/output/dense/kernel": ((768, 768), "float32"), + "bert/encoder/layer_1/output/LayerNorm/gamma": ((768,), "float32"), + "bert/encoder/layer_3/attention/output/dense/bias": ((768,), "float32"), + "bert/encoder/layer_1/attention/output/dense/bias": ((768,), "float32"), + "bert/encoder/layer_3/attention/output/LayerNorm/gamma": ((768,), "float32"), + "bert/encoder/layer_2/output/dense/bias": ((768,), "float32"), + "bert/encoder/layer_6/attention/self/key/bias": ((768,), "float32"), + "bert/encoder/layer_11/attention/output/dense/kernel": ((768, 768), "float32"), + "bert/encoder/layer_2/intermediate/dense/kernel": ((768, 3072), "float32"), + "bert/encoder/layer_2/attention/self/key/kernel": ((768, 768), "float32"), + "bert/encoder/layer_2/attention/output/dense/bias": ((768,), "float32"), + "bert/encoder/layer_2/attention/self/key/bias": ((768,), "float32"), + "bert/encoder/layer_6/attention/self/query/bias": ((768,), "float32"), + "bert/encoder/layer_11/intermediate/dense/kernel": ((768, 3072), "float32"), + "bert/encoder/layer_6/output/LayerNorm/beta": ((768,), "float32"), + "bert/encoder/layer_11/intermediate/dense/bias": ((3072,), "float32"), + "bert/encoder/layer_11/attention/self/query/bias": ((768,), "float32"), + "bert/encoder/layer_11/attention/self/value/kernel": ((768, 768), "float32"), + "bert/encoder/layer_10/output/LayerNorm/gamma": ((768,), "float32"), + "bert/encoder/layer_11/attention/self/key/kernel": ((768, 768), "float32"), + "bert/encoder/layer_6/attention/self/key/kernel": ((768, 768), "float32"), + "bert/encoder/layer_6/attention/output/LayerNorm/beta": ((768,), "float32"), + "bert/encoder/layer_11/attention/self/key/bias": ((768,), "float32"), + "bert/encoder/layer_10/attention/self/value/kernel": ((768, 768), "float32"), + "bert/encoder/layer_4/intermediate/dense/bias": ((3072,), "float32"), + "bert/encoder/layer_11/attention/output/dense/bias": ((768,), "float32"), + "bert/encoder/layer_10/attention/self/query/bias": ((768,), "float32"), + "bert/embeddings/LayerNorm/gamma": ((768,), "float32"), + "bert/encoder/layer_2/attention/output/LayerNorm/gamma": ((768,), "float32"), + "bert/encoder/layer_11/output/dense/kernel": ((3072, 768), "float32"), + "bert/encoder/layer_11/attention/output/LayerNorm/beta": ((768,), "float32"), + "bert/encoder/layer_5/output/dense/kernel": ((3072, 768), "float32"), + "bert/encoder/layer_3/attention/self/key/bias": ((768,), "float32"), + "bert/encoder/layer_10/output/LayerNorm/beta": ((768,), "float32"), + "bert/encoder/layer_10/intermediate/dense/kernel": ((768, 3072), "float32"), + "bert/embeddings/word_embeddings": ((30522, 768), "float32"), + "bert/encoder/layer_9/output/dense/kernel": ((3072, 768), "float32"), + "bert/encoder/layer_9/attention/self/value/bias": ((768,), "float32"), + "bert/encoder/layer_6/attention/output/LayerNorm/gamma": ((768,), "float32"), + "bert/encoder/layer_10/intermediate/dense/bias": ((3072,), "float32"), + "bert/encoder/layer_6/attention/self/value/bias": ((768,), "float32"), + "bert/encoder/layer_1/attention/output/LayerNorm/gamma": ((768,), "float32"), + "bert/encoder/layer_5/attention/self/value/bias": ((768,), "float32"), + "bert/encoder/layer_2/output/LayerNorm/beta": ((768,), "float32"), + "bert/encoder/layer_0/output/dense/bias": ((768,), "float32"), + "bert/encoder/layer_3/intermediate/dense/kernel": ((768, 3072), "float32"), + "cls/predictions/output_bias": ((30522,), "float32"), + "bert/encoder/layer_0/attention/self/value/bias": ((768,), "float32"), + "bert/encoder/layer_6/output/dense/bias": ((768,), "float32"), + "bert/encoder/layer_0/attention/output/dense/kernel": ((768, 768), "float32"), + "bert/encoder/layer_2/attention/output/LayerNorm/beta": ((768,), "float32"), + "bert/encoder/layer_10/attention/output/LayerNorm/beta": ((768,), "float32"), + "bert/encoder/layer_5/attention/self/query/bias": ((768,), "float32"), + "bert/encoder/layer_4/attention/self/query/kernel": ((768, 768), "float32"), + "bert/encoder/layer_0/attention/self/key/kernel": ((768, 768), "float32"), + "bert/encoder/layer_0/output/dense/kernel": ((3072, 768), "float32"), + "bert/encoder/layer_10/attention/output/LayerNorm/gamma": ((768,), "float32"), + "bert/encoder/layer_7/attention/output/dense/bias": ((768,), "float32"), + "bert/encoder/layer_3/output/LayerNorm/gamma": ((768,), "float32"), + "bert/encoder/layer_2/attention/self/query/bias": ((768,), "float32"), + "bert/encoder/layer_8/output/dense/kernel": ((3072, 768), "float32"), + "bert/embeddings/LayerNorm/beta": ((768,), "float32"), + "bert/encoder/layer_1/attention/self/value/kernel": ((768, 768), "float32"), + "bert/encoder/layer_10/attention/output/dense/kernel": ((768, 768), "float32"), + "bert/encoder/layer_2/intermediate/dense/bias": ((3072,), "float32"), + "bert/encoder/layer_6/intermediate/dense/bias": ((3072,), "float32"), + "bert/encoder/layer_2/attention/output/dense/kernel": ((768, 768), "float32"), + "bert/encoder/layer_11/attention/self/value/bias": ((768,), "float32"), + "bert/encoder/layer_9/attention/output/dense/bias": ((768,), "float32"), + "bert/encoder/layer_0/attention/output/LayerNorm/gamma": ((768,), "float32"), + "bert/encoder/layer_10/attention/output/dense/bias": ((768,), "float32"), + "bert/encoder/layer_10/output/dense/kernel": ((3072, 768), "float32"), + "bert/encoder/layer_1/attention/output/dense/kernel": ((768, 768), "float32"), + "bert/encoder/layer_8/attention/self/key/bias": ((768,), "float32"), + "bert/encoder/layer_0/intermediate/dense/bias": ((3072,), "float32"), + "bert/encoder/layer_1/intermediate/dense/kernel": ((768, 3072), "float32"), + "bert/encoder/layer_1/attention/self/key/bias": ((768,), "float32"), + "bert/encoder/layer_7/attention/output/dense/kernel": ((768, 768), "float32"), + "bert/encoder/layer_2/attention/self/query/kernel": ((768, 768), "float32"), + "bert/encoder/layer_8/attention/output/dense/bias": ((768,), "float32"), + "cls/predictions/transform/dense/kernel": ((768, 768), "float32"), + "bert/encoder/layer_6/attention/self/query/kernel": ((768, 768), "float32"), + "bert/encoder/layer_5/attention/self/key/kernel": ((768, 768), "float32"), + "bert/encoder/layer_0/attention/self/value/kernel": ((768, 768), "float32"), + "bert/encoder/layer_7/attention/self/query/bias": ((768,), "float32"), + "bert/encoder/layer_7/attention/self/key/bias": ((768,), "float32"), + "bert/encoder/layer_1/output/dense/kernel": ((3072, 768), "float32"), + "bert/encoder/layer_11/attention/output/LayerNorm/gamma": ((768,), "float32"), + "bert/encoder/layer_4/attention/output/LayerNorm/beta": ((768,), "float32"), + "bert/encoder/layer_1/attention/output/LayerNorm/beta": ((768,), "float32"), + "bert/encoder/layer_9/attention/output/dense/kernel": ((768, 768), "float32"), + "bert/encoder/layer_2/output/LayerNorm/gamma": ((768,), "float32"), + "bert/encoder/layer_0/output/LayerNorm/gamma": ((768,), "float32"), + "bert/encoder/layer_10/attention/self/query/kernel": ((768, 768), "float32"), + "bert/encoder/layer_1/attention/self/query/bias": ((768,), "float32"), + "bert/encoder/layer_3/output/LayerNorm/beta": ((768,), "float32"), + "bert/encoder/layer_6/attention/output/dense/kernel": ((768, 768), "float32"), + "bert/encoder/layer_1/attention/self/query/kernel": ((768, 768), "float32"), + "bert/encoder/layer_11/output/dense/bias": ((768,), "float32"), + "cls/predictions/transform/dense/bias": ((768,), "float32"), + "bert/encoder/layer_0/intermediate/dense/kernel": ((768, 3072), "float32"), + "bert/encoder/layer_11/attention/self/query/kernel": ((768, 768), "float32"), + "bert/encoder/layer_0/attention/self/key/bias": ((768,), "float32"), + "bert/encoder/layer_0/attention/output/dense/bias": ((768,), "float32"), + "bert/encoder/layer_7/attention/output/LayerNorm/gamma": ((768,), "float32"), + "bert/encoder/layer_4/attention/self/key/bias": ((768,), "float32"), + "bert/encoder/layer_10/attention/self/key/kernel": ((768, 768), "float32"), + "bert/embeddings/position_embeddings": ((512, 768), "float32"), + "bert/encoder/layer_1/output/dense/bias": ((768,), "float32"), + "bert/encoder/layer_9/intermediate/dense/kernel": ((768, 3072), "float32"), + "bert/encoder/layer_0/output/LayerNorm/beta": ((768,), "float32"), + "bert/pooler/dense/bias": ((768,), "float32"), + "bert/encoder/layer_0/attention/output/LayerNorm/beta": ((768,), "float32"), +} +# pylint: enable=line-too-long + + +def create_base_checkpoint(): + """Returns path to fake Bert "base" checkpoint directory (zero init).""" + directory = tempfile.mkdtemp() + path = f"{directory}/bert_model.ckpt" + with tf.Session() as sess: + for name, (shape, dtype) in _BASE_SHAPES_DTYPES.items(): + tf.Variable(tf.zeros(shape, dtype), name=name) + saver = tf.train.Saver() + sess.run(tf.global_variables_initializer()) + saver.save(sess, path) + return directory diff --git a/big_vision/models/proj/flexi/vit.py b/big_vision/models/proj/flexi/vit.py new file mode 100644 index 0000000000000000000000000000000000000000..3a0af0e95045016fa6adb0dd7f49eee70d4005bf --- /dev/null +++ b/big_vision/models/proj/flexi/vit.py @@ -0,0 +1,226 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A version of ViT with flexible seqlen ((internal link)).""" + +from typing import Optional, Sequence + +from absl import logging +from big_vision import utils +from big_vision.models import common +from big_vision.models import vit +import flax.linen as nn +import jax +import jax.numpy as jnp +import numpy as np +import tensorflow as tf + + +def resample_patchemb(old, new_hw): + """Resample the weights of the patch embedding kernel to target resolution. + + We resample the patch embedding kernel by approximately inverting the effect + of patch resizing. Colab with detailed explanation: + (internal link) + With this resizing, we can for example load a B/8 filter into a B/16 model + and, on 2x larger input image, the result will match. + See (internal link) + Args: + old: original parameter to be resized. + new_hw: target shape (height, width)-only. + Returns: + Resized patch embedding kernel. + """ + assert len(old.shape) == 4, "Four dimensions expected" + assert len(new_hw) == 2, "New shape should only be hw" + if tuple(old.shape[:2]) == tuple(new_hw): + return old + + logging.info("FlexiViT: resize embedding %s to %s", old.shape, new_hw) + + def resize(x_np, new_shape): + x_tf = tf.constant(x_np)[None, ..., None] + # NOTE: we are using tf.image.resize here to match the resize operations in + # the data preprocessing pipeline. + x_upsampled = tf.image.resize( + x_tf, new_shape, method="bilinear")[0, ..., 0].numpy() + return x_upsampled + + def get_resize_mat(old_shape, new_shape): + mat = [] + for i in range(np.prod(old_shape)): + basis_vec = np.zeros(old_shape) + basis_vec[np.unravel_index(i, old_shape)] = 1. + mat.append(resize(basis_vec, new_shape).reshape(-1)) + return np.stack(mat).T + + resize_mat = get_resize_mat(old.shape[:2], new_hw) + resize_mat_pinv = np.linalg.pinv(resize_mat.T) + + def resample_kernel(kernel): + resampled_kernel = resize_mat_pinv @ kernel.reshape(-1) + return resampled_kernel.reshape(new_hw) + v_resample_kernel = jax.vmap(jax.vmap(resample_kernel, 2, 2), 3, 3) + return v_resample_kernel(old) + + +class Patchify(nn.Module): + """As a class just to match param names with original ViT.""" + + patch_size: Sequence[int] = (32, 32) + width: int = 768 + seqhw: Optional[int] = None + + @nn.compact + def __call__(self, image, seqhw=None): + n, h, w, c = image.shape # pylint: disable=unused-variable + + w_emb = self.param( + "kernel", nn.initializers.normal(stddev=1/np.sqrt(self.width)), + (*self.patch_size, c, self.width), image.dtype) + b_emb = self.param("bias", nn.initializers.zeros, self.width, image.dtype) + + # Compute required patch-size to reach `seqhw` given `image` size. + seqhw = seqhw or self.seqhw + if seqhw is None and self.is_initializing(): + patch_size = self.patch_size + else: + patch_size = tuple(np.array((h, w)) // np.array((seqhw, seqhw))) + + if patch_size != self.patch_size: + w_emb = resample_patchemb(old=w_emb, new_hw=patch_size) + + x = jax.lax.conv_general_dilated( + image, w_emb, window_strides=patch_size, padding="VALID", + dimension_numbers=("NHWC", "HWIO", "NHWC")) + return x + b_emb + + +class _Model(nn.Module): + """ViT model.""" + + num_classes: int + patch_size: Sequence[int] = (32, 32) + posemb_size: Sequence[int] = (7, 7) + width: int = 768 + depth: int = 12 + mlp_dim: Optional[int] = None # Defaults to 4x input dim + num_heads: int = 12 + posemb: str = "learn" # Can also be "sincos2d" + pool_type: str = "gap" # Can also be "map" or "tok" + head_zeroinit: bool = True + + seqhw: Optional[int] = None + + @nn.compact + def __call__(self, image, *, seqhw=None, train=False): + out = {} + + x = out["stem"] = Patchify( + self.patch_size, self.width, self.seqhw, name="embedding")(image, seqhw) + + # == Flattening + posemb + n, h, w, c = x.shape + x = jnp.reshape(x, [n, h * w, c]) + + pos_emb = vit.get_posemb( + self, self.posemb, self.posemb_size, c, "pos_embedding", x.dtype) + if pos_emb.shape[1] != h * w: + pos_emb = jnp.reshape(pos_emb, (1, *self.posemb_size, c)) + pos_emb = jax.image.resize(pos_emb, (1, h, w, c), "linear") + pos_emb = jnp.reshape(pos_emb, (1, h * w, c)) + + x = out["with_posemb"] = x + pos_emb + + # == Optional [cls] token + if self.pool_type == "tok": + cls = self.param("cls", nn.initializers.zeros, (1, 1, c), x.dtype) + x = jnp.concatenate([jnp.tile(cls, [n, 1, 1]), x], axis=1) + + # == Encoder + n, l, c = x.shape # pylint: disable=unused-variable + + x, out["encoder"] = vit.Encoder( + depth=self.depth, + mlp_dim=self.mlp_dim, + num_heads=self.num_heads, + name="Transformer")(x) + encoded = out["encoded"] = x + + if self.pool_type == "map": + x = out["head_input"] = vit.MAPHead( + num_heads=self.num_heads, mlp_dim=self.mlp_dim)(x) + elif self.pool_type == "gap": + x = out["head_input"] = jnp.mean(x, axis=1) + elif self.pool_type == "tok": + x = out["head_input"] = x[:, 0] + encoded = encoded[:, 1:] + else: + raise ValueError(f"Unknown pool type: '{self.pool_type}'") + + x_2d = jnp.reshape(encoded, [n, h, w, -1]) + + out["pre_logits_2d"] = x_2d + out["pre_logits"] = x + + if self.num_classes: + kw = {"kernel_init": nn.initializers.zeros} if self.head_zeroinit else {} + head = nn.Dense(self.num_classes, name="head", **kw) + x_2d = out["logits_2d"] = head(x_2d) + x = out["logits"] = head(x) + + return x, out + + +def Model(num_classes, *, variant=None, **kw): # pylint: disable=invalid-name + """Factory function, because linen really don't like what I'm doing!""" + return _Model(num_classes, **{**vit.decode_variant(variant), **kw}) + + +def load(init_params, init_file, model_cfg, dont_load=()): # pylint: disable=invalid-name because we had to CamelCase above. + """Load init from checkpoint, both old model and this one. +Hi-res posemb.""" + init_file = {**vit.VANITY_NAMES, **VANITY_NAMES}.get(init_file, init_file) + restored_params = utils.load_params(init_file) + + restored_params = vit.fix_old_checkpoints(restored_params) + + # Potentially resize the position embedings if seqlen differs. + restored_params["pos_embedding"] = vit.resample_posemb( + old=restored_params["pos_embedding"], + new=init_params["pos_embedding"]) + + # Potentially resize the patch embedding kernel. + old_patchemb = restored_params["embedding"]["kernel"] + restored_params["embedding"]["kernel"] = resample_patchemb( + old=old_patchemb, new_hw=model_cfg.patch_size) + + # possibly use the random init for some of the params (such as, the head). + restored_params = common.merge_params(restored_params, init_params, dont_load) + + return restored_params + + +# Shortcut names for some canonical paper checkpoints: +VANITY_NAMES = { + # pylint: disable=line-too-long + "FlexiViT-L i1k": "gs://big_vision/flexivit/flexivit_l_i1k.npz", + "FlexiViT-B i1k": "gs://big_vision/flexivit/flexivit_b_i1k.npz", + "FlexiViT-S i1k": "gs://big_vision/flexivit/flexivit_s_i1k.npz", + "FlexiViT-B i21k 90ep": "gs://big_vision/flexivit/flexivit_b_i21k_90ep.npz", + "FlexiViT-B i21k 300ep": "gs://big_vision/flexivit/flexivit_b_i21k_300ep.npz", + "FlexiViT-B i21k 1000ep": "gs://big_vision/flexivit/flexivit_b_i21k_1000ep.npz", + "ViT-B/16 i21k": "gs://big_vision/flexivit/vit_b16_i21k_300ep.npz", + "ViT-B/30 i21k": "gs://big_vision/flexivit/vit_b30_i21k_300ep.npz", + # pylint: enable=line-too-long +} diff --git a/big_vision/models/proj/flexi/vit_test.py b/big_vision/models/proj/flexi/vit_test.py new file mode 100644 index 0000000000000000000000000000000000000000..f7dd588e3a4231ece5ac0a92771fddcac803b66d --- /dev/null +++ b/big_vision/models/proj/flexi/vit_test.py @@ -0,0 +1,127 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the FlexiViT model.""" + +from absl.testing import absltest +from big_vision.models.proj.flexi import vit +import jax +from jax import config +from jax import numpy as jnp +import numpy as np +import tensorflow as tf + +config.update("jax_enable_x64", True) + + +class PatchEmbTest(absltest.TestCase): + + def _test_patch_emb_resize(self, old_shape, new_shape, n_patches=100): + # This test verifies that if we resize the input image patch and resample + # the patch embedding accordingly, the output does not change. + # NOTE: if the image contains more than one patch, then the embeddings will + # change due to patch interaction during the resizing. + patch_shape = old_shape[:-2] + resized_patch_shape = new_shape[:-2] + patches = np.random.randn(n_patches, *old_shape[:-1]) + w_emb = jnp.asarray(np.random.randn(*old_shape)) + + old_embeddings = jax.lax.conv_general_dilated( + patches, w_emb, window_strides=patch_shape, padding="VALID", + dimension_numbers=("NHWC", "HWIO", "NHWC"), precision="highest") + + patch_resized = tf.image.resize( + tf.constant(patches), resized_patch_shape, method="bilinear").numpy() + patch_resized = jnp.asarray(patch_resized).astype(jnp.float64) + w_emb_resampled = vit.resample_patchemb(w_emb, resized_patch_shape) + self.assertEqual(w_emb_resampled.shape, new_shape) + + new_embeddings = jax.lax.conv_general_dilated( + patch_resized, w_emb_resampled, window_strides=resized_patch_shape, + padding="VALID", dimension_numbers=("NHWC", "HWIO", "NHWC"), + precision="highest") + + self.assertEqual(old_embeddings.shape, new_embeddings.shape) + np.testing.assert_allclose( + old_embeddings, new_embeddings, rtol=1e-1, atol=1e-4) + + def test_resize_square(self): + out_channels = 256 + patch_sizes = [48, 40, 30, 24, 20, 16, 15, 12, 10, 8, 6, 5] + for s in patch_sizes: + old_shape = (s, s, 3, out_channels) + for t in patch_sizes: + new_shape = (t, t, 3, out_channels) + if s <= t: + self._test_patch_emb_resize(old_shape, new_shape) + + def test_resize_rectangular(self): + out_channels = 256 + old_shape = (8, 10, 3, out_channels) + new_shape = (10, 12, 3, out_channels) + self._test_patch_emb_resize(old_shape, new_shape) + + old_shape = (8, 6, 3, out_channels) + new_shape = (9, 15, 3, out_channels) + self._test_patch_emb_resize(old_shape, new_shape) + + old_shape = (8, 6, 3, out_channels) + new_shape = (15, 9, 3, out_channels) + self._test_patch_emb_resize(old_shape, new_shape) + + def test_input_channels(self): + out_channels = 256 + for c in [1, 3, 10]: + old_shape = (8, 10, c, out_channels) + new_shape = (10, 12, c, out_channels) + self._test_patch_emb_resize(old_shape, new_shape) + + def _test_works(self, old_shape, new_shape): + old = jnp.asarray(np.random.randn(*old_shape)) + resampled = vit.resample_patchemb(old, new_shape[:2]) + self.assertEqual(resampled.shape, new_shape) + self.assertEqual(resampled.dtype, old.dtype) + + def test_downsampling(self): + # NOTE: for downsampling we cannot guarantee that the outputs would match + # before and after downsampling. So, we simply test that the code runs and + # produces an output of the correct shape and type. + out_channels = 256 + for t in [4, 5, 6, 7]: + for c in [1, 3, 5]: + old_shape = (8, 8, c, out_channels) + new_shape = (t, t, c, out_channels) + self._test_works(old_shape, new_shape) + + def _test_raises(self, old_shape, new_shape): + old = jnp.asarray(np.random.randn(*old_shape)) + with self.assertRaises(AssertionError): + vit.resample_patchemb(old, new_shape) + + def test_raises_incorrect_dims(self): + old_shape = (8, 10, 3, 256) + new_shape = (10, 12, 1, 256) + self._test_raises(old_shape, new_shape) + + old_shape = (8, 10, 1, 256) + new_shape = (10, 12, 3, 256) + self._test_raises(old_shape, new_shape) + + old_shape = (8, 10, 3, 128) + new_shape = (10, 12, 3, 256) + self._test_raises(old_shape, new_shape) + + +if __name__ == "__main__": + absltest.main() diff --git a/big_vision/models/proj/givt/adaptor.py b/big_vision/models/proj/givt/adaptor.py new file mode 100644 index 0000000000000000000000000000000000000000..6b76c7704beee10a5d57f822577f490cb1ad1d59 --- /dev/null +++ b/big_vision/models/proj/givt/adaptor.py @@ -0,0 +1,174 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Invertible adaptor based on iRevNet. + +Based on the PyTorch version from: +https://github.com/jhjacobsen/pytorch-i-revnet/blob/master/models/iRevNet.py +""" + +from typing import Any, Optional, Sequence + +from big_vision import utils +from big_vision.models import common +from big_vision.models.proj.givt import cnn +import einops +import flax.core +import flax.linen as nn +import jax +import jax.numpy as jnp + + +def _split(x: jax.Array) -> tuple[jax.Array, jax.Array]: + n = x.shape[-1] // 2 + x1 = x[:, :, :, :n] + x2 = x[:, :, :, n:] + return x1, x2 + + +def _merge(x1: jax.Array, x2: jax.Array) -> jax.Array: + return jnp.concatenate((x1, x2), axis=-1) + + +class IRevNetBlock(nn.Module): + """iRevNet Block.""" + first: int = False + dropout_rate: float = 0. + num_channels: int = 2 + num_channels_bottleneck: Optional[int] = None + num_grps_norm: int = 32 + + @nn.compact + def _fx2(self, x: jax.Array, train: bool = True) -> jax.Array: + if not self.first: + y = nn.GroupNorm(num_groups=self.num_grps_norm, name="gn_0")(x) + y = nn.relu(y) + else: + y = x + + ks = (3, 3) # hardcode kernel-size 3 for now + y = nn.Conv(self.num_channels_bottleneck or self.num_channels, + kernel_size=ks, padding=1, use_bias=False)(y) + y = nn.GroupNorm(num_groups=self.num_grps_norm, name="gn_1")(y) + y = nn.relu(y) + + y = nn.Conv(self.num_channels_bottleneck or self.num_channels, + kernel_size=ks, padding=1, use_bias=False)(y) + y = nn.Dropout(rate=self.dropout_rate, deterministic=(not train))(y) + y = nn.GroupNorm(num_groups=self.num_grps_norm, name="gn_2")(y) + y = nn.relu(y) + + y = nn.Conv(self.num_channels, kernel_size=ks, padding=1, use_bias=False)(y) + + return y + + def forward( + self, + x: tuple[jax.Array, jax.Array], + train: bool = True, + ) -> tuple[jax.Array, jax.Array]: + """Bijective block forward.""" + x1, x2 = x[0], x[1] + fx2 = self._fx2(x2, train=train) + y1 = fx2 + x1 + return (x2, y1) + + def inverse(self, + x: tuple[jax.Array, jax.Array], + train: bool = True + ) -> tuple[jax.Array, jax.Array]: + """Bijective block inverse.""" + x2, y1 = x[0], x[1] + fx2 = -self._fx2(x2, train=train) + x1 = fx2 + y1 + return (x1, x2) + + +class IRevNet(nn.Module): + """iRevNet.""" + num_blocks: int = 4 + num_channels: int = 4 + num_channels_bottleneck: Optional[int] = None + dropout_rate: float = 0.0 + + def setup(self) -> None: + num_grps_norm = min(32, self.num_channels // 2) + self.modules = [ + IRevNetBlock( + first=(i == 0), + num_channels=self.num_channels // 2, + num_channels_bottleneck=( + self.num_channels_bottleneck or self.num_channels) // 2, + num_grps_norm=num_grps_norm, + dropout_rate=self.dropout_rate, + ) + for i in range(self.num_blocks) + ] + + def forward(self, x: jax.Array, train: bool = True) -> jax.Array: + out = _split(x) + for m in self.modules: + out = m.forward(out, train=train) + out_bij = _merge(out[0], out[1]) + return out_bij + + def inverse(self, out_bij: jax.Array, train: bool = True) -> jax.Array: + out = _split(out_bij) + for m in reversed(self.modules): + out = m.inverse(out, train=train) + out = _merge(out[0], out[1]) + return out + + def __call__(self, x: jax.Array, train: bool = True) -> jax.Array: + return self.forward(x, train=train) + + +class Model(IRevNet): + """Wrapper for IRevNet to function as an adaptor in our setup.""" + + pixel_shuffle_patch_size: tuple[int, int] = (1, 1) + + def forward(self, x: jax.Array, train: bool = True) -> jax.Array: + # (b, code_len, ch) --> (b, h, w, ch) --> (b, code_len, ch) + # h, w are the spatial dimensions after space-to-depth transformation + h, w = cnn.get_h_w_pixelshuffle(x.shape[1], self.pixel_shuffle_patch_size) + x = einops.rearrange(x, "b (h w) c -> b h w c", h=h, w=w) + x = super().forward(x, train) + x = einops.rearrange(x, "b h w c -> b (h w) c") # (b, codelen, codeword_d) + + return x + + def inverse(self, out_bij: jax.Array, train: bool = True) -> jax.Array: + # (b, code_len, ch) --> (b, h, w, ch) --> (b, code_len, ch) + h, w = cnn.get_h_w_pixelshuffle( + out_bij.shape[1], self.pixel_shuffle_patch_size) + out_bij = einops.rearrange(out_bij, "b (h w) c -> b h w c", h=h, w=w) + out_bij = super().inverse(out_bij, train) + out_bij = einops.rearrange(out_bij, "b h w c -> b (h w) c") + + return out_bij + + +def load( + init_params: Any, + init_file: str, + model_params: Any = None, + dont_load: Sequence[str] = (), +) -> Any: + """Loads params from init checkpoint and merges into init_params.""" + del model_params + ckpt_params = flax.core.unfreeze(utils.load_params(init_file)) + if init_params is not None: + ckpt_params = common.merge_params(ckpt_params, init_params, dont_load) + return ckpt_params diff --git a/big_vision/models/proj/givt/adaptor_test.py b/big_vision/models/proj/givt/adaptor_test.py new file mode 100644 index 0000000000000000000000000000000000000000..e197d1d2ba6c5e3247bb14cd533e2b916f373779 --- /dev/null +++ b/big_vision/models/proj/givt/adaptor_test.py @@ -0,0 +1,50 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the IRevNet adaptor.""" + +from big_vision.models.proj.givt import adaptor +import jax +from jax import random +import jax.numpy as jnp + +from absl.testing import absltest + + +class AdaptorTest(googletest.TestCase): + + def test_inversion(self): + num_channels = 8 + input_shape = (1, 24, 24, num_channels) + + rng = random.PRNGKey(758493) + _, inp_rng, init_rng, data_rng = jax.random.split(rng, 4) + + dummy_x = random.normal(inp_rng, shape=input_shape) + real_x = jax.random.normal(data_rng, shape=input_shape) + + model = adaptor.IRevNet( + num_blocks=4, + num_channels=num_channels, + dropout_rate=0.0, + ) + params = model.init(init_rng, dummy_x) + + real_y = model.apply(params, real_x, method=model.forward) + real_x_ = model.apply(params, real_y, method=model.inverse) + self.assertTrue(jnp.allclose(real_x, real_x_, atol=1e-5)) + + +if __name__ == "__main__": + googletest.main() diff --git a/big_vision/models/proj/givt/cnn.py b/big_vision/models/proj/givt/cnn.py new file mode 100644 index 0000000000000000000000000000000000000000..27896b3d365c75aa7ec62edd0f872a98ad6afe3f --- /dev/null +++ b/big_vision/models/proj/givt/cnn.py @@ -0,0 +1,376 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""CNN encoder/decoder architecture based on the VQ-GAN and MaskGIT papers. + +Adapted from https://github.com/google-research/maskgit/blob/main/maskgit/nets/vqgan_tokenizer.py. # pylint: disable=line-too-long +""" + +import dataclasses +import functools +import math +from typing import Any, Sequence + +from big_vision import utils +from big_vision.models import common +from big_vision.models.proj.givt import vae + +import einops +import flax.linen as nn +import flax.training.checkpoints + +import jax +import jax.numpy as jnp + + +def _get_norm_layer(train, dtype, norm_type="BN"): + """Create normalization layers. + + Args: + train: Whether to use the layer in training or inference mode. + dtype: Layer output type. + norm_type: Which normalization to use "BN", "LN", or "GN". + + Returns: + An instance of the the layer. + """ + if norm_type == "BN": + return functools.partial( + nn.BatchNorm, + use_running_average=not train, + momentum=0.9, + epsilon=1e-5, + axis_name=None, + axis_index_groups=None, + dtype=jnp.float32, + use_fast_variance=False) + elif norm_type == "LN": + return functools.partial(nn.LayerNorm, dtype=dtype, use_fast_variance=False) + elif norm_type == "GN": + return functools.partial(nn.GroupNorm, dtype=dtype, use_fast_variance=False) + else: + raise NotImplementedError + + +def _tensorflow_style_avg_pooling(x, window_shape, strides, padding: str): + """Avg pooling as done by TF (Flax layer gives different results). + + To be specific, Flax includes padding cells when taking the average, + while TF does not. + + Args: + x: Input tensor + window_shape: Shape of pooling window; if 1-dim tuple is just 1d pooling, if + 2-dim tuple one gets 2d pooling. + strides: Must have the same dimension as the window_shape. + padding: Either 'SAME' or 'VALID' to indicate pooling method. + + Returns: + pooled: Tensor after applying pooling. + """ + pool_sum = jax.lax.reduce_window(x, 0.0, jax.lax.add, + (1,) + window_shape + (1,), + (1,) + strides + (1,), padding) + pool_denom = jax.lax.reduce_window( + jnp.ones_like(x), 0.0, jax.lax.add, (1,) + window_shape + (1,), + (1,) + strides + (1,), padding) + return pool_sum / pool_denom + + +def _upsample(x, factor=2, method="nearest"): + n, h, w, c = x.shape + x = jax.image.resize(x, (n, h * factor, w * factor, c), method=method) + return x + + +def _dsample(x): + return _tensorflow_style_avg_pooling( + x, (2, 2), strides=(2, 2), padding="same") + + +def get_h_w_pixelshuffle(hw, pixel_shuffle_patch_size): + # Compute h, w after space-to-depth transformation and before flattening, + # assuming the imge before space-to-depth transformation was square. + ph, pw = pixel_shuffle_patch_size + s = int(math.sqrt(hw * ph * pw)) + h, w = s // ph, s // pw + assert h * w == hw, f"Length {hw} incompatible with pixelshuffle ({ph}, {pw})" + return h, w + + +class ResBlock(nn.Module): + """Basic Residual Block.""" + filters: int + norm_fn: Any + conv_fn: Any + dtype: int = jnp.float32 + activation_fn: Any = nn.relu + use_conv_shortcut: bool = False + + @nn.compact + def __call__(self, x: jax.Array) -> jax.Array: + input_dim = x.shape[-1] + residual = x + x = self.norm_fn()(x) + x = self.activation_fn(x) + x = self.conv_fn(self.filters, kernel_size=(3, 3), use_bias=False)(x) + x = self.norm_fn()(x) + x = self.activation_fn(x) + x = self.conv_fn(self.filters, kernel_size=(3, 3), use_bias=False)(x) + if input_dim != self.filters: + if self.use_conv_shortcut: + residual = self.conv_fn( + self.filters, kernel_size=(3, 3), use_bias=False)( + x) + else: + residual = self.conv_fn( + self.filters, kernel_size=(1, 1), use_bias=False)( + x) + return x + residual + + +class Encoder(nn.Module): + """Encoder Blocks.""" + + filters: int + num_res_blocks: int + channel_multipliers: list[int] + embedding_dim: int + conv_downsample: bool = False + norm_type: str = "GN" + activation_fn_str: str = "swish" + dtype: int = jnp.float32 + + def setup(self) -> None: + if self.activation_fn_str == "relu": + self.activation_fn = nn.relu + elif self.activation_fn_str == "swish": + self.activation_fn = nn.swish + else: + raise NotImplementedError + + @nn.compact + def __call__(self, x: jax.Array, train: bool = False) -> jax.Array: + conv_fn = nn.Conv + norm_fn = _get_norm_layer( + train=train, dtype=self.dtype, norm_type=self.norm_type) + block_args = dict( + norm_fn=norm_fn, + conv_fn=conv_fn, + dtype=self.dtype, + activation_fn=self.activation_fn, + use_conv_shortcut=False, + ) + x = conv_fn(self.filters, kernel_size=(3, 3), use_bias=False)(x) + num_blocks = len(self.channel_multipliers) + for i in range(num_blocks): + filters = self.filters * self.channel_multipliers[i] + for _ in range(self.num_res_blocks): + x = ResBlock(filters, **block_args)(x) + if i < num_blocks - 1: + if self.conv_downsample: + x = conv_fn(filters, kernel_size=(4, 4), strides=(2, 2))(x) + else: + x = _dsample(x) + for _ in range(self.num_res_blocks): + x = ResBlock(filters, **block_args)(x) + x = norm_fn()(x) + x = self.activation_fn(x) + x = conv_fn(self.embedding_dim, kernel_size=(1, 1))(x) + return x + + +class Decoder(nn.Module): + """Decoder Blocks.""" + + filters: int + num_res_blocks: int + channel_multipliers: list[int] + norm_type: str = "GN" + activation_fn_str: str = "swish" + output_dim: int = 3 + dtype: Any = jnp.float32 + + def setup(self) -> None: + if self.activation_fn_str == "relu": + self.activation_fn = nn.relu + elif self.activation_fn_str == "swish": + self.activation_fn = nn.swish + else: + raise NotImplementedError + + @nn.compact + def __call__(self, x: jax.Array, train: bool = False) -> jax.Array: + conv_fn = nn.Conv + norm_fn = _get_norm_layer( + train=train, dtype=self.dtype, norm_type=self.norm_type) + block_args = dict( + norm_fn=norm_fn, + conv_fn=conv_fn, + dtype=self.dtype, + activation_fn=self.activation_fn, + use_conv_shortcut=False, + ) + num_blocks = len(self.channel_multipliers) + filters = self.filters * self.channel_multipliers[-1] + x = conv_fn(filters, kernel_size=(3, 3), use_bias=True)(x) + for _ in range(self.num_res_blocks): + x = ResBlock(filters, **block_args)(x) + for i in reversed(range(num_blocks)): + filters = self.filters * self.channel_multipliers[i] + for _ in range(self.num_res_blocks): + x = ResBlock(filters, **block_args)(x) + if i > 0: + x = _upsample(x, 2) + x = conv_fn(filters, kernel_size=(3, 3))(x) + x = norm_fn()(x) + x = self.activation_fn(x) + x = conv_fn(self.output_dim, kernel_size=(3, 3))(x) + return x + + +class Model(vae.Model): + """CNN Model.""" + + filters: int = 128 + num_res_blocks: int = 2 + channel_multipliers: list[int] = dataclasses.field(default_factory=list) + conv_downsample: bool = False + activation_fn: str = "swish" + norm_type: str = "GN" + output_dim: int = 3 + dtype: Any = jnp.float32 + # If True, rescale the input [-1, 1] -> [0, 1] and clip logvar to [-30, 20] + malib_ckpt: bool = False + pixel_shuffle_patch_size: tuple[int, int] = (1, 1) + + def setup(self) -> None: + # Encoder and decoder + self.encoder = Encoder( + filters=self.filters, + num_res_blocks=self.num_res_blocks, + channel_multipliers=self.channel_multipliers, + norm_type=self.norm_type, + activation_fn_str=self.activation_fn, + embedding_dim=2 * self.codeword_dim, + conv_downsample=self.conv_downsample, + dtype=self.dtype, + name="cnn_encoder", + ) + self.decoder = Decoder( + filters=self.filters, + num_res_blocks=self.num_res_blocks, + channel_multipliers=self.channel_multipliers, + norm_type=self.norm_type, + activation_fn_str=self.activation_fn, + output_dim=self.output_dim, + dtype=self.dtype, + name="cnn_decoder", + ) + + def _maybe_rescale_input(self, x): + return (x + 1.0) / 2.0 if self.malib_ckpt else x + + def _maybe_rescale_output(self, x): + return 2.0 * x - 1.0 if self.malib_ckpt else x + + def _maybe_clip_logvar(self, logvar): + return jnp.clip(logvar, -30.0, 20.0) if self.malib_ckpt else logvar + + def encode( + self, + x: jax.Array, + *, + train: bool = False, + ) -> tuple[jax.Array, jax.Array]: + x = self._maybe_rescale_input(x) + x = self.encoder(x, train=train) # (2, 16, 16, 64) + assert x.shape[1] == x.shape[2], f"Square spatial dims. required: {x.shape}" + mu, logvar = jnp.split(x, 2, axis=-1) # (2, 16, 16, 32) x 2 + logvar = self._maybe_clip_logvar(logvar) + + def _space_to_depth(z): + ph, pw = self.pixel_shuffle_patch_size + return einops.rearrange( + z, "b (h ph) (w pw) c -> b (h w) (c ph pw)", + ph=ph, pw=pw + ) # (2, 256 // (ph * pw), 64 * ph * pw) + + mu, logvar = _space_to_depth(mu), _space_to_depth(logvar) + + return mu, logvar + + def decode(self, x: jax.Array, train: bool = False) -> jax.Array: + # Decode + ph, pw = self.pixel_shuffle_patch_size + h, w = get_h_w_pixelshuffle(x.shape[1], (ph, pw)) + + x = einops.rearrange( + x, "b (h w) (c ph pw) -> b (h ph) (w pw) c", + h=h, w=w, + ph=ph, pw=pw + ) # (2, 16, 16, 32) + x = self.decoder(x, train=train) # (2, 256, 256, 3) + x = self._maybe_rescale_output(x) + x = jnp.clip(x, -1.0, 1.0) + + return x + + +def load( + init_params: Any, + init_file: str, + model_params: Any = None, + dont_load: Sequence[str] = (), + malib_ckpt: bool = False, + use_ema_params: bool = False, +) -> Any: + """Loads params from init checkpoint and merges into init_params. + + Args: + init_params: pytree with (previously initialized) model parameters. + init_file: Path of the checkpoint to load. + model_params: Dict containing the model config. + dont_load: Sequence of (flattened) parameter names which should not be + loaded. + malib_ckpt: Whether the given init_file is a malib checkpoint. + use_ema_params: Whether to load the EMA params (for malib checkpoints). + + Returns: + pytree containing the loaded model parameters. + """ + # `model_params` is unused here, but we still include it to conform with the + # general big_vision interface, cf. the core models in big_vision/models/. + del model_params + + assert malib_ckpt or (not use_ema_params), ( + "Loading EMA parameters is only supported for malib checkpoints.") + + if malib_ckpt: + # Locally disable transfer guard since restore_checkpoint does not allow for + # fine-grained sharding control. + with jax.transfer_guard("allow"): + vaegan_params = flax.training.checkpoints.restore_checkpoint( + init_file, None) + vaegan_params_flat = utils.tree_flatten_with_names(vaegan_params)[0] + prefix_old = "ema_params/" if use_ema_params else "g_params/" + vaegan_params_flat = [(k.replace(prefix_old, "cnn_"), v) + for k, v in vaegan_params_flat if prefix_old in k] + params = utils.tree_unflatten(vaegan_params_flat) + else: + params = flax.core.unfreeze(utils.load_params(init_file)) + + if init_params is not None: + params = common.merge_params(params, init_params, dont_load) + return params diff --git a/big_vision/models/proj/givt/decode.py b/big_vision/models/proj/givt/decode.py new file mode 100644 index 0000000000000000000000000000000000000000..14f685185ff51fa0e47e517c538c68ebc82df8a3 --- /dev/null +++ b/big_vision/models/proj/givt/decode.py @@ -0,0 +1,386 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Autorgregressive sampler for GIVT.""" + +import functools +from typing import Any, Optional + +from big_vision.models.proj.givt import parallel_decode +import flax +from flax import linen as nn +import jax +from jax import lax +from jax import numpy as jnp +import ml_collections + + +def _sample_gmm( + gmm_pdf, + *, + rng, + cfg_inference_weight=None, + gmm_pdf_uncond=None, +): + """Draw a single sample from a GMM.""" + if cfg_inference_weight is not None: + assert gmm_pdf_uncond is not None + gmm_pdf = parallel_decode.CFGDensity( + gmm_pdf, gmm_pdf_uncond, w=cfg_inference_weight, rng=rng + ) + samples = gmm_pdf.sample(seed=rng) + logprobs = gmm_pdf.log_prob(samples) + if logprobs.ndim == 2: + logprobs = logprobs[..., None] + return samples, logprobs + + +# Beam search reshaping utils +def _flatten_samples_dim(x): + """Flattens samples dimension into batch dimension.""" + if x.ndim == 0: # ignore scalars (e.g. cache index) + return x + return x.reshape((x.shape[0] * x.shape[1],) + x.shape[2:]) + + +def _unflatten_samples_dim(x, batch_size, num_samples): + """Unflattens first dimension into batch and samples dimensions.""" + if x.ndim == 0: # ignore scalars (e.g. cache index) + return x + assert batch_size * num_samples == x.shape[0] + return x.reshape((batch_size, num_samples) + x.shape[1:]) + + +def _cache_map(fn, cache, scan=False): + """Maps function over cache.""" + if scan: + # Assuming the chache is scanned over the first dimension, we apply a map + # function over this dimension for scanned models + fn_mod = lambda x: jax.lax.map(fn, x) if x.ndim > 0 else fn(x) + else: + fn_mod = fn + + frozen = isinstance(cache, flax.core.FrozenDict) + if frozen: + cache = flax.core.unfreeze(cache) + flat_cache = flax.traverse_util.flatten_dict(cache) + # Exclude cached relative position bias from beam expansion, etc. + keyvals = {k: v for k, v in flat_cache.items() if k[-1] != "cached_bias"} + keyvals = jax.tree_map(fn_mod, keyvals) + flat_cache.update(keyvals) + new_cache = flax.traverse_util.unflatten_dict(flat_cache) + if frozen: + new_cache = flax.core.freeze(new_cache) + return new_cache + + +@flax.struct.dataclass +class LoopState: + """Internal state of the sampling loop.""" + # Terminology + # b: batch size + # nb: number of beams + # nf: number of fans + # s: seaquence length + # d: feature dimension + rng: jnp.ndarray # PRNGKey of the loop state. + cache: Any # Cache for fast auto-regressive decoding. + sequences: jnp.ndarray # (b * nb, s, d) + logprobs: jnp.ndarray # (b * nb, s, d) + cache_u: Any # Uncond cache if cfg, otherwise None + + +def _create_cache( + labels, + model, + init_sequence, + params, + encoded, + uncond=False, +): + """Creates the cache and returns initial logits.""" + if uncond: + assert labels is not None # Need labels for CFG! + drop_labels = jnp.ones((labels.shape[0],), dtype=jnp.bool_) + else: + drop_labels = None + + def init_cache(model): + return model.decode( + init_sequence, labels, encoded, decode=True, drop_labels=drop_labels + ) + + cache = nn.apply(init_cache, model, mutable=True)(params)[1]["cache"] + + def prefill_cache(model): + return model.prefill( + labels, init_sequence.shape[0], encoded, drop_labels=drop_labels + ) + + # prefill class label or BOS token + prefill_logits, aux = nn.apply(prefill_cache, model, mutable=True)( + {"params": params["params"], "cache": cache}) + cache = aux["cache"] + return cache, prefill_logits + + +def generate( + params: Any, + seed: jax.Array, + *, + model: nn.Module, + seq_len: int, + feature_dim: int, + labels: Optional[jnp.ndarray] = None, + cond_image: Optional[jnp.ndarray] = None, + batch_size: Optional[int] = None, + config: Optional[ml_collections.ConfigDict] = None, +) -> tuple[jax.Array, jax.Array]: + """Sampling loop for GIVT.""" + if model.style != "ar": # pytype: disable=wrong-arg-types + raise ValueError(f"Invalid style: {model.style}") + if model.has_encoder != (cond_image is not None): + raise ValueError("Need cond_image if and only if the model has an encoder!") + + assert labels is not None or batch_size, ( + "Please provide either labels or batch_size.") + + config = config or {} + config = dict(config) # copy + + # For sampling, we support keep_gt (a bool mask), and gt (ground truth) + # tokens to use instead of samples. + keep_gt = config.pop("keep_gt", None) + gt = config.pop("gt", None) + + if isinstance(seed, int): + seed = jax.random.PRNGKey(seed) + + beam_size = config.pop("beam_size", 1) + fan_size = config.pop("fan_size", 1) + + if labels is not None: + batch_size = labels.shape[0] + # fold beams into batch dimension + labels = labels.repeat(beam_size, axis=0) + + # initialize sequence and logprobs (we track per feature dim logprobs) + init_sequence = jnp.zeros((batch_size * beam_size, seq_len, feature_dim)) + init_logprobs = jnp.zeros_like(init_sequence) + + if cond_image is not None: + # embed conditioning image if provided + def encode_cond_img(model, cond_img): + return model.encode(cond_img) + encoded = nn.apply(encode_cond_img, model)(params, cond_image) + encoded = jnp.repeat(encoded, beam_size, axis=0) + else: + encoded = None + + cache, prefill_logits = _create_cache( + labels, model, init_sequence, params, encoded + ) + + cfg_inference_weight = config.pop("cfg_inference_weight", None) + if cfg_inference_weight == 0.0: + cfg_inference_weight = None + cfg = cfg_inference_weight is not None + + get_pdf = functools.partial( + model.get_pdf, + temperature_scales=config.pop("temp", None), + temperature_probs=config.pop("temp_probs", None), + ) + + # setup sampling function + sample = functools.partial( + _sample_gmm, cfg_inference_weight=cfg_inference_weight + ) + + # draw first output token + pdf_first = get_pdf(prefill_logits) + rng_first, rng = jax.random.split(seed) + + if cfg: + assert beam_size == 1 and fan_size == 1 # CFG + Beam not supported. + cache_u, prefill_logits_u = _create_cache( + labels, model, init_sequence, params, encoded, uncond=True + ) + pdf_first_u = get_pdf(prefill_logits_u) + else: + cache_u = None + pdf_first_u = None + + tokens_first, logprobs_first = sample( + pdf_first, rng=rng_first, gmm_pdf_uncond=pdf_first_u + ) + init_sequence = init_sequence.at[:, 0].set(tokens_first.squeeze(axis=1)) + init_logprobs = init_logprobs.at[:, 0].set(logprobs_first.squeeze(axis=1)) + + def tokens_to_logits(tokens, cache, uncond=False): + if uncond: + drop_labels = jnp.ones((labels.shape[0],), dtype=jnp.bool_) + else: + drop_labels = None + + def decode_step(model, tokens): + return model.decode(tokens, labels, encoded, + decode=True, drop_labels=drop_labels) + + logits, aux = nn.apply(decode_step, model, mutable=True)( + {"params": params["params"], "cache": cache}, tokens) + return logits, aux["cache"] + + init_state = LoopState( + cache=cache, + sequences=init_sequence, # (b * nb, s, d) + logprobs=init_logprobs, # (b * nb, s, d) + rng=rng, + cache_u=cache_u, + ) + + rand_top_k = config.pop("rand_top_k", False) + rand_top_k_temp = config.pop("rand_top_k_temp", 1.0) + + assert not config, f"Sampling config is expected to be empty: {config}" + + def sampling_iteration(i, state): + rng_sampling, rng_local = jax.random.split(state.rng) + cur_tokens = state.sequences[:, i][:, None] + # (b * nb, d) + cur_logits, cache = tokens_to_logits(cur_tokens, state.cache) + + # (b, nb, d) + cur_logits = _unflatten_samples_dim( + cur_logits, batch_size, beam_size).squeeze(axis=2) + + # (b, nb * nf, d) + cur_pdf = get_pdf(cur_logits.repeat(fan_size, axis=1)) + + if cfg: + cur_logits_u, cache_u = tokens_to_logits( + cur_tokens, state.cache_u, uncond=True + ) + cur_logits_u = _unflatten_samples_dim( + cur_logits_u, batch_size, beam_size).squeeze(axis=2) + cur_pdf_u = get_pdf(cur_logits_u.repeat(fan_size, axis=1)) + new_tokens, new_logprobs = sample( + cur_pdf, rng=rng_sampling, gmm_pdf_uncond=cur_pdf_u + ) + else: + new_tokens, new_logprobs = sample(cur_pdf, rng=rng_sampling) + cache_u = None + + if gt is not None: + assert keep_gt is not None + new_tokens = jnp.where(keep_gt[i], gt[:, i, :][:, None], new_tokens) + + # Skip beam search if not needed + if beam_size == fan_size == 1: + sampled_tokens = new_tokens.squeeze(axis=1) + sequences = state.sequences.at[:, i + 1].set(sampled_tokens) + return LoopState( + cache=cache, + rng=rng_local, + sequences=sequences, + logprobs=state.logprobs, + cache_u=cache_u, + ) + + # (b, nb, s, d) + logprobs = _unflatten_samples_dim(state.logprobs, batch_size, beam_size) + cur_logprobs = logprobs[:, :, i] # (b, nb, d) + # (b, nb * nf, d) + new_logprobs = new_logprobs + cur_logprobs.repeat(fan_size, axis=1) + beam_logprobs = new_logprobs.sum(axis=-1) # (b, nb * nf) + + if rand_top_k: + # randomize top-k sampling via sampling from a categorical distribution + def stoc_top_k(r, x, p): + return jax.random.choice(r, x, shape=(beam_size,), replace=False, p=p) + # construct index grid + index_grid = jnp.arange(beam_logprobs.shape[1], dtype=jnp.int32) + # (b, nb * nf) + index_grid = index_grid[None].repeat(beam_logprobs.shape[0], axis=0) + top_k_rng, rng_local = jax.random.split(rng_local) + top_k_rng = jax.random.split(top_k_rng, beam_logprobs.shape[0]) + # vmap categorical sampling + top_beam_fan_indices = jax.vmap(stoc_top_k, in_axes=(0, 0, 0))( + top_k_rng, + index_grid, + nn.softmax(beam_logprobs / rand_top_k_temp, axis=-1)) + else: + _, top_beam_fan_indices = lax.top_k(beam_logprobs, k=beam_size) # (b, nb) + + top_beam_indices = top_beam_fan_indices // fan_size + + def _gather_beams(x): + if x.ndim == 0: + return x + # checkify.check(jnp.all(top_beam_indices < x.shape[1]), + # f"`take_along_axis` out of bounds in `_gather_beams`: " + # f"{top_beam_indices.max()} vs. {x.shape[1]}") + # (b, nb, 1 ... 1) + expanded_indices = top_beam_indices.reshape( + top_beam_indices.shape + (1,) * (x.ndim - 2)) + return jnp.take_along_axis(x, expanded_indices, axis=1) + + def _gather_tokens(x): + # (b, nb * nf, d) -> (b, nb, d) + # checkify.check(jnp.all(top_beam_fan_indices < x.shape[1]), + # f"`take_along_axis` out of bounds in `_gather_tokens`: " + # f"{top_beam_fan_indices.max()} vs. {x.shape[1]}") + return jnp.take_along_axis(x, top_beam_fan_indices[..., None], axis=1) + # (b, nb, s, d) + sequences = _unflatten_samples_dim(state.sequences, batch_size, beam_size) + sequences = _gather_beams(sequences) # (b, nb, s, d) + sequences = sequences.at[:, :, i + 1].set(_gather_tokens(new_tokens)) + # (b, nb, s, d) + sequences = _flatten_samples_dim(sequences) + + logprobs = _gather_beams(logprobs) + logprobs = logprobs.at[:, :, i + 1].set(_gather_tokens(new_logprobs)) + logprobs = _flatten_samples_dim(logprobs) + + scanned_cache = getattr(model, "scan", False) + cache = _cache_map( + lambda x: _unflatten_samples_dim(x, batch_size, beam_size), + cache, scanned_cache) + cache = _cache_map(_gather_beams, cache, scanned_cache) + cache = _cache_map(_flatten_samples_dim, cache, scanned_cache) + + if cfg: + assert cache_u is not None + cache_u = _cache_map( + lambda x: _unflatten_samples_dim(x, batch_size, beam_size), + cache_u, scanned_cache + ) + cache_u = _cache_map(_gather_beams, cache_u, scanned_cache) + cache_u = _cache_map(_flatten_samples_dim, cache_u, scanned_cache) + else: + assert cache_u is None + + return LoopState( + cache=cache, + rng=rng_local, + sequences=sequences, + logprobs=logprobs, + cache_u=cache_u, + ) + + final_state = lax.fori_loop(0, seq_len, sampling_iteration, init_state) + final_logprobs = final_state.logprobs[::beam_size][:, -1].sum(axis=-1) + + # return top beams and corresponding log probs + return final_state.sequences[::beam_size], final_logprobs diff --git a/big_vision/models/proj/givt/decode_test.py b/big_vision/models/proj/givt/decode_test.py new file mode 100644 index 0000000000000000000000000000000000000000..53b41c09c9bc6a842ac4d565ae5de9a05562819d --- /dev/null +++ b/big_vision/models/proj/givt/decode_test.py @@ -0,0 +1,121 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import parameterized +from big_vision.models.proj.givt import decode +from big_vision.models.proj.givt import givt +import jax +import jax.numpy as jnp + +from absl.testing import absltest + + +_BATCH_SIZE = 2 +_OUT_DIM = 4 +_IMG_DIM = 8 +_PATCH_SIZE = 2 +_SEQ_LEN = _IMG_DIM // _PATCH_SIZE * _IMG_DIM // _PATCH_SIZE +_NUM_MIXTURES = 4 + + +def _make_test_model(**overwrites): + config = dict( + num_heads=2, + num_decoder_layers=1, + mlp_dim=64, + emb_dim=16, + patches=(_PATCH_SIZE, _PATCH_SIZE), + input_size=(_IMG_DIM, _IMG_DIM), + seq_len=_SEQ_LEN, + out_dim=_OUT_DIM, + num_mixtures=_NUM_MIXTURES, + style="ar", + ) + config.update(overwrites) + return givt.Model(**config) + + +class DecodeTest(parameterized.TestCase): + + def _make_model(self, **overwrites): + model = _make_test_model(**overwrites) + sequence = jax.random.uniform( + jax.random.PRNGKey(0), (_BATCH_SIZE, _SEQ_LEN, _OUT_DIM) + ) + labels = jax.random.uniform( + jax.random.PRNGKey(0), (_BATCH_SIZE,), maxval=10 + ).astype(jnp.int32) + variables = model.init( + jax.random.PRNGKey(0), + sequence, + labels, + train=False, + image=jnp.zeros((_BATCH_SIZE, _IMG_DIM, _IMG_DIM, 3), dtype=jnp.float32) + if model.has_encoder + else None, + ) + return model, variables + + def _test_model(self, rng, model, variables, config): + labels = jnp.ones((_BATCH_SIZE,), dtype=jnp.int32) + if model.has_encoder: + cond_image = jnp.zeros( + (_BATCH_SIZE, _IMG_DIM, _IMG_DIM, 3), dtype=jnp.float32 + ) + else: + cond_image = None + result, logprobs = decode.generate( + params=variables, + seed=rng, + seq_len=_SEQ_LEN, + feature_dim=_OUT_DIM, + labels=labels, + model=model, + config=config, + cond_image=cond_image, + ) + # TODO: More expressive tests? Eg for causality, and caching. + self.assertEqual(result.shape, (_BATCH_SIZE, _SEQ_LEN, _OUT_DIM)) + self.assertTrue(jnp.allclose(logprobs, jnp.zeros_like(logprobs), atol=1e-5)) + + @parameterized.product( + rng_seed=[1, 2], + encoder=[True, False], + ) + def test_simple(self, rng_seed, encoder): + rng = jax.random.PRNGKey(rng_seed) + model, variables = self._make_model( + num_layers=1 if encoder else 0 + ) + assert model.has_encoder == encoder + self._test_model(rng, model, variables, config={}) + + @parameterized.product( + rng_seed=[1, 2], + cfg_inference_weight=[0.0, 1.0, 3.0], + per_channel_mixtures=[True, False], + ) + def test_cfg(self, rng_seed, cfg_inference_weight, per_channel_mixtures): + rng = jax.random.PRNGKey(rng_seed) + model, variables = self._make_model( + num_mixtures=1 if per_channel_mixtures else 3, + drop_labels_probability=0.1, + per_channel_mixtures=per_channel_mixtures, + ) + config = {"cfg_inference_weight": cfg_inference_weight} + self._test_model(rng, model, variables, config) + + +if __name__ == "__main__": + googletest.main() diff --git a/big_vision/models/proj/givt/givt.py b/big_vision/models/proj/givt/givt.py new file mode 100644 index 0000000000000000000000000000000000000000..0bd4c84427fde5873c33dd7e76e6a289e163cb22 --- /dev/null +++ b/big_vision/models/proj/givt/givt.py @@ -0,0 +1,820 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Decoder-only and encoder-decoder GIVT model. + +Used abbreviations for dimension annotations: + B: batch size. + E: embedding size. + L: (soft) token sequence length. + D: soft token dimension. + P: number of patches (extracted by a ViT encoder in GIVT-based UViM) +""" + +import enum +import itertools +from typing import Literal, Optional, Sequence, Any, Mapping + +from absl import logging +from big_vision import utils +from big_vision.models import common +from big_vision.models import vit +import distrax +import einops +import flax.linen as nn +from flax.linen import partitioning +import jax +import jax.numpy as jnp +import numpy as np + + +class _SpecialLabel(enum.Enum): + + MASK = "mask" + NOMASK = "nomask" + REPLACE = "replace" + NOLABEL = "nolabel" # For CFG + + +def _random_mask_with_ratios(rng, ratios: jax.Array, seq_len: int): + """Generates masks where a fraction of tokens is uncovered. + + Args: + rng: RNG. + ratios: Ratios, must be a 1D matrix of shape (B,). Values must be in + [0, 1], and indicate at ratios[i] how many of the i-th tokens are + uncovered (ie. equal to `True`). + seq_len: How many tokens this mask has to cover. + + Returns: + Mask of dtype bool, shape (B, L). + + Raises: + ValueError: Incorrect inputs. + """ + if ratios.ndim != 1: + raise ValueError("Ratios must have shape (B,)!") + ratios = jnp.clip(ratios, 0, 1) + indices = jnp.arange(seq_len, dtype=jnp.float32) # Shape: (L,) + ratios = ratios[:, jnp.newaxis] * seq_len # Shape: (B, 1) + # This is a binary array where the first ratios * seq_len positions are True + mask = (indices < ratios).astype(jnp.bool_) # Shape: (B, L) + # Shuffle to a actual mask. + return jax.random.shuffle(rng, mask, axis=-1) + + +def apply_mask_schedule(ratio: float | jax.Array, method: str) -> jax.Array: + """Generate a mask rate by scheduling mask functions R.""" + if method == "cosine": + mask_ratio = jax.lax.cos(jnp.pi / 2. * ratio) + elif "pow:" in method: + exponent = float(method.replace("pow:", "")) + mask_ratio = 1. - ratio**exponent + else: + raise NotImplementedError(method) + # Clamps mask into [epsilon, 1) + mask_ratio = jnp.clip(mask_ratio, 1e-6, 1.) + return mask_ratio + + +class EncoderDecoderBlock(nn.Module): + """Transformer encoder-decoder layer.""" + mlp_dim: int + num_heads: int + dropout_rate: float = 0. + decode: bool = False + + @nn.compact + def __call__( + self, + targets: jax.Array, + encoded: jax.Array | None = None, + decoder_mask: jax.Array | None = None, + deterministic: bool = True, + ) -> tuple[jax.Array, jax.Array]: + """Applies EncoderDecoderBlock module. + + Args: + targets: target text embeddings [B, L, D]. + encoded: encoded image patches from encoder [B, P, E]. + decoder_mask: decoder self-attention mask. + deterministic: bool, deterministic or not (to apply dropout). + + Returns: + output after transformer encoder-decoder block [B, L, E]. + """ + # Helper function for axis annotation. + def wlc(f): + dim_names = ("act_batch", "act_len", "act_emb") + return nn.with_logical_constraint(f, dim_names) + # Decoder block. + x = wlc(nn.LayerNorm(name="LayerNorm1", use_bias=False)(targets)) + x = wlc(nn.SelfAttention( + num_heads=self.num_heads, use_bias=False, broadcast_dropout=False, + dropout_rate=self.dropout_rate, decode=self.decode, name="SelfAttn")( + x, decoder_mask, deterministic=deterministic)) + x = wlc(nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic)) + x = wlc(x + targets) + + if encoded is None: + y = x + else: + # Encoder-Decoder block. + y = wlc(nn.LayerNorm(name="LayerNorm2", use_bias=False)(x)) + y = wlc(nn.MultiHeadDotProductAttention( + num_heads=self.num_heads, use_bias=False, broadcast_dropout=False, + dropout_rate=self.dropout_rate, name="CrossAttn")( + y, encoded, deterministic=deterministic)) + y = wlc( + nn.Dropout(rate=self.dropout_rate)(y, deterministic=deterministic)) + y = wlc(y + x) + + # MLP block. + z = wlc(nn.LayerNorm(name="LayerNorm3", use_bias=False)(y)) + z = wlc(vit.MlpBlock(mlp_dim=self.mlp_dim, dropout=self.dropout_rate, + name="MLP")(z, deterministic=deterministic)) + + # nn.scan requires a carry (second element in tuple) + out = wlc(y + z) + return out, out + + +class Decoder(nn.Module): + """Transformer decoder model with optional cross-attention.""" + emb_dim: int + mlp_dim: int + num_heads: int + num_layers: int + out_dim: int + seq_len: int + style: Literal["ar", "masked"] + dropout_rate: float = 0. + zero_embedding_init: bool = False + + scan: bool = False + remat_policy: str = "nothing_saveable" + + @nn.compact + def __call__( + self, + targets: jax.Array, + encoded: jax.Array | None = None, + decoder_mask: jax.Array | None = None, + decode: bool = False, + deterministic: bool = True, + return_reps: bool = False, + ) -> jax.Array | tuple[jax.Array, Mapping[str, jax.Array]]: + """Applies Transformer model on the inputs. + + Args: + targets: target text tokens [B, L]. + encoded: encoded sequence from an encoder [B, P, E]. + decoder_mask: decoder self-attention mask. + decode: bool, whether to perform fast autoregressive decoding with cache. + deterministic: bool, deterministic or not (to apply dropout). + return_reps: bool, whether to return intermediate representations. + + Returns: + output of a transformer decoder [B, L, out_dim], where out_dim is usually + a multiple of D. + """ + if self.style == "masked" and decode: + raise ValueError("Cannot run masked model in cached mode!") + + pos_emb = vit.get_posemb( + self, "learn", self.seq_len, self.emb_dim, + "pos_emb") + + y = common.AddPositionEmbs( + decode=decode, name="PosEmbedTargets")(targets, pos_emb) + + out = {} + if self.scan: + # Mostly followed + # https://github.com/google/maxtext/blob/4d99e30b3e0e0cb1d1aa11c7db7fffe18e301498/MaxText/layers.py#L1126 + # for the scanned version. + + # 1. remat + enc_dec_block_remat = nn.remat( + EncoderDecoderBlock, + prevent_cse=False, + static_argnums=(-1, -2), + policy=getattr(jax.checkpoint_policies, self.remat_policy, None)) + # 2. scan + initializing = self.is_mutable_collection("params") + param_scan_axis = 1 + params_spec = (param_scan_axis if initializing + else partitioning.ScanIn(param_scan_axis)) + dec_scanned = nn.scan(enc_dec_block_remat, + variable_axes={ + "params": params_spec, + "cache": 0, + }, + split_rngs={"params": True, "dropout": True}, + in_axes=nn.broadcast, + length=self.num_layers) + # 3. fprop + y, out = dec_scanned(num_heads=self.num_heads, mlp_dim=self.mlp_dim, + dropout_rate=self.dropout_rate, decode=decode, + name="EncDecBlock")( + y, encoded, decoder_mask, deterministic) + # Extracting the intermediate representation from the stacked activation + # tensor `out`, which is a [num_layers, B, L, E] tensor. Indexing along + # the first axis to extract individual layers, and then averaging across + # the second axis, which corresponds to the sequence dimension after + # indexing. + assert out.shape[0] == self.num_layers and ( + decode or out.shape[2] == self.seq_len), ( + (out.shape, self.num_layers, self.seq_len)) + out = {f"block{l}_rep": jnp.mean(out[l], axis=1) + for l in range(self.num_layers)} + else: + for lyr in range(self.num_layers): + y, _ = EncoderDecoderBlock( + num_heads=self.num_heads, mlp_dim=self.mlp_dim, + dropout_rate=self.dropout_rate, decode=decode, + name=f"EncDecBlock{lyr}")(y, encoded, decoder_mask=decoder_mask, + deterministic=deterministic) + out[f"block{lyr}_rep"] = jnp.mean(y, axis=1) + y = nn.LayerNorm(name="LayerNorm")(y) + out["pre_logits"] = jnp.mean(y, axis=1) + + logits = nn.Dense( + self.out_dim, + kernel_init=nn.initializers.zeros, + name="LogitsDense", + )(y) + out["logits"] = logits + if return_reps: + return logits, out + return logits + + +class Model(nn.Module): + """GIVT model supporting decoder-only and encoder-decoder applications.""" + num_heads: int = 8 + # num_layers = 0 means no encoder + num_layers: int = 0 + num_decoder_layers: int = 6 + mlp_dim: int = 2048 + enc_dropout_rate: float = 0. + dec_dropout_rate: float = 0. + # Decoder params: + emb_dim: int = 512 + num_labels: Optional[int] = 1000 + seq_len: int = 256 + # Encoder params: + patches: Sequence[int] = (16, 16) + input_size: Sequence[int] = (256, 256) + posemb_type: Literal["learn", "sincos2d"] = "learn" + zero_decoder_seq: bool = False + style: Literal["ar", "masked"] = "ar" + + zero_embedding_init: bool = False + + num_mixtures: int = 4 + multivariate: bool = False + out_dim: int = 32 + scale_tol: float = 1e-6 + + # Mask specific params. + mask_schedule_train: str = "cosine" + # Results in at least 40% masked tokens with cosine. + min_masking_rate_training: float = 0.3 + + # How to fuse mask at input: + # - replace: replace token[masked] with lookup(MASK) + # - concat: replace token[mask] with lookup(REPLACE) and concat either + # lookup(NOMASK) or lookup(MASK). + mask_style: str = "replace" + + # Set to >0 for CFG support. + drop_labels_probability: float = 0.0 + + fix_square_plus: bool = False + + # If True, and mixture >1, create a GMM per channel. Otherwise, create + # a GMM of `dim`-dimensional Gaussians. + per_channel_mixtures: bool = True + + scan: bool = False + remat_policy: str = "nothing_saveable" + + @property + def has_encoder(self) -> bool: + return self.num_layers > 0 + + @property + def num_logits(self) -> int: + if self.multivariate: + assert self.num_mixtures == 1 + # d**2 covariance, d means. + # Note: `round` makes pytype happy. + return round(self.out_dim ** 2) + self.out_dim + + elif self.per_channel_mixtures: + # One (mu, sigma, pi) per output dimension and mixture component. + # Note that we predict a distribution for each output dimensions in + # parallel. + return 3 * self.num_mixtures * self.out_dim + + else: + # Mixture weights plus mean/scale per mixture + return self.num_mixtures + 2 * self.num_mixtures * self.out_dim + + def setup(self) -> None: + assert self.posemb_type == "learn" + assert self.num_mixtures > 0 + + if self.multivariate and self.num_mixtures != 1: + raise ValueError("Cannot do multivariate GMM!") + + if self.num_layers > 0: + grid_size = np.array(self.input_size) // np.array(self.patches) + + self.pos_emb_for_encoder = vit.get_posemb( + self, self.posemb_type, grid_size, self.emb_dim, + "pos_embedding_encoder") + + self.conv = nn.Conv(self.emb_dim, self.patches, padding="VALID", + strides=self.patches, name="EmbedPatches") + + self.encoder = vit.Encoder( + depth=self.num_layers, + mlp_dim=self.mlp_dim, + num_heads=self.num_heads, + dropout=self.enc_dropout_rate, + scan=self.scan, + remat_policy=self.remat_policy,) + else: + self.encoder = None + + # Iterator that will lead free label IDs. + next_label = itertools.count(self.num_labels or 0) + special_labels = {} + + if self.style == "ar": + pass + elif self.style == "masked": + if self.mask_style == "replace": + special_labels = {_SpecialLabel.MASK: next(next_label)} + elif self.mask_style == "concat": + special_labels = { + _SpecialLabel.MASK: next(next_label), + _SpecialLabel.NOMASK: next(next_label), + _SpecialLabel.REPLACE: next(next_label), + } + else: + raise NotImplementedError(self.mask_style) + else: + raise NotImplementedError(self.style) + + if self.drop_labels_probability > 0: + special_labels[_SpecialLabel.NOLABEL] = next(next_label) + + self.special_labels = special_labels + lookup_size = (self.num_labels or 1) + len(self.special_labels) + + self.labels_emb = nn.Embed( + lookup_size, + self.emb_dim, + name="EmbedLabels", + embedding_init=nn.initializers.zeros + if self.zero_embedding_init + else nn.initializers.normal(stddev=1.0), + ) + + self.targets_emb = nn.Dense(self.emb_dim, name="EmbedTargets") + + self.decoder = Decoder( + num_layers=self.num_decoder_layers or self.num_layers, + mlp_dim=self.mlp_dim, + num_heads=self.num_heads, + out_dim=self.num_logits, + # In masked mode, we run with 1 more token at the input. + seq_len=self.seq_len + int(self.style == "masked"), + dropout_rate=self.dec_dropout_rate, + emb_dim=self.emb_dim, + zero_embedding_init=self.zero_embedding_init, + style=self.style, + scan=self.scan, + remat_policy=self.remat_policy, + ) + + def encode(self, image: jax.Array, train: bool = False) -> jax.Array: + """Encodes input image or embeddings.""" + emb = self.conv(image) + patch_embeddings = einops.rearrange(emb, "B PH PW E -> B (PH PW) E") + encoded, _ = self.encoder( + patch_embeddings + self.pos_emb_for_encoder, deterministic=not train) + return encoded + + def embed_labels( + self, + labels: jax.Array | None = None, + batch_size: int | None = None, + ) -> jax.Array: + if labels is not None: + # Embed class label, add a sequence dim (output shape (B, 1, E)) + return self.labels_emb(labels)[:, None, :] + + assert ((self.num_labels == 1 or self.num_labels is None) + and batch_size is not None) + # Create [BOS] token embedding + return self.labels_emb(jnp.zeros((batch_size,), jnp.int32))[:, None, :] + + def prefill( + self, labels=None, batch_size=None, encoded=None, drop_labels=None + ): + labels = self._drop_labels(drop_labels, labels) + labels_for_prefill = self.embed_labels(labels=labels, batch_size=batch_size) + return self.decoder( + labels_for_prefill, + encoded=encoded, + decode=True) + + def _decode_ar( + self, + targets: jax.Array, + labels: jax.Array | None = None, + encoded: jax.Array | None = None, + decode: bool = False, + train: bool = False, + ) -> tuple[jax.Array, Mapping[str, jax.Array]]: + """Autoregressive decoding.""" + targets_embedded = self.targets_emb(targets) + + if decode: + decoder_mask = None + else: + decoder_mask = nn.make_causal_mask(targets[:, :, 0]) + b = targets.shape[0] + labels_embedded = self.embed_labels(labels, b) + assert labels_embedded.shape == (b, 1, self.emb_dim), ( + labels_embedded.shape, (b, 1, self.emb_dim)) + targets_embedded = jnp.concatenate( + [labels_embedded, targets_embedded[:, : -1]], axis=1) + + logits, out = self.decoder( + targets_embedded, + encoded=encoded, + decoder_mask=decoder_mask, + decode=decode, + deterministic=not train, + return_reps=True) + + return logits, out + + def _get_special_label(self, size, label: _SpecialLabel): + return self.labels_emb( + jnp.full(size, self.special_labels[label], jnp.int32) + ) + + def _decode_masked( + self, + targets, + input_mask, + labels=None, + encoded=None, + train=False, + ): + """Masked decoding.""" + b, s, _ = targets.shape + assert input_mask.shape == (b, s) + + if self.mask_style == "replace": + targets_embedded = jnp.where( + input_mask[:, :, None], + self._get_special_label((b, s), _SpecialLabel.MASK), + self.targets_emb(targets), + ) + elif self.mask_style == "concat": + masks = jnp.where( + input_mask[:, :, None], + self._get_special_label((b, s), _SpecialLabel.MASK), + self._get_special_label((b, s), _SpecialLabel.NOMASK), + ) + embedded_targets = self.targets_emb(targets) + targets_embedded = jnp.where( + input_mask[:, :, None], + self._get_special_label((b, s), _SpecialLabel.REPLACE), + embedded_targets, + ) + # Only take half of each to get the right embedding size. + targets_embedded = jnp.concatenate( + [masks[..., ::2], targets_embedded[..., ::2]], axis=-1 + ) + else: + raise ValueError(self.mask_style) + + labels_embedded = self.embed_labels(labels, b) + assert labels_embedded.shape == (b, 1, self.emb_dim) + # Note that we do not truncate the input here, so this has shape + # (B, L+1, E). + targets_embedded = jnp.concatenate( + [labels_embedded, targets_embedded], axis=1) + + logits = self.decoder( + targets_embedded, + encoded=encoded, + decoder_mask=None, + decode=False, + deterministic=not train) + + logits = logits[:, 1:, ...] # Remove class label + assert logits.shape[:2] == (b, s) + return logits + + def _drop_labels(self, drop_labels_mask, labels): + if labels is None: + return None + if self.drop_labels_probability >= 0.999: + logging.warning("Dropping all labels...") + return jnp.full_like(labels, self.special_labels[_SpecialLabel.NOLABEL]) + if drop_labels_mask is None: + return labels + assert _SpecialLabel.NOLABEL in self.special_labels + nolabel = jnp.full_like( + labels, self.special_labels[_SpecialLabel.NOLABEL] + ) + return jnp.where(drop_labels_mask, nolabel, labels) + + def decode( + self, + targets: jax.Array, + labels: jax.Array | None = None, + encoded: jax.Array | None = None, + decode: bool = False, + train: bool = False, + max_decode_length: int | None = None, + input_mask: jax.Array | None = None, + drop_labels: jax.Array | None = None, + return_reps: bool = False, + ) -> jax.Array | tuple[jax.Array, Mapping[str, jax.Array]]: + """Applies Transformer decoder-branch on encoded-input and target. + + Args: + targets: target text tokens [B, L, out_dim]. + labels: optional class labes, [B]. + encoded: encoded image patches from encoder [B, P, E]. + decode: whether to prepare and use an autoregressive cache. + train: whether it is training. + max_decode_length: optional max length for positional embeddings. + input_mask: If given, mask input. Required for style=="masked". + Shape [B, L], bool tensor. True means the token will be removed + from the input. + drop_labels: Drop labels at corresponding locations [B]. + return_reps: whether to return intermediate representations. + + Returns: + logits array from transformer decoder [B, L, 3 * num_mixtures * out_dim]. + """ + del max_decode_length + labels = self._drop_labels(drop_labels, labels) + if self.style == "ar": + logits, out = self._decode_ar( + targets, labels, encoded, decode, train) + if return_reps: + return logits, out + return logits + elif self.style == "masked": + assert not decode # Cache not supported. + assert input_mask is not None + assert not return_reps # Not implemented. + return self._decode_masked(targets, input_mask, labels, encoded, train) + else: + raise NotImplementedError(self.style) + + def _square_plus(self, x): + # Via https://twitter.com/jon_barron/status/1387167648669048833 + if self.fix_square_plus: + return (x + jnp.sqrt(jnp.square(x) + 4)) / 2 + else: + return x + jnp.sqrt(jnp.square(x) + 4) / 2 + + def get_pdf( + self, + logits: jax.Array, + temperature_scales: float | None = None, + temperature_probs: float | None = None, + ) -> distrax.Distribution: + assert logits.shape[-1] == self.num_logits + if self.multivariate: + scales = logits[..., :self.out_dim ** 2] + locs = logits[..., self.out_dim ** 2:] + assert locs.shape[-1] == self.out_dim + scales = self._square_plus(scales) + # Turn into a square matrix. + *leading, _ = scales.shape + scales = scales.reshape(*leading, self.out_dim, self.out_dim) + # Make sure the diagonals are non zero. + diag_scale_tol = jnp.eye(self.out_dim) * self.scale_tol + scales = jnp.maximum(scales, diag_scale_tol) + if (t := temperature_scales) is not None: + scales = scales * t + + # Note that there is `tfd.MultivariateNormalFullCovariance`` but it just + # calls linalg.cholesky on the covariance and then uses the + # MultivariateNormalTri class. Using ... direcly avoids having to + # construct a hermetian matrix. + # + # Note that only the lower triag part of `scales` is used by applying + # jnp.tril. The other elements are replaced with zeros. + # + # Note on output shapes: + # - .sample() -> shape (..., seq_len, out_dim) + # - .prob() -> shape (..., seq_len). + return distrax.MultivariateNormalTri(locs, scales) + + elif self.per_channel_mixtures: + # [..., 3 * num_mixtures * out_dim] -> [..., 3 * out_dim, num_mixtures] + logits = jnp.reshape(logits, logits.shape[: -1] + (-1, self.num_mixtures)) + # 3 tensors with shape [..., out_dim, num_mixtures] + probs, locs, scales = jnp.split(logits, 3, axis=-2) + if (t := temperature_probs) is not None: + probs = probs * t + + # normalize mixture probabilities + probs = nn.softmax(probs) + scales = self._square_plus(scales) + # threshold scale + scales = jnp.maximum(scales, self.scale_tol) + if (t := temperature_scales) is not None: + scales = scales * t + + # Note on output shapes: + # - .sample() -> shape (..., seq_len, out_dim) + # - .prob() -> shape (..., seq_len, out_dim). + return distrax.MixtureSameFamily( + mixture_distribution=distrax.Categorical(probs=probs), + components_distribution=distrax.Normal(loc=locs, scale=scales), + ) + else: + *shape, num_logits = logits.shape + assert num_logits == self.num_logits, (num_logits, self.num_logits) + prob_logits, other_logits = ( + logits[..., : self.num_mixtures], + logits[..., self.num_mixtures :], + ) + if (t := temperature_probs) is not None: + prob_logits = prob_logits * t + other_logits = jnp.reshape( + other_logits, (*shape, self.num_mixtures, 2, self.out_dim) + ) + locs = other_logits[..., 0, :] + scales = self._square_plus(other_logits[..., 1, :]) + + scales = jnp.maximum(scales, self.scale_tol) # Threshold scale + if (t := temperature_scales) is not None: + scales = scales * t + + # prob_logits has shape (b, seq_len, m) + # locs/scales has shape (b, seq_len, m, d) + assert prob_logits.ndim == locs.ndim - 1, (prob_logits.shape, locs.shape) + assert locs.shape == scales.shape, (locs.shape, scales.shape) + + # Note on output shapes: + # - .sample() -> shape (..., seq_len, out_dim) + # - .prob() -> shape (..., seq_len,) + # - .nll() -> shape (..., seq_len,) + return distrax.MixtureSameFamily( + mixture_distribution=distrax.Categorical(logits=prob_logits), + components_distribution=distrax.MultivariateNormalDiag( + loc=locs, scale_diag=scales + ), + ) + + def __call__( + self, + sequence: jax.Array, + labels: jax.Array | None = None, + *, + image: jax.Array | None = None, + decode: bool = False, + input_mask: jax.Array | None = None, + drop_labels: jax.Array | None = None, + train: bool = False, + ) -> tuple[jax.Array, distrax.Distribution]: + """Applies Transformer model on the inputs. + + Args: + sequence: batch of sequences [B, L]. + labels: class labels for class conditional generation [B]. + image: batch of images [B, H, W, 3]. + decode: whether to prepare and use an autoregressive cache. + input_mask: If given, mask input. Required for style=="masked" [B, L]. + drop_labels: If given, drop labels of the corresponding batches [B]. + train: whether it is training. + + Returns: + logits array from full transformer [B, L, out_dim]. + """ + if self.style == "masked" and input_mask is None: + raise ValueError("Cannot run masked model without input mask!") + + if self.encoder is not None: + assert image is not None + encoded = self.encode(image, train=train) + else: + assert image is None + encoded = None + + logits = self.decode(sequence, labels=labels, encoded=encoded, + decode=decode, input_mask=input_mask, train=train) + pdf = self.get_pdf(logits) + return logits, pdf + + def get_input_mask_training( + self, + rng: jax.Array, + shape: tuple[int, int], + ) -> jax.Array | None: + """Creates a random maask of shape (B, L) for training masked models.""" + if self.style == "ar": + return None + b, s = shape + # Sample b values in [0, 1-min_mask_ratio]. + keep = jax.random.uniform( + rng, shape=(b,), maxval=1.0 - self.min_masking_rate_training + ) + mask_ratio = apply_mask_schedule(keep, self.mask_schedule_train) + return _random_mask_with_ratios(rng, ratios=mask_ratio, seq_len=s) + + def get_input_mask_teacher_forced( + self, + shape: tuple[int, int], + ) -> jax.Array | None: + """Creates a random maask of shape (B, L) for training masked models.""" + if self.style == "ar": + return None + return jnp.zeros(shape, dtype=jnp.bool_) + + def get_drop_labels( + self, + rng: jax.Array, + batch_size: int, + ) -> jax.Array | None: + if (p := self.drop_labels_probability) > 0: + return jax.random.uniform(rng, shape=(batch_size,)) <= p + else: + return None + + +def load( + init_params: Any, + init_files: str | Mapping[str, str], + model_params: Any = None, + dont_load: Sequence[str] = (), + resample_encoder_posemb: bool = False, + trim_decoder_posemb: bool = False, +) -> Any: + """Loads params from init checkpoint and merges into init_params.""" + del model_params + if isinstance(init_files, str): + ckpt_params = utils.load_params(init_files) + ckpt_params = common.merge_params(ckpt_params, init_params, dont_load) + + if resample_encoder_posemb: + if init_params and "pos_embedding_encoder" in init_params: + ckpt_params["pos_embedding_encoder"] = vit.resample_posemb( + old=ckpt_params["pos_embedding_encoder"], + new=init_params["pos_embedding_encoder"]) + + if trim_decoder_posemb: + if init_params and "pos_embedding_decoder" in init_params: + ckpt_params["pos_embedding_decoder"] = ( + ckpt_params["pos_embedding_decoder"][ + :, :init_params["pos_embedding_decoder"].shape[1], :]) + + else: + init_files = {**init_files} # Shallow copy because we'll pop stuff off. + + enc_init = init_files.pop("encoder", None) + if enc_init: + ckpt_params = init_params.copy() + vit_params = { + "pos_embedding": ckpt_params["pos_embedding_encoder"], + "Transformer": ckpt_params["encoder"], + "embedding": ckpt_params["EmbedPatches"], + } + encoder_params = vit.load( + vit_params, enc_init, model_cfg={}, + dont_load=dont_load) + ckpt_params["encoder"] = encoder_params["Transformer"] + ckpt_params["pos_embedding_encoder"] = encoder_params["pos_embedding"] + ckpt_params["EmbedPatches"] = encoder_params["embedding"] + else: + raise ValueError("Only encoder init is supported: {}.".format(init_files)) + + return ckpt_params diff --git a/big_vision/models/proj/givt/givt_test.py b/big_vision/models/proj/givt/givt_test.py new file mode 100644 index 0000000000000000000000000000000000000000..acfc59e73448efffbb34bb73728ca3a52cbd02fb --- /dev/null +++ b/big_vision/models/proj/givt/givt_test.py @@ -0,0 +1,124 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for GIVT model.""" + +from absl.testing import parameterized +from big_vision.models.proj.givt import givt +import jax +import jax.numpy as jnp +import numpy as np + +from absl.testing import absltest + + +_BATCH_SIZE = 2 +_OUT_DIM = 4 +_SEQ_LEN = 16 +_NUM_MIXTURES = 4 + + +def _make_test_model(**overwrites): + config = dict( + num_heads=2, + num_decoder_layers=1, + mlp_dim=64, + emb_dim=16, + seq_len=_SEQ_LEN, + out_dim=_OUT_DIM, + num_mixtures=_NUM_MIXTURES, + ) + config.update(overwrites) + return givt.Model(**config) + + +class MaskedTransformerTest(parameterized.TestCase): + + @parameterized.product(rng_seed=[0]) + def test_masks(self, rng_seed): + m = _make_test_model(style="masked") + mask = m.get_input_mask_training(jax.random.PRNGKey(rng_seed), (2, 16)) + self.assertEqual(mask.shape, (2, 16)) + # At least one should definitly be masked out. + self.assertTrue(np.all(mask.sum(-1) > 1)) + + @parameterized.product( + train=[True, False], + multivariate=[True, False], + per_channel_mixtures=[True, False], + drop_labels_probability=[0.0, 0.1], + style=["masked", "ar"], + ) + def test_apply( + self, + train, + multivariate, + per_channel_mixtures, + drop_labels_probability, + style, + ): + if per_channel_mixtures and multivariate: + self.skipTest("Not supported") + model = _make_test_model( + style=style, + multivariate=multivariate, + num_mixtures=1 if multivariate else _NUM_MIXTURES, + per_channel_mixtures=per_channel_mixtures, + drop_labels_probability=drop_labels_probability, + ) + sequence = jax.random.uniform( + jax.random.PRNGKey(0), (_BATCH_SIZE, _SEQ_LEN, _OUT_DIM) + ) + labels = jax.random.uniform( + jax.random.PRNGKey(0), (_BATCH_SIZE,), maxval=10 + ).astype(jnp.int32) + input_mask = jax.random.uniform( + jax.random.PRNGKey(0), (_BATCH_SIZE, _SEQ_LEN) + ).astype(jnp.bool_) + variables = model.init( + jax.random.PRNGKey(0), + sequence, + labels, + input_mask=input_mask, + train=train, + ) + logits, pdf = model.apply( + variables, sequence, labels, input_mask=input_mask, train=train + ) + nll = -pdf.log_prob(sequence) + self.assertFalse(np.any(np.isnan(nll))) + if multivariate: + self.assertEqual( + logits.shape, (_BATCH_SIZE, _SEQ_LEN, _OUT_DIM**2 + _OUT_DIM) + ) + self.assertEqual(nll.shape, (_BATCH_SIZE, _SEQ_LEN)) + elif per_channel_mixtures: + self.assertEqual( + logits.shape, + (_BATCH_SIZE, _SEQ_LEN, 3 * _NUM_MIXTURES * _OUT_DIM), + ) + self.assertEqual(nll.shape, (_BATCH_SIZE, _SEQ_LEN, _OUT_DIM)) + else: + self.assertEqual( + logits.shape, + (_BATCH_SIZE, _SEQ_LEN, _NUM_MIXTURES + _NUM_MIXTURES * _OUT_DIM * 2), + ) + self.assertEqual(nll.shape, (_BATCH_SIZE, _SEQ_LEN)) + + sample = pdf.sample(seed=jax.random.PRNGKey(0)) + self.assertEqual(sample.shape, (_BATCH_SIZE, _SEQ_LEN, _OUT_DIM)) + + +if __name__ == "__main__": + googletest.main() diff --git a/big_vision/models/proj/givt/parallel_decode.py b/big_vision/models/proj/givt/parallel_decode.py new file mode 100644 index 0000000000000000000000000000000000000000..b05903a7e30a870ba5b09b0965932321090f1190 --- /dev/null +++ b/big_vision/models/proj/givt/parallel_decode.py @@ -0,0 +1,523 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Decode autoregressive/bidirectional masked transformers. + + +Currently, we implement MaskGIT style temperature sampling: + +In each step: +1. Get P = model(inputs), predicted GMMs +2. Get samples = sample_from(P) +3. Get probs = P[samples], ie, model evaluated at samples. + We use this now as a confidence metric, but we scale the probs: +4. probs = probs ^ 1/choice_temperature +4. set probs[already_uncovered_points] = inf, ie, we will always keep + uncovered points (no resampling!) +5. Now pick top K points from probs to keep for the next steps, where + K = some monotonically increasing ratio of points as we go along decoding +""" + +import dataclasses +from typing import Literal + +from absl import logging +from big_vision.models.proj.givt import givt +import distrax +import flax +import jax +import jax.numpy as jnp + + +_CONFIDENCE_OF_KNOWN_TOKENS = jnp.inf + + +@jax.vmap +def _get_per_batch_mask(arr, k): + (d,) = arr.shape + indices = jnp.argsort(arr) + valid_indices = jnp.arange(d) < k + return jnp.zeros((d,), jnp.bool_).at[indices].set(valid_indices) + + +def _get_bottom_k_mask(arr, k): + *leading, d = arr.shape + arr = arr.reshape((-1, d)) + mask = _get_per_batch_mask(arr, k) + return mask.reshape(*leading, -1) + + +def mask_by_random_topk(rng, mask_len, probs, temperature=1.0): + """Create a mask. + + Adaption of jax.random.choice where probabilities are changed by scaling with + `temperature` (probs = probs ^ (1/temperature)). + + Additionally, this function returns a mask of tokens to mask out, which + are picked to be the low confidence ones. Thus, this function is roughly + equivalent to (but not exactly at edge cases such as prob = inf..): + + keep = jax.random.choice( + rng, seq_len, + shape=(seq_len - mask_len,), + # NOTE: probabilities are updated with `temperature`. + p=jnp.power(probs, 1/temperature), + replace=False + ) + mask = jnp.ones((seq_len,), dtype=jnp.bool_) + return mask.at[..., keep].set(False) + + Args: + rng: a PRNG key used as the random key. + mask_len: the number to mask. + probs: the probabilities associated with each entry. + temperature: when temperature = 1.0, it's identical to jax's implementation. + The larger this value is, the more random the masking is picked. + + Returns: + A binary masking map [batch_size, seq_len]. Contains True where we should + mask (at mask_len locations), and False where we should keep. + """ + confidence = jnp.log(probs) + temperature * jax.random.gumbel( + rng, probs.shape) + return _get_bottom_k_mask(confidence, mask_len) + + +@flax.struct.dataclass +class DecodeState: + """Holds decoding state data.""" + + rng: jax.Array # Sampling random state. + # The position of the decoding loop in the length dimension. Scalar int32. + step: jax.Array + # What we input at each step. Starts from all masks and is uncovered by + # sampling. Note that this is an array with leading + # dimension `num_steps + 1` because we start with all masked tokens and then + # need `num_steps` to uncover all, i.e., the final output is given by + # all_inputs_q[-1, ...]. + all_inputs_q: jax.Array # float32 [num_steps + 1, batch, seq_len, c] + # Has a 1 for every _uncovered_ point. + uncovered_per_step: jax.Array # bool_ [num_steps, batch, seq_len] + logits_per_step: jax.Array # [num_steps, batch, seq_len, num_logits] + uncond_logits_per_step: jax.Array # [num_steps, batch, seq_len, num_logits] + prob_per_step: jax.Array # Probability per step. + # If CFG: Rejection sampling success rate. + rejection_sampling_success_per_step: jax.Array + + @classmethod + def make( + cls, + initial_rng: jax.Array, + all_masked_input: jax.Array, + num_logits: int, + num_steps: int, + ) -> "DecodeState": + """Creates the initial state.""" + b, seq_len, c = all_masked_input.shape + all_inputs_q = jnp.broadcast_to( + all_masked_input, + (num_steps + 1, b, seq_len, c), + ) + return cls( + initial_rng, + step=jnp.array(0), + all_inputs_q=all_inputs_q, + uncovered_per_step=jnp.full((num_steps, b, seq_len), False, jnp.bool_), + logits_per_step=jnp.full( + (num_steps, b, seq_len, num_logits), jnp.nan, jnp.float32 + ), + uncond_logits_per_step=jnp.full( + (num_steps, b, seq_len, num_logits), jnp.nan, jnp.float32 + ), + prob_per_step=jnp.full((num_steps, b, seq_len), jnp.nan, jnp.float32), + rejection_sampling_success_per_step=jnp.full( + (num_steps,), jnp.nan, jnp.float32 + ), + ) + + @property + def current_inputs_q(self) -> jax.Array: + """Returns the current quantized input.""" + return self.all_inputs_q[self.step, ...] + + @property + def num_steps(self) -> int: + """Returns number of decode steps.""" + return self.uncovered_per_step.shape[0] + + def _steps_mask(self) -> jax.Array: + return jnp.arange(self.num_steps) <= self.step + + @property + def total_uncovered(self) -> jax.Array: + """Returns the total uncovered mask up to and including current step.""" + return self.uncovered_per_step.sum( + axis=0, where=self._steps_mask()[:, jnp.newaxis, jnp.newaxis] + ).astype(jnp.bool_) + + def split_rng(self) -> tuple["DecodeState", jax.Array]: + """Splits of RNG for the current step.""" + rng, step_rng = jax.random.split(self.rng, 2) + return self.replace(rng=rng), step_rng + + def set_next_input(self, next_input_q: jax.Array) -> "DecodeState": + """Sets the input for the next step.""" + return self._set_row("all_inputs_q", self.step + 1, next_input_q) + + def set_uncover_at_current_step(self, uncovered: jax.Array) -> "DecodeState": + """Sets what was uncovered after the current step.""" + return self._set_row("uncovered_per_step", self.step, uncovered) + + def set_logits_at_current_step(self, logits: jax.Array) -> "DecodeState": + return self._set_row("logits_per_step", self.step, logits) + + def set_uncond_logits_at_current_step( + self, logits: jax.Array + ) -> "DecodeState": + return self._set_row("uncond_logits_per_step", self.step, logits) + + def set_rejection_sampling_success_at_current_step( + self, success: jax.Array + ) -> "DecodeState": + return self._set_row( + "rejection_sampling_success_per_step", self.step, success + ) + + def set_prob_at_current_step(self, prob: jax.Array) -> "DecodeState": + return self._set_row("prob_per_step", self.step, prob) + + def increment_step(self) -> "DecodeState": + """Increments step.""" + return self.replace(step=self.step + 1) + + def _set_row(self, attr_name, row_index, row_value): + """Sets one row of the variables that have shape (num_steps, ...).""" + current_value = getattr(self, attr_name) + _, *expected_shape = current_value.shape + if row_value.shape != tuple(expected_shape): + raise ValueError(f"Expected {row_value.shape} == {expected_shape}!") + if row_value.dtype != current_value.dtype: + raise ValueError(f"Expected {row_value.dtype} == {current_value.dtype}") + new_value = current_value.at[row_index, ...].set(row_value) + return self.replace(**{attr_name: new_value}) + + +@dataclasses.dataclass(frozen=True) +class MaskedGenerationConfig: + """Config for masked generation. + + Attributes: + num_steps: Number of sampling steps. + should_anneal_temperature: If given, anneal choice temperature as we go + through the sampling steps. + choice_temperature: Temperature for picking points. + ordering: How to order to select. Supports: + maskgit: Maskgit style, use P[samples] + schedule: Inference mask schedule. + cfg_inference_weight: CFG Inference weight. + """ + num_steps: int = 16 + should_anneal_temperature: bool = True + choice_temperature: float = 1.0 + ordering: Literal["maskgit"] = "maskgit" + schedule: str = "cosine" + cfg_inference_weight: float = 0.0 + + +def _assert_single_component_get_loc_scale( + pdf: distrax.Distribution, rng=None, mixture=None +): + """Extracts loc and scale from a single mixture GMM.""" + if not isinstance(pdf, distrax.MixtureSameFamily): + raise ValueError(f"Expected mixture! Got {type(pdf)}") + components_d = pdf.components_distribution + if isinstance(components_d, distrax.MultivariateNormalDiag): + loc, scale_diag = components_d.loc, components_d.scale_diag + b, s, m, _ = loc.shape + if mixture is None: + assert rng is not None + # Shape (b, seq) + mixture = pdf.mixture_distribution.sample(seed=rng) + mixture = jax.nn.one_hot(mixture, num_classes=m, axis=-1) + assert mixture.shape == (b, s, m), (mixture.shape, loc.shape) + loc = (loc * mixture[..., None]).sum(-2) + scale_diag = (scale_diag * mixture[..., None]).sum(-2) + return loc, scale_diag, mixture + else: + loc, scale = components_d.loc, components_d.scale + if loc.shape[-1] != 1 or scale.shape[-1] != 1: + raise ValueError(f"Expected one mixture! {loc.shape}/{scale.shape}") + return loc[..., 0], scale[..., 0], None + + +class CFGDensity: + """Helper to get probability and samples via CFG.""" + + pdf_c: distrax.Distribution + pdf_u: distrax.Distribution + w: float + simple: distrax.Distribution + fac: jax.Array + + def __init__( + self, + pdf_c: distrax.Distribution, + pdf_u: distrax.Distribution, + w: float, + rng: jax.Array, + ) -> None: + loc_c, scale_c, mixture = _assert_single_component_get_loc_scale(pdf_c, rng) + # Note: RNG only needed when we have mixtures, to select components. + loc_u, scale_u, _ = _assert_single_component_get_loc_scale( + pdf_u, rng, mixture=mixture + ) + + # Definitly wider than whatever we had before. The mean should be slightly + # away though! + loc_simple = loc_c + scale_simple = jnp.stack([scale_c, scale_u], -1).max(-1) * 2 + self.simple = distrax.Normal(loc_simple, scale_simple) + + self.pdf_c = distrax.Normal(loc_c, scale_c) + self.pdf_u = distrax.Normal(loc_u, scale_u) + self.w = w + + assert loc_c.ndim == 3, loc_c.shape + points = loc_c[jnp.newaxis, ...] + jnp.linspace(-10, 10, 1001).reshape( + -1, 1, 1, 1 + ) + p_at_c, _ = self._unnormalized_p(points) + + self.fac = jnp.max(p_at_c / self.simple.prob(loc_c), axis=0) + jax.debug.print("🎲 CFG {fac}", fac=self.fac.mean()) + + def _unnormalized_p(self, x): + w = self.w + logp_cfg = (1 + w) * self.pdf_c.log_prob(x) - w * self.pdf_u.log_prob(x) + return jnp.exp(logp_cfg), logp_cfg + + def rejection_sample( + self, + seed: jax.Array, + max_samples: int = 1_000, + ) -> tuple[jax.Array, jax.Array]: + """Rejection sampling, try `max_samples`, take first match.""" + rng_sample, rng_uni = jax.random.split(seed, 2) + # Shape (max_samples, b, seq_len, c) + xs = self.simple.sample(seed=rng_sample, sample_shape=(max_samples,)) + facq = self.fac * self.simple.prob(xs) + ys = jax.random.uniform(rng_uni, shape=facq.shape, minval=0.0, maxval=facq) + # Shape (max_samples, b, seq_len, c), True where `xs` is a valid sample + # from p. We might have anywhere between 0 and `max_samples` valid samples! + p, _ = self._unnormalized_p(xs) + mask = ys < p + # Now we need to do fancy tricks to get the first element in `mask` that is + # True. We do this by making a shifted mask that is False for every element + # after the first True. + # > Example: + # mask [0, 1, 0, 1, 0, 0, 1, 0] + # > implies: + # cmask [0, 1, 1, 1, 1, 1, 1, 1] + # shifted_cmask [0, 0, 1, 1, 1, 1, 1, 1] + # keep [0, 1, 0, 0, 0, 0, 0, 0] # <- picks the first valid! + cmask = jnp.cumsum(mask, axis=0).astype(jnp.bool_) + shifted_cmask = jnp.pad( + cmask, [(1, 0), (0, 0), (0, 0), (0, 0)], constant_values=False + )[:-1] + assert shifted_cmask.shape == mask.shape + keep = jnp.logical_and(cmask, jnp.logical_not(shifted_cmask)) + # Now we can grab the first valid sample by doing a sum over the + # `max_samples` dimension. + sample = jnp.where(keep, xs, 0).sum(0) + # If the rejection sampler fails, we fall back to the conditional + # distribution. + ok = mask.sum(0) > 0 # Shape (b, seq_len, c) + # jax.debug.print("🎲 CFG ok {ok}%", ok=ok.mean() * 100) + sample = jnp.where( + ok, sample, self.pdf_c.sample(seed=rng_sample) + ) + return sample, ok.mean() * 100 + + def sample( + self, + seed: jax.Array, + max_samples: int = 1_000, + ) -> jax.Array: + result, ok = self.rejection_sample(seed, max_samples) + jax.debug.print("Debug ok={ok}%", ok=ok) + return result + + # Unnormalized! But we only use it for ordering. + def prob(self, xs: jax.Array) -> jax.Array: + p, _ = self._unnormalized_p(xs) + return p + + def log_prob(self, xs: jax.Array) -> jax.Array: + _, lp = self._unnormalized_p(xs) + return lp + + +def decode_masked( + rng: jax.Array, + labels: jax.Array, + seq_len: int, + feature_dim: int, + model: givt.Model, + variables: flax.core.FrozenDict, + config: MaskedGenerationConfig, +) -> DecodeState: + """Implements an masked bidirectional sampling loop. + + This function implements the loop from the docstring. + + Args: + rng: RNG, only required if sampling. + labels: Shape (b,), labels per batch. Determines batch size. + seq_len: How many tokens to sample per batch. + feature_dim: Output dimension of the VAE, i.e., number of channels, `c`. + model: GIVT model to sample from. + variables: Variables of the model. + config: Configures style. + + Returns: + Final state. + """ + logging.info("Masked Generation Config:\n%s", config) + + if model.style != "masked": + raise ValueError(f"Need masked model! Got `{model.style}`.") + + (b,) = labels.shape + all_masked_input = jnp.zeros((b, seq_len, feature_dim)) + init_state = DecodeState.make( + rng, + all_masked_input, + num_logits=model.num_logits, + num_steps=config.num_steps, + ) + + def loop_cond_fn(state: DecodeState): + return state.step < state.num_steps + + def tokens_to_logits(tokens, input_mask, drop_labels=None): + return model.apply( + variables, + tokens, + labels=labels, + # Note that the model applies the mask token internally given the input. + input_mask=input_mask, + drop_labels=drop_labels, + method="decode", + ) + + def loop_body_fn(state: DecodeState) -> DecodeState: + # 1 where we should mask, cumulative. + unknown = jnp.logical_not(state.total_uncovered) + + # Defines the mask ratio for the next round. The number to mask out is + # determined by mask_ratio * unknown_number_in_the_beginning. + ratio = (state.step + 1) / config.num_steps + # Note that the mask schedule inverts the function, so `mask_ratio` givts + # near 1 and goes to 0 monotonically. + mask_ratio = givt.apply_mask_schedule(ratio, method=config.schedule) + mask_len = jnp.floor(seq_len * mask_ratio).reshape(1, 1) + num_unknown = jnp.sum(unknown, axis=-1, keepdims=True) + mask_len = jnp.maximum( + 0, + # Keeps at least one of prediction in this round: Avoids the case where + # mask_len is equal to num_unknown, in which case the mask is not + # updated! We substract 1 to always remove at least one masked token. + jnp.minimum(num_unknown - 1, mask_len)) + + # Run model --- + logits = tokens_to_logits(state.current_inputs_q, unknown) + # Book keeping: store all logits. + state = state.set_logits_at_current_step(logits) + + pdf = model.get_pdf(logits) + state, sample_rng = state.split_rng() + if config.cfg_inference_weight > 0: + drop_all_labels = jnp.full((b,), True, jnp.bool_) + logits_uncond = tokens_to_logits( + state.current_inputs_q, unknown, drop_labels=drop_all_labels + ) + state = state.set_uncond_logits_at_current_step(logits_uncond) + pdf_uncond = model.get_pdf(logits_uncond) + state, cfg_rng = state.split_rng() + pdf = CFGDensity( + pdf_c=pdf, + pdf_u=pdf_uncond, + w=config.cfg_inference_weight, + rng=cfg_rng, + ) + sample, rejection_sampling_success = pdf.rejection_sample(sample_rng) + state = state.set_rejection_sampling_success_at_current_step( + rejection_sampling_success + ) + else: + sample = pdf.sample(seed=sample_rng) + + # Sample at the unknown spots. + sampled = jnp.where(unknown[:, :, None], sample, state.current_inputs_q) + assert sampled.shape == (b, seq_len, feature_dim), ( + sampled.shape, + b, + seq_len, + feature_dim, + ) + + prob = pdf.prob(sampled) + if model.multivariate: + assert prob.ndim == 2 # (b, seq_len) already + elif model.per_channel_mixtures or config.cfg_inference_weight > 0: + # Independence accross channels. + # This reduction is also required when using CFG and also + # `model.per_channel_mixtures == False` due to the 2-step CFG redefining + # the pdf, but the reduction is not needed without CFG. + prob = prob.prod(-1) + state = state.set_prob_at_current_step(prob) + + if config.ordering == "maskgit": + ordering = jnp.where(unknown, prob, _CONFIDENCE_OF_KNOWN_TOKENS) + else: + raise NotImplementedError(config.ordering) + + assert ordering.shape == (b, seq_len), (ordering.shape, b, seq_len) + + temp = config.choice_temperature + if config.should_anneal_temperature: + temp *= (1. - ratio) + + # True where we should mask input. Note that this is cumulative (ie this + # starts with all True and keeps getting more False entries as we go through + # the steps). + state, choice_rng = state.split_rng() + masking = mask_by_random_topk(choice_rng, mask_len, ordering, temp) + assert masking.shape == (b, seq_len) + masking = jnp.where(mask_len == 0, jnp.zeros_like(masking), masking) + + # Remove the masked tokens from the sampled array for safety (the model will + # again apply the mask anyway...). + sampled = jnp.where(masking[:, :, None], jnp.zeros_like(sampled), sampled) + + # Get next_uncover --- + # New tokens to uncover (non cumulative): where it was unknown + # but is now known. + next_uncover = jnp.logical_and(unknown, jnp.logical_not(masking)) + assert next_uncover.shape == (b, seq_len), (next_uncover.shape, b, seq_len) + state = state.set_uncover_at_current_step(next_uncover) + state = state.set_next_input(sampled) + return state.increment_step() + + return jax.lax.while_loop(loop_cond_fn, loop_body_fn, init_state) diff --git a/big_vision/models/proj/givt/parallel_decode_test.py b/big_vision/models/proj/givt/parallel_decode_test.py new file mode 100644 index 0000000000000000000000000000000000000000..c03effc5f944c99db7ec3e392abe255bb8583eb9 --- /dev/null +++ b/big_vision/models/proj/givt/parallel_decode_test.py @@ -0,0 +1,154 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import parameterized +from big_vision.models.proj.givt import givt +from big_vision.models.proj.givt import parallel_decode +import chex +import jax +import jax.numpy as jnp + +from absl.testing import absltest + + +_BATCH_SIZE = 2 +_OUT_DIM = 4 +_SEQ_LEN = 6 +_NUM_MIXTURES = 4 + + +def _make_test_model(**overwrites): + config = dict( + num_heads=2, + num_decoder_layers=1, + mlp_dim=64, + emb_dim=16, + seq_len=_SEQ_LEN, + out_dim=_OUT_DIM, + num_mixtures=_NUM_MIXTURES, + style="masked", + ) + config.update(overwrites) + return givt.Model(**config) + + +def _mask(*flags): + return jnp.asarray(flags).astype(jnp.bool_) + + +class HelperTest(googletest.TestCase): + + def test_get_first_n(self): + with self.subTest("ordered"): + values = jnp.asarray([4, 3, 2, 1, 0]) + k = jnp.asarray([3], jnp.int32) + chex.assert_trees_all_equal( + parallel_decode._get_bottom_k_mask(values, k), _mask(0, 0, 1, 1, 1) + ) + + with self.subTest("equal_values"): + values = jnp.ones((5,)) + k = jnp.asarray([3], jnp.int32) + chex.assert_trees_all_equal( + parallel_decode._get_bottom_k_mask(values, k), _mask(1, 1, 1, 0, 0) + ) + + with self.subTest("equal_values"): + values = jnp.asarray([1, 2, 2, 2, 3]) + k = jnp.asarray([3], jnp.int32) + chex.assert_trees_all_equal( + parallel_decode._get_bottom_k_mask(values, k), _mask(1, 1, 1, 0, 0) + ) + + +class ParallelDecodeTest(parameterized.TestCase): + + def _make_model(self, **overwrites): + model = _make_test_model(**overwrites) + sequence = jax.random.uniform( + jax.random.PRNGKey(0), (_BATCH_SIZE, _SEQ_LEN, _OUT_DIM) + ) + labels = jax.random.uniform( + jax.random.PRNGKey(0), (_BATCH_SIZE,), maxval=10 + ).astype(jnp.int32) + input_mask = jax.random.uniform( + jax.random.PRNGKey(0), (_BATCH_SIZE, _SEQ_LEN) + ).astype(jnp.bool_) + variables = model.init( + jax.random.PRNGKey(0), + sequence, + labels, + input_mask=input_mask, + train=False, + ) + return model, variables + + def _test_model(self, rng, model, variables, config): + labels = jnp.ones((_BATCH_SIZE,), dtype=jnp.int32) + state = parallel_decode.decode_masked( + rng, + seq_len=_SEQ_LEN, + feature_dim=_OUT_DIM, + labels=labels, + model=model, + variables=variables, + config=config, + ) + self.assertEqual(int(state.step), 4) + # Each point uncovered exactly once. + chex.assert_trees_all_equal( + state.uncovered_per_step.sum(0), + jnp.ones((_BATCH_SIZE, _SEQ_LEN), dtype=jnp.int32), + ) + + @parameterized.product( + rng_seed=[1, 2], + choice_temperature=[1.0, 4.0], + multivariate=[True, False], + ) + def test_decode_masked(self, rng_seed, choice_temperature, multivariate): + rng = jax.random.PRNGKey(rng_seed) + model, variables = self._make_model( + num_mixtures=1 if multivariate else _NUM_MIXTURES, + multivariate=multivariate, + ) + config = parallel_decode.MaskedGenerationConfig( + num_steps=4, + choice_temperature=choice_temperature, + ) + self._test_model(rng, model, variables, config) + + @parameterized.product( + rng_seed=[1, 2], + choice_temperature=[1.0, 4.0], + w=[0.0, 1.0, 3.0], + per_channel_mixtures=[True, False], + ) + def test_cfg(self, rng_seed, choice_temperature, w, per_channel_mixtures): + rng = jax.random.PRNGKey(rng_seed) + model, variables = self._make_model( + num_mixtures=1 if per_channel_mixtures else 3, + drop_labels_probability=0.1, + per_channel_mixtures=per_channel_mixtures, + ) + config = parallel_decode.MaskedGenerationConfig( + num_steps=4, + choice_temperature=choice_temperature, + cfg_inference_weight=w, + ) + self._test_model(rng, model, variables, config) + + +if __name__ == "__main__": + googletest.main() diff --git a/big_vision/models/proj/givt/vae.py b/big_vision/models/proj/givt/vae.py new file mode 100644 index 0000000000000000000000000000000000000000..db1ab66721e0fc6131f2f98a2cbbddad9d66007c --- /dev/null +++ b/big_vision/models/proj/givt/vae.py @@ -0,0 +1,94 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Abstract VAE model class. + +Gaussian encoder and decoder (the latter assumed to have constant variance). + +Inspiration drawn from https://github.com/pytorch/examples/tree/main/vae. +""" + +import abc +from typing import Optional, Mapping + + +import flax.linen as nn +import jax +import jax.numpy as jnp + + +class Model(nn.Module, metaclass=abc.ABCMeta): + """Abstract VAE model class.""" + + codeword_dim: Optional[int] = None + code_len: int = 256 + code_dropout: str = "none" + + @abc.abstractmethod + def encode( + self, + x: jax.Array, + *, + train: bool = False, + ) -> tuple[jax.Array, jax.Array]: + ... + + def reparametrize( + self, + mu: jax.Array, + logvar: jax.Array, + rng: jax.Array | None = None, + ) -> jax.Array: + std = jnp.exp(0.5 * logvar) + if rng is None: + rng = self.make_rng("dropout") + eps = jax.random.normal(rng, shape=std.shape, dtype=std.dtype) + return mu + std * eps + + @abc.abstractmethod + def decode( + self, x: jax.Array, + train: bool = False, + ) -> jax.Array | Mapping[str, jax.Array]: + ... + + def code_dropout_fn(self, z: jax.Array, *, train: bool = False) -> jax.Array: + # "seq" drops out tokens later in the sequence with higher probablility than + # tokens earlier in the sequence. + assert self.code_dropout in ["none", "seq", "random"] + if train and self.code_dropout != "none": + importance = jnp.linspace(1.0, 0.0, self.code_len + 2)[1:-1] + thr = jax.random.uniform(self.make_rng("dropout"), z.shape[:1]) + mask = importance[None, :] > thr[:, None] + if self.code_dropout == "random": + mask = jax.random.permutation( + self.make_rng("dropout"), mask, axis=-1, independent=True) + z = z * mask[:, :, None] + return z + + def __call__( + self, + x: jax.Array, + *, + train: bool = False, + ) -> tuple[jax.Array | Mapping[str, jax.Array], Mapping[str, jax.Array]]: + mu, logvar = self.encode(x, train=train) + # Only reparametrize when training for simplicity. + if train: + z = self.reparametrize(mu, logvar) + else: + z = mu + z = self.code_dropout_fn(z, train=train) + x = self.decode(z, train=train) + return x, {"mu": mu, "logvar": logvar, "z": z} diff --git a/big_vision/models/proj/givt/vit.py b/big_vision/models/proj/givt/vit.py new file mode 100644 index 0000000000000000000000000000000000000000..7ed36ef069db09868a8b3be97ce6cb607b7e4053 --- /dev/null +++ b/big_vision/models/proj/givt/vit.py @@ -0,0 +1,188 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Simple VAE fork of the UViM VQ-VAE (proj/uvim/vit.py) with small changes.""" + +from typing import Optional, Sequence, Mapping, Any + +from big_vision import utils +from big_vision.models import common +from big_vision.models import vit +from big_vision.models.proj.givt import vae + +import einops +import flax.linen as nn +import flax.training.checkpoints +import jax +import jax.numpy as jnp +import numpy as np + + +class Model(vae.Model): + """ViT model.""" + + input_size: Sequence[int] = (256, 256) + patch_size: Sequence[int] = (16, 16) + width: int = 768 + enc_depth: int = 6 + dec_depth: int = 6 + mlp_dim: Optional[int] = None + num_heads: int = 12 + posemb: str = "learn" # Can also be "sincos2d" + dropout: float = 0.0 + head_zeroinit: bool = True + bottleneck_resize: bool = False + inout_specs: Optional[Mapping[str, tuple[int, int]]] = None + scan: bool = False + remat_policy: str = "nothing_saveable" + + def setup(self) -> None: + self.grid_size = np.array(self.input_size) // np.array(self.patch_size) + + self.embedding = nn.Conv( + self.width, self.patch_size, strides=self.patch_size, + padding="VALID", name="embedding") + + self.pos_embedding_encoder = vit.get_posemb( + self, self.posemb, self.grid_size, self.width, "pos_embedding_encoder") + self.encoder = vit.Encoder( + depth=self.enc_depth, + mlp_dim=self.mlp_dim, + num_heads=self.num_heads, + dropout=self.dropout, + scan=self.scan, + remat_policy=self.remat_policy, + name="encoder") + + if not self.bottleneck_resize: + self.bottleneck_downsample = self.param( + "bottleneck_downsample", + nn.initializers.xavier_uniform(), + (np.prod(self.grid_size), self.code_len)) + + if not self.bottleneck_resize: + self.bottleneck_upsample = self.param( + "bottleneck_upsample", + nn.initializers.xavier_uniform(), + (self.code_len, np.prod(self.grid_size))) + + self.pos_embedding_decoder = vit.get_posemb( + self, self.posemb, self.grid_size, self.width, "pos_embedding_decoder") + self.decoder = vit.Encoder( + depth=self.dec_depth, + mlp_dim=self.mlp_dim, + num_heads=self.num_heads, + dropout=self.dropout, + scan=self.scan, + remat_policy=self.remat_policy, + name="decoder") + + # Setting num_outputs to 2 * codeword_dim to predict mean and variance per + # element + self.encoder_head = nn.Dense(self.codeword_dim * 2 or self.width * 2) + self.decoder_stem = nn.Dense(self.width) + + kw = {"kernel_init": nn.initializers.zeros} if self.head_zeroinit else {} + + if self.inout_specs is not None: + num_out_channels = sum( + num_classes for _, num_classes in self.inout_specs.values()) + else: + num_out_channels = 3 + + self.head = nn.Dense( + num_out_channels * np.prod(self.patch_size), + name="decoder_head", **kw) + + def encode( + self, + x: jax.Array, + *, + train: bool = False, + ) -> tuple[jax.Array, jax.Array]: + if self.inout_specs is not None: + one_hot_inputs = [] + for in_ch, num_classes in self.inout_specs.values(): + one_hot_inputs.append(nn.one_hot(x[..., in_ch], num_classes)) + x = jnp.concatenate(one_hot_inputs, axis=-1) + x = self.embedding(x) + x = einops.rearrange(x, "b h w c -> b (h w) c") + + x, _ = self.encoder(x + self.pos_embedding_encoder, deterministic=not train) + + if self.bottleneck_resize: + x = einops.rearrange(x, "b (h w) c -> b h w c", + h=self.grid_size[0], w=self.grid_size[1]) + l = int(np.round(self.code_len ** 0.5)) + x = jax.image.resize( + x, (x.shape[0], l, l, x.shape[3]), + method="linear") + x = einops.rearrange(x, "b h w c -> b (h w) c") + else: + x = jnp.einsum("btc,tn->bnc", x, self.bottleneck_downsample) + + x = self.encoder_head(x) + + mu, logvar = jnp.split(x, 2, axis=-1) + return mu, logvar + + def decode( + self, + x: jax.Array, + train: bool = False, + ) -> jax.Array | Mapping[str, jax.Array]: + x = self.decoder_stem(x) + + if self.bottleneck_resize: + l = int(np.round(self.code_len ** 0.5)) + x = einops.rearrange(x, "b (h w) c -> b h w c", h=l, w=l) + x = jax.image.resize( + x, (x.shape[0], self.grid_size[0], self.grid_size[1], x.shape[3]), + method="linear") + x = einops.rearrange(x, "b h w c -> b (h w) c") + else: + x = jnp.einsum("bnc,nt->btc", x, self.bottleneck_upsample) + + x, _ = self.decoder(x + self.pos_embedding_decoder, deterministic=not train) + x = self.head(x) + # c = 3 for RGB images + x = einops.rearrange(x, "b (h w) (p q c) -> b (h p) (w q) c", + h=self.grid_size[0], w=self.grid_size[1], + p=self.patch_size[0], q=self.patch_size[1]) + + if self.inout_specs is None: + x = jnp.clip(x, -1.0, 1.0) + else: + x_dict = {} + channel_index = 0 + for name, (_, num_channels) in self.inout_specs.items(): + x_dict[name] = x[..., channel_index : channel_index + num_channels] + channel_index += num_channels + x = x_dict + + return x + + +def load( + init_params: Any, + init_file: str, + model_params: Any = None, + dont_load: Sequence[str] = (), +) -> Any: + """Loads params from init checkpoint and merges into init_params.""" + del model_params + params = flax.core.unfreeze(utils.load_params(init_file)) + if init_params is not None: + params = common.merge_params(params, init_params, dont_load) + return params diff --git a/big_vision/models/proj/image_text/text_transformer.py b/big_vision/models/proj/image_text/text_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..696eab5b4f2bf6e37a50ea3b2b02688a860bb254 --- /dev/null +++ b/big_vision/models/proj/image_text/text_transformer.py @@ -0,0 +1,119 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Transformer encoders for text, similar to CLIP.""" + +from typing import Any + +from big_vision import utils +from big_vision.models import common +from big_vision.models import vit +import flax.linen as nn +import flax.training.checkpoints +import numpy as np + +ConfigDict = Any + + +class _Model(nn.Module): + """Text transformer similar to CLIP.""" + + # Differences to CLIP text encoder (gpt-2) that I am aware of: + # 1. https://imgur.com/HNi3jix (gpt-1) + # 2. https://imgur.com/qKGZgBR (gpt-2) + # 3. https://imgur.com/a/xrpYHF0 (clip) + # - LayerNorm is on res-path (like pre-activation resnet) + # - dropout 0.1 everywhere + # - init as var=0.02, scaled by depth + # - BOS and EOS tokens, take repr from EOS. + # - self-attention is autoregressively masked. + # - scaled in width only, with the image model. + + num_classes: int + width: int = 512 + depth: int = 12 + mlp_dim: int = 2048 + num_heads: int = 8 + dropout: float = 0.0 + vocab_size: int = 32_000 + pool_type: str = "last" + scan: bool = False + remat_policy: str = "nothing_saveable" + + @nn.compact + def __call__(self, text, *, train=False): + out = {} + + # We can't use where/argwhere since the output shape is not fixed. + # Here we use the fact that sequences are padded with EOS tokens, that the + # EOS token has value 1, and that argmin returns the first index. + # eos_indices = jnp.argmin(text, axis=1) + + embedding = nn.Embed(num_embeddings=self.vocab_size, features=self.width) + x = out["embedded"] = embedding(text) + + # Add posemb + n, l, d = x.shape # pylint: disable=unused-variable + x = x + self.param("pos_embedding", + nn.initializers.normal(stddev=1/np.sqrt(d)), + (1, l, d), x.dtype) + + x, encoder_out = vit.Encoder( + depth=self.depth, mlp_dim=self.mlp_dim, num_heads=self.num_heads, + scan=self.scan, remat_policy=self.remat_policy, dropout=self.dropout)( + x, deterministic=not train) + + out.update({"transformed": x, **encoder_out}) + + # Share weights between embeddings and logit transformation. + out["vocab_logits"] = embedding.attend(x) + + if self.pool_type == "last": + # Assuming "sticky" EOS tokenization, last token is always EOS. + x = out["pre_logits"] = x[:, -1, :] + elif self.pool_type == "first": + x = out["pre_logits"] = x[:, 0, :] + elif self.pool_type in ("mean", "gap"): + x = out["pre_logits"] = x.mean(axis=1) + elif self.pool_type in ("max", "gmp"): + x = out["pre_logits"] = x.max(axis=1) + elif self.pool_type == "map": + x = out["pre_logits"] = vit.MAPHead( + num_heads=self.num_heads, mlp_dim=self.mlp_dim)(x) + else: + raise NotImplementedError(f"Cannot do pooling '{self.pool_type}'") + + if self.num_classes: + x = out["logits"] = nn.Dense(self.num_classes, name="head")(x) + return x, out + + +def Model(num_classes, *, variant=None, **kw): # pylint: disable=invalid-name + """Factory function, because linen really don't like what I'm doing!""" + return _Model(num_classes, **{**vit.decode_variant(variant), **kw}) + + +def load(init_params, init_file, model_cfg, dont_load=()): # pylint: disable=invalid-name + """Load init from checkpoint, both old model and this one. +Hi-res posemb.""" + del model_cfg # unused + params = utils.load_params(init_file) + params = flax.core.unfreeze( + flax.training.checkpoints.convert_pre_linen(params)) + + # Some older (but expensive to train) checkpoints had the posemb added twice + # by mistake. We detect this here and merge them. + extra_posemb = params["Encoder_0"].pop("pos_embedding", 0) + params["pos_embedding"] += extra_posemb + + return common.merge_params(params, init_params, dont_load) diff --git a/big_vision/models/proj/image_text/two_towers.py b/big_vision/models/proj/image_text/two_towers.py new file mode 100644 index 0000000000000000000000000000000000000000..a9a045eaeb72383795bf903b51d89a6150bc3a16 --- /dev/null +++ b/big_vision/models/proj/image_text/two_towers.py @@ -0,0 +1,154 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Transformer encoders both for text and for images.""" + +import importlib +from typing import Any, Optional, Tuple, Union +from absl import logging + +from big_vision import utils +import flax.linen as nn +import jax.numpy as jnp + +ConfigDict = Any + + +class Model(nn.Module): + """Two towers transformer.""" + image: Optional[ConfigDict] = None + text: Optional[ConfigDict] = None + text_model: str = "proj.image_text.text_transformer" + image_model: str = "vit" + out_dim: Union[int, Tuple[int, int]] = 128 + temperature_init: float = 1.0 + bias_init: Optional[float] = None + + @nn.compact + def __call__(self, image, text=None, **kw): + """Returns (B,C) image and (B,C) text representations.""" + + # Support calling without text or without image, for example for few-shot. + ztxt, zimg = None, None + out = {} + out_dims = self.out_dim + if isinstance(out_dims, int): + out_dims = (out_dims, out_dims) + + # Embed the text: + if text is not None: + text_model = importlib.import_module( + f"big_vision.models.{self.text_model}" + ).Model(**{"num_classes": out_dims[1], **(self.text or {})}, name="txt") + + ztxt, out_txt = text_model(text, **kw) + for k, v in out_txt.items(): + out[f"txt/{k}"] = v + + # Normalize the embeddings the models give us. + out["txt/norm"] = jnp.linalg.norm(ztxt, axis=1, keepdims=True) + out["txt/normalized"] = ztxt = ztxt / (out["txt/norm"] + 1e-8) + + if image is not None: + image_model = importlib.import_module( + f"big_vision.models.{self.image_model}" + ).Model(**{"num_classes": out_dims[0], **(self.image or {})}, name="img") # pylint: disable=not-a-mapping + + zimg, out_img = image_model(image, **kw) + for k, v in out_img.items(): + out[f"img/{k}"] = v + + # Normalize the embeddings the models give us. + out["img/norm"] = jnp.linalg.norm(zimg, axis=1, keepdims=True) + out["img/normalized"] = zimg = zimg / (out["img/norm"] + 1e-8) + + temp_init = jnp.log(self.temperature_init) + t = self.param("t", + lambda key, shape, dtype: temp_init * jnp.ones(shape, dtype), + (1,), jnp.float32) + out["t"] = jnp.exp(t) + + out["t/parameter"] = t + if (b_init := self.bias_init) is not None: + out["b"] = self.param("b", lambda k, s, d: b_init * jnp.ones(s, d), + (1,), jnp.float32) + + # We could actually play with pre-multiplying by temperature here, such + # that out["t"] is nothing special to the trainer anymore. + + return zimg, ztxt, out + + +def load(init_params, init_files, model_cfg, img_load_kw={}, txt_load_kw={}): # pylint: disable=dangerous-default-value + """Loads both towers, `init_files` is now a dict with `img` and `txt` keys.""" + if isinstance(init_files, str): + init_files = VANITY_NAMES.get(init_files, init_files) + + if isinstance(init_files, str): + # A shortcut for a single file checkpoint of a two_towers model. + if "bias_init" in model_cfg.keys(): + logging.info("loading img, txt, t, and b from a single checkpoint.") + init_files = {k: f"{init_files}:{k}" for k in ("img", "txt", "t", "b")} + else: + logging.info("loading img, txt, and t from a single checkpoint.") + init_files = {k: f"{init_files}:{k}" for k in ("img", "txt", "t")} + else: + init_files = {**init_files} # Shallow copy because we'll pop stuff off. + + if not init_params: # Convenience to skip checks in colab. + init_params = {"img": None, "txt": None} + restored_params = {**init_params} + + img_init = init_files.pop("image", init_files.pop("img", None)) + if img_init: + restored_params["img"] = importlib.import_module( + f"big_vision.models.{model_cfg.get('image_model', 'vit')}" + ).load(init_params["img"], img_init, model_cfg.image, **img_load_kw) + + txt_init = init_files.pop("text", init_files.pop("txt", None)) + if txt_init: + restored_params["txt"] = importlib.import_module( + f"big_vision.models.{model_cfg.get('text_model', 'proj.image_text.text_transformer')}" # pylint: disable=line-too-long + ).load(init_params["txt"], txt_init, model_cfg.text, **txt_load_kw) + + t_init = init_files.pop("temperature", init_files.pop("t", None)) + if t_init: + restored_params["t"] = utils.load_params(t_init) + + b_init = init_files.pop("bias", init_files.pop("b", None)) + if b_init: + restored_params["b"] = utils.load_params(b_init) + + assert not init_files, ( + f"There's something unused left in `config.model_init`. You probably got " + f"a typo. Here it is: {init_files}") + + return restored_params + + +# Shortcut names for some canonical paper checkpoints: +VANITY_NAMES = { + # pylint: disable=line-too-long + # SigLIP image encoder checkpoints from https://arxiv.org/abs/2303.15343 + "SigLIP B/16 224": "gs://big_vision/siglip/webli_en_b16_224_63724782.npz", + "SigLIP B/16 256": "gs://big_vision/siglip/webli_en_b16_256_60500360.npz", + "SigLIP B/16 384": "gs://big_vision/siglip/webli_en_b16_384_68578854.npz", + "SigLIP B/16 512": "gs://big_vision/siglip/webli_en_b16_512_68580893.npz", + "SigLIP L/16 256": "gs://big_vision/siglip/webli_en_l16_256_60552751.npz", + "SigLIP L/16 384": "gs://big_vision/siglip/webli_en_l16_384_63634585.npz", + "SigLIP So400m/14 224": "gs://big_vision/siglip/webli_en_so400m_224_57633886.npz", + "SigLIP So400m/14 384": "gs://big_vision/siglip/webli_en_so400m_384_58765454.npz", + "SigLIP B/16-i18n 256": "gs://big_vision/siglip/webli_i18n_b16_256_66117334.npz", + # pylint: enable=line-too-long +} diff --git a/big_vision/models/proj/paligemma/__pycache__/gemma_bv.cpython-310.pyc b/big_vision/models/proj/paligemma/__pycache__/gemma_bv.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dece1ba5ca107ce8e1382cd3a4acf8f3906a3fdd Binary files /dev/null and b/big_vision/models/proj/paligemma/__pycache__/gemma_bv.cpython-310.pyc differ diff --git a/big_vision/models/proj/paligemma/__pycache__/paligemma.cpython-310.pyc b/big_vision/models/proj/paligemma/__pycache__/paligemma.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ee61aa3134e2da9be775dc80e22fd2644b0f9647 Binary files /dev/null and b/big_vision/models/proj/paligemma/__pycache__/paligemma.cpython-310.pyc differ diff --git a/big_vision/models/proj/paligemma/gemma_bv.py b/big_vision/models/proj/paligemma/gemma_bv.py new file mode 100644 index 0000000000000000000000000000000000000000..bdc01e20c50454991e6ffb2c5399e96e24e3f346 --- /dev/null +++ b/big_vision/models/proj/paligemma/gemma_bv.py @@ -0,0 +1,182 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Gemma wrapper to make it work for us.""" + +from big_vision.models.ppp import gemma +import flax.linen as nn +import jax +import jax.numpy as jnp + + +def _get_config(model): + config = gemma.get_config(model.variant) + config.scan = model.scan + config.remat_policy = model.remat_policy + if model.vocab_size is not None: + config.vocab_size = model.vocab_size + config.dropout = model.dropout + config.dropout_bdims = model.dropout_bdims + config.cache_dtype = model.cache_dtype + return config + + +@jax.vmap +def _left_to_right_align(x, input_mask, attn_mask): + """Converts input from left-align to right-aligned.""" + # Due to vmap, this is operating in a single example (not batch level). + assert x.ndim == 2 and input_mask.ndim == 1 and attn_mask.ndim == 2 + assert x.shape[0] == input_mask.shape[0] + assert attn_mask.shape[0] == attn_mask.shape[1], attn_mask.shape + seqlen = jnp.sum(input_mask) + x = jnp.roll(x, -seqlen, axis=0) + input_mask = jnp.roll(input_mask, -seqlen, axis=0) + attn_mask = jnp.roll(attn_mask, -seqlen, axis=(0, 1)) + return x, input_mask, attn_mask + + +class Model(nn.Module): + """Wrapping gemma big_vision model.""" + variant: str = "gemma_2b" + scan: bool = True + remat_policy: str = "nothing_saveable" + vocab_size: int | None = None + + dropout: float = 0.0 + dropout_bdims: tuple[int, ...] = () # Every float is dropped independently. + cache_dtype: str | None = "bfloat16" # bfloat16 to save memory and transfers. + + def setup(self): + # The parent+name avoids an unnecessary nesting in params pytree. + self.model = gemma.Model(**_get_config(self), parent=self.scope, name="") + + def embed_tokens(self, tokens, train=False): + # Turns int32[B,T] tokens into float32[B,T,d_model] embeddings. + # Really just the vocab embedding. + return self.model(tokens, embed_only=True, deterministic=not train) + + def compute_logits(self, pre_logits, train=False): + return self.model(None, pre_logits=pre_logits, deterministic=not train)[0] + + def __call__(self, embs, mask=None, train=False): + # Turns float32[B,T,d_model] embedding sequence to logits. + # call(emb_tokens(tokens)) should be a forward pass. + # Allow for specifying int32[B,T,T] attention masks. For convenience + # default to triangular autorgressive mask when None, but not P0. + # Return float32[B,T,vocab_size] logits and out-dict. + + batch_size, _, d_model = embs.shape + assert d_model == self.embdim + logits, out = self.model( + tokens=jnp.zeros([batch_size, 0], dtype=jnp.int32), + embedded_prefix=embs, + mask=mask, + deterministic=not train, + ) + return logits, out + + def prefill_cache(self, x, input_mask, attn_mask, *, cache_size): + """Initializes decoding cache with `x` [B, N, E] as prompt. + + IMPORTANT: Inputs MUST be left-aligned and attn_mask should not allow + input tokens to attend to padding tokens. + + TODO: Relax left-align requirement by converting any input into + a right aligned input with no attention to padding tokens. + + Args: + x: float[B, N, E] with prompt tokens. + input_mask: bool[B, N]. True indicates tokens are part of the prompt. + False indicates padding tokens. This class doesn't combine this with + attn_mask, so mask out the attention to padding tokens beforehand. + attn_mask: bool[B, N, N]. Indicates which tokens can attend to which while + processing the prompt tokens. During extend_cache tokens, it is assumed + that tokens can attend all previous valid tokens. + cache_size: int. Indicates the size of the cache. The prompt will consume + the first N entries of the cache. Each subsequent extend_cache will + consume one entry. Behaviour is undefined when prefill_len plus number + of extend_cache exceeds the cache_size. + + Returns: + logits of the last valid token (i.e. last logits where input_mask=True). + """ + # To call the model with decode=True we need to be able to provide: + # (a) positions of tokens [B, N], ([B, 1] for extend) + # (b) attention mask [B, N, cache_size] ([B, 1, cache_size] for extend) + # + # To do so we track how many tokens each example has seen so far, and we + # align the prompt to the right so that cache usage for each example is in + # a continuous subsequent of (cache_begin, cache_end] such that cache_end + # is the same for all sequences (this allows to do faster row updates of + # the cache during decoding). + x, input_mask, attn_mask = _left_to_right_align(x, input_mask, attn_mask) + + # Track sequence len + seq_len = jnp.sum(input_mask, axis=-1) + self.put_variable("cache", "seq_len", seq_len) + positions = jnp.cumsum(input_mask, axis=-1) - 1 + + # Initialize cache_begin and cache_end. Note: cache_end is the same for all + # sequences but we keep it per example to allow easy sharding rules with + # batch as the first axis. + batch_size, prefill_len, _ = x.shape + self.put_variable("cache", "cache_begin", prefill_len - seq_len) + self.put_variable( + "cache", "cache_end", jnp.full((batch_size,), prefill_len, jnp.int32) + ) + + # Pad attention to set the cache size. + mask = jnp.pad(attn_mask, ((0, 0), (0, 0), (0, cache_size - prefill_len))) + + _, aux = self.model( + tokens=None, + embedded_prefix=x, + positions=positions, + mask=mask, + decode=True, + ) + return self.compute_logits(aux["pre_logits"][:, -1:]) + + def extend_cache(self, x): + """Extends decoding cache with `x` [B, 1, E] and returns logits.""" + assert x.shape[1] == 1, "Only supports extend the cache by one token." + if self.model.scan: + cache_size = self.variables["cache"]["layers"]["attn"]["k_cache"].shape[2] + else: + raise NotImplementedError("Not implemented yet.") + + # Lookup current token position and increment by one for next call. + positions = self.get_variable("cache", "seq_len") + self.put_variable("cache", "seq_len", positions + 1) + + # Update which cache positions are in use and construct attention mask. + # Tokens can attend to all cache positions which are in use including self. + cache_begin = self.get_variable("cache", "cache_begin") + cache_end = self.get_variable("cache", "cache_end") + 1 + self.put_variable("cache", "cache_end", cache_end) + mask = jnp.logical_and( + jnp.arange(cache_size)[None, None, :] >= cache_begin[:, None, None], + jnp.arange(cache_size)[None, None, :] < cache_end[:, None, None]) + + logits, _ = self.model( + tokens=None, embedded_prefix=x, + positions=positions[:, None], mask=mask, decode=True) + return logits + + @property + def embdim(self): + return _get_config(self).width + + +load = gemma.load diff --git a/big_vision/models/proj/paligemma/paligemma.py b/big_vision/models/proj/paligemma/paligemma.py new file mode 100644 index 0000000000000000000000000000000000000000..95aa62b6a82ba3015b25f1525e78b10f1957b368 --- /dev/null +++ b/big_vision/models/proj/paligemma/paligemma.py @@ -0,0 +1,289 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Image encoder + AR-decoder LLM.""" + +import importlib +from typing import Any, Optional + +import flax.linen as nn +import jax +import jax.numpy as jnp + +ConfigDict = Any + + +def make_attn_mask(input_mask, mask_ar): + """Returns attention mask bool[B, N, N] to use in transformer. + + Tokens can attend to valid inputs tokens which have a cumulative mask_ar + smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to + setup several types of attention, for example: + + [[1 1 1 1 1 1]]: pure causal attention. + + [[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between + themselves and the last 3 tokens have a causal attention. The first + entry could also be a 1 without changing behaviour. + + [[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a + block can attend all previous blocks and all tokens on the same block. + + Args: + input_mask: bool[B, N] true if its part of the input, false if padding. + mask_ar: int32[B, N] mask that's 1 where previous tokens cannot depend on + it and 0 where it shares the same attention mask as the previous token. + """ + cumsum = jnp.cumsum(mask_ar, axis=1) + attn_mask = (cumsum[:, None, :] <= cumsum[:, :, None]) + valid_mask = (input_mask[:, None, :] * input_mask[:, :, None]) + return jnp.logical_and(attn_mask, valid_mask) + + +class Model(nn.Module): + """Two towers transformer.""" + img_model: str = "vit" + img: Optional[ConfigDict] = None + llm_model: str = "proj.paligemma.gemma_bv" + llm: Optional[ConfigDict] = None + + def setup(self): + self._llm = importlib.import_module( + f"big_vision.models.{self.llm_model}" + ).Model(**(self.llm or {}), name="llm") + + img_config = {"num_classes": self._llm.embdim, **(self.img or {})} + self._img_model = importlib.import_module( + f"big_vision.models.{self.img_model}" + ).Model(**img_config, name="img") + + def embed_image(self, image, train=False): + out = {} + + # if we have video, fold frame dimension into the batch dimension + image_shape = image.shape + if len(image_shape) == 5: # video frames + image = jnp.reshape(image, (-1, *image.shape[-3:])) + + # Do we want to normalize? are they huge? + zimg, out_img = self._img_model(image, train=train) + + if len(image_shape) == 5: # concatenate tokens from all video frames + zimg = jnp.reshape(zimg, (image_shape[0], -1, zimg.shape[-1])) + + out["img/zimg"] = zimg + for k, v in out_img.items(): + out[f"img/{k}"] = v + return zimg, out + + def embed_text(self, tokens, train=False): + out = {} + ztxt = out["llm/ztxt"] = self._llm.embed_tokens(tokens, train=train) + return ztxt, out + + def embed_image_and_text(self, image, text, *, + input_mask=None, mask_ar=None, train=False): + """Concats image/text into a sequence of embeded tokens to pass to `llm`. + + Args: + image: float[B, H, W, 3] image to be embedded by the `img` model and used + as prefix to the sequence passed to the `llm` model. + text: int32[B, T] token sequence to embedded by the `llm`. + input_mask: bool[B, T] true if the text token is a valid token and false + if its a token to pad the sequence. Defaults to all being input tokens. + mask_ar: int32[B, T] mask that's 1 where `text` should be attended to + causally, and 0 where it can be attended to with full self-attention. + Defaults to all text tokens being auto-regressive. + train: bool whether we're in train or test mode (dropout etc). + + Returns: + Tuple (x: float[B, N, E], input_mask: bool[B, N], mask_ar: int[B, N]) and + auxiliary outputs. + """ + zimg, out_img = self.embed_image(image, train=train) + ztxt, out_txt = self.embed_text(text, train=train) + + if input_mask is None: + input_mask = jnp.full(text.shape, True) + if mask_ar is None: + mask_ar = jnp.full(text.shape, 1) + + # Concatenate embeded image and text into a single token sequence. + x = jnp.concatenate([zimg, ztxt], axis=1) + _, img_len, _ = zimg.shape + pad_width = ((0, 0), (img_len, 0)) + mask_ar = jnp.pad(mask_ar, pad_width, constant_values=0) + input_mask = jnp.pad(input_mask, pad_width, constant_values=True) + + return (x, input_mask, mask_ar), {**out_img, **out_txt} + + def __call__(self, image, text, mask_ar, train=False): + """Concats image/text and returns text logits. + + Args: + image: float32[B, H, W, 3] image that can be passed to the `img` model. + text: int32[B, T] token sequence that can be embedded by the `txt` model. + mask_ar: int32[B, T] mask that's 1 where `text` should be attended to + causally, and 0 where it can be attended to with full self-attention. + train: bool whether we're in train or test mode (dropout etc). + + Returns: + float32[B, T, V] logits for the `text` input, and an out-dict of named + intermediates. + """ + # Embed the image and text. + (x, input_mask, mask_ar), out = self.embed_image_and_text( + image, text, mask_ar=mask_ar, train=train) + + # Call transformer on the embedded token sequence. + attn_mask = out["attn_mask"] = make_attn_mask(input_mask, mask_ar) + _, out_llm = self._llm(x, mask=attn_mask, train=train) + for k, v in out_llm.items(): + out[f"llm/{k}"] = v + + # Extract the logits for the text tokens. + zimg = out["img/zimg"] + text_pre_logits = out["llm/pre_logits"][:, zimg.shape[1]:, :] + text_logits = self._llm.compute_logits(text_pre_logits, train=train) + out["text_logits"] = text_logits + out["text_tokens"] = jnp.argmax(text_logits, axis=-1) + return text_logits, out + + def prefill_cache(self, x, input_mask, mask_ar, *, cache_size): + """Initializes decoding cache with `x` [B, N, E] as prompt.""" + if hasattr(self._llm, "prefill_cache"): + attn_mask = make_attn_mask(input_mask, mask_ar) + return self._llm.prefill_cache( + x, input_mask, attn_mask, cache_size=cache_size) + else: + return self._fallback_prefill_cache(x, input_mask, mask_ar, cache_size) + + def extend_cache(self, x): + """Advances decoding cache with `x` [B, 1, E].""" + if hasattr(self._llm, "prefill_cache"): + return self._llm.extend_cache(x) + else: + return self._fallback_extend_cache(x) + + def _fallback_prefill_cache(self, x, input_mask, mask_ar, cache_size): + # FALLBACK: only cache inputs and call the model with the full sequence + # for each and every decode step. Very slowwww... + # + # This very slow codepath does not requires the model to implement caching. + # It is intended to allow to plug any model under development quite early + # into some decoding tasks and not as a long term decoding solution. + attn_mask = make_attn_mask(input_mask, mask_ar) + logits, _ = self._llm(x, mask=attn_mask) + + # Save the prefill inputs for subsequent extend_calls in the cache. + # Unused entries are zero-initialized. + pad_size = cache_size - x.shape[1] + x = jnp.pad(jnp.where(input_mask[..., None], x, 0), + [(0, 0), (0, pad_size), (0, 0)]) + mask_ar = jnp.pad(jnp.where(input_mask, mask_ar, 0), + [(0, 0), (0, pad_size)]) + input_mask = jnp.pad(input_mask, [(0, 0), (0, pad_size)]) + self.put_variable("cache", "x_cache", x) + self.put_variable("cache", "input_mask_cache", input_mask) + self.put_variable("cache", "mask_ar_cache", mask_ar) + + # Extract logits of the last token (using einsum). + last_pos = jnp.sum(input_mask, axis=1)[:, None] - 1 + last_onehot = jax.nn.one_hot(last_pos, logits.shape[1], dtype=jnp.int32) + last_logits = jnp.einsum("bnh,ben->beh", logits, last_onehot) + + return last_logits + + def _fallback_extend_cache(self, x): + # FALLBACK: append inputs to cache and call the model with the full sequence + # for each and every decode step. Very slowwww... + assert x.shape[1] == 1 + mask_ar = jnp.full(x.shape[:-1], 1) + input_mask = jnp.full(x.shape[:-1], True) + + # Append inputs to cache by add/or on the next available cache position, + # which is zero-initialized. + c_x = self.get_variable("cache", "x_cache") + c_input_mask = self.get_variable("cache", "input_mask_cache") + c_mask_ar = self.get_variable("cache", "mask_ar_cache") + next_pos = jnp.sum(c_input_mask, axis=1)[:, None] + move_onehot = jax.nn.one_hot(next_pos, c_x.shape[1], dtype=jnp.int32) + x = jnp.add(c_x, jnp.einsum("beh,ben->bnh", x, move_onehot)) + mask_ar = jnp.add(c_mask_ar, jnp.einsum("be,ben->bn", mask_ar, move_onehot)) + input_mask = jnp.logical_or( + c_input_mask, jnp.einsum("be,ben->bn", input_mask, move_onehot)) + self.put_variable("cache", "x_cache", x) + self.put_variable("cache", "input_mask_cache", input_mask) + self.put_variable("cache", "mask_ar_cache", mask_ar) + + # Call model on the full cached sequence. + attn_mask = make_attn_mask(input_mask, mask_ar) + logits, _ = self._llm(x, mask=attn_mask) + + # Extract logits of the last token. + last_pos = jnp.sum(input_mask, axis=1)[:, None] - 1 + last_onehot = jax.nn.one_hot(last_pos, logits.shape[1], dtype=jnp.int32) + last_logits = jnp.einsum("bnh,ben->beh", logits, last_onehot) + + return last_logits + + +# pylint: disable=line-too-long +import os +GEMMA_DIR = os.environ.get("BV_GEMMA_DIR", "PLEASE_SET_BV_GEMMA_DIR") +VANITY_NAMES = { + # Because checkpoints are behind an ACK-wall, the user has to download them + # to some folder (or bucket), take that from an environment variable. + "pt_224": os.path.join(GEMMA_DIR, "pt_224.npz"), + "pt_224.bf16": os.path.join(GEMMA_DIR, "pt_224.bf16.npz"), + "pt_224.f16": os.path.join(GEMMA_DIR, "pt_224.f16.npz"), + "pt_448": os.path.join(GEMMA_DIR, "pt_448.npz"), + "pt_448.bf16": os.path.join(GEMMA_DIR, "pt_448.bf16.npz"), + "pt_448.f16": os.path.join(GEMMA_DIR, "pt_448.f16.npz"), + "pt_896": os.path.join(GEMMA_DIR, "pt_896.npz"), + "pt_896.bf16": os.path.join(GEMMA_DIR, "pt_896.bf16.npz"), + "pt_896.f16": os.path.join(GEMMA_DIR, "pt_896.f16.npz"), +} +# pylint: enable=line-too-long + + +def load(init_params, init_files, model_cfg, img_load_kw={}, llm_load_kw={}): # pylint: disable=dangerous-default-value + """Loads both pieces, `init_files` is now a dict with `img` and `llm` keys.""" + + # A slight shortcut when loading an already combined model: + if isinstance(init_files, str): + init_files = VANITY_NAMES.get(init_files, init_files) + init_files = {"img": f"{init_files}:img", "llm": f"{init_files}:llm"} + + if not init_params: # Convenience to skip checks in colab. + init_params = {"img": None, "llm": None} + restored_params = {**init_params} + + init_files = {**init_files} # Needed because ConfigDict but we'll pop stuff. + + if img_init := init_files.pop("img", None): + restored_params["img"] = importlib.import_module( + f"big_vision.models.{model_cfg.get('img_model', 'vit')}" + ).load(init_params["img"], img_init, model_cfg.img, **img_load_kw) + + if llm_init := init_files.pop("llm", None): + restored_params["llm"] = importlib.import_module( + f"big_vision.models.{model_cfg.get('llm_model', 'proj.paligemma.gemma_bv')}" + ).load(init_params["llm"], llm_init, model_cfg.llm, **llm_load_kw) + + assert not init_files, ( + f"There's something unused left in `config.model_init`. You probably got " + f"a typo. Here it is: {init_files}") + + return restored_params diff --git a/big_vision/models/proj/uvim/decode.py b/big_vision/models/proj/uvim/decode.py new file mode 100644 index 0000000000000000000000000000000000000000..809d4f342b2ae35a2b83a5edcba5f52fcd165248 --- /dev/null +++ b/big_vision/models/proj/uvim/decode.py @@ -0,0 +1,384 @@ +# Copyright 2022 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Inference.""" +import functools + +from typing import Any, Callable, Optional, Tuple + +import flax +from flax import linen as nn +import jax +from jax import lax +from jax import numpy as jnp + +import numpy as np + + +EOS_ID = 1 +NEG_INF = np.array(-1.0e7) # Effective negative infinity. + + +GenerateFn = Callable[..., + Tuple[jnp.ndarray, jnp.ndarray, Optional[jnp.ndarray]]] + + +def temperature_sampling(*args, temperature=1.0, top_k=0, top_p=0.0, **kwargs): + """Convenience wrapper for temperature sampling.""" + return generate(*args, generate_fn=_temperature_sampling, + temperature=temperature, + top_k=top_k, + top_p=top_p, + **kwargs) + + +def topk_sampling(*args, temperature=1.0, top_k=20, **kwargs): + """Convenience wrapper for top-k sampling.""" + return generate(*args, generate_fn=_temperature_sampling, + temperature=temperature, + top_k=top_k, + top_p=0.0, + **kwargs) + + +def nucleus_sampling(*args, temperature=1.0, top_p=0.2, **kwargs): + """Convenience wrapper for nucleus sampling.""" + return generate(*args, generate_fn=_temperature_sampling, + temperature=temperature, + top_k=0, + top_p=top_p, + **kwargs) + + +def argmax_sampling(*args, **kwargs): + """Convenience wrapper for argmax sampling.""" + return generate(*args, generate_fn=_temperature_sampling, + temperature=1e-7, + top_k=0, + top_p=0.0, + **kwargs) + + +def generate(params, inputs, prompts, seed, *, + model: nn.Module, + generate_fn: GenerateFn, + num_samples: int = 1, + prefill: bool = False, + eos_token: int = EOS_ID, + **generate_fn_kwargs): + """Generate sequence with fast decoding beam search on a batch. + + Model must support: + encode(inputs) -> encoded, or encode(*inputs) -> encoded. + decode(encoded, prompts, decode=True/False, max_decode_length) -> logits + + Args: + params: model parameters. + inputs: either a single `jnp.ndarray` of e.g. images, or + a tuple of inputs which are passed via `model.encode(*inputs)`. + prompts: [batch_size, max_decode_len] forced tokens for generation. + prompts need to finish with 0 token, they should not contain the end + markers. If no prompting is required, pass an all zeros tensor. + seed: PRNG key for random sampling. + model: object with methods encode and decode. + generate_fn: search or sampling function to generate sequences. + num_samples: number of samples to generate per item. + prefill: whether to prefill cache. + eos_token: if of end-of-sentence token for target vocabulary. + **generate_fn_kwargs: generate fn specific kwargs. + + Returns: + Top-scoring sequences (worst scores first). + [batch_size, num_samples, max_decode_len] + Scores of the generated sequences (worst scores first). The + returned scores are modified log probabilities. May be absent. + [batch_size, max_decode_len] + Log probs for the generated tokens. May be absent. + [batch_size, num_samples, max_decode_len] + """ + _, max_decode_len = prompts.shape + decode_kwargs = {"max_decode_length": max_decode_len} + + def encode(model, inputs): + if not isinstance(inputs, tuple): + inputs = (inputs,) + return model.encode(*inputs) + + encoded_inputs = nn.apply(encode, model)(params, inputs) + if isinstance(encoded_inputs, tuple): + encoded_inputs, enc_pos_emb = encoded_inputs + decode_kwargs["enc_pos_emb"] = enc_pos_emb + + def init_cache(model): + encoded = jnp.zeros_like(encoded_inputs) + targets = jnp.zeros_like(prompts) + return model.decode(encoded, targets, decode=True, **decode_kwargs) + + cache = nn.apply(init_cache, model, mutable=True)(params)[1]["cache"] + + def prefill_cache(model, encoded, targets): + return model.decode(encoded, targets, prefill=True, **decode_kwargs) + + if prefill: + cache = nn.apply(prefill_cache, model, mutable=True)( + {"params": params["params"], "cache": cache}, + encoded_inputs, prompts)[1]["cache"] + + def tokens_to_logits(tokens, cache): + def decode_step(model, tokens): + encoded = expand_samples_dim_and_flatten( + encoded_inputs, num_samples) + return model.decode(encoded, tokens, decode=True, **decode_kwargs) + + logits, aux = nn.apply(decode_step, model, mutable=True)( + {"params": params["params"], "cache": cache}, tokens) + return logits.squeeze(axis=1), aux["cache"] + + beam_seqs, scores, logprobs = generate_fn( + prompts, + cache, + tokens_to_logits, + num_samples=num_samples, + eos_token=eos_token, + max_decode_len=max_decode_len, + seed=seed, + **generate_fn_kwargs) + return beam_seqs, scores, logprobs + + +def expand_samples_dim(x, num_samples): + """Creates new dimension in non-scalar array and tiles into it.""" + if x.ndim == 0: # ignore scalars (e.g. cache index) + return x + x = jnp.expand_dims(x, axis=1) + tile_dims = [1] * x.ndim + tile_dims[1] = num_samples + return jnp.tile(x, tile_dims) + + +def flatten_samples_dim(x): + """Flattens samples dim into batch dim.""" + if x.ndim == 0: # ignore scalars (e.g. cache index) + return x + return x.reshape((x.shape[0] * x.shape[1],) + x.shape[2:]) + + +def unflatten_samples_dim(x, batch_size, num_samples): + """Unflattens first dim into batch and samples dims.""" + if x.ndim == 0: # ignore scalars (e.g. cache index) + return x + assert batch_size * num_samples == x.shape[0] + return x.reshape((batch_size, num_samples) + x.shape[1:]) + + +def expand_samples_dim_and_flatten(x, num_samples): + """Expands the each batch item by num_samples in batch dimension.""" + return flatten_samples_dim(expand_samples_dim(x, num_samples)) + + +def cache_map(fn, cache): + """Maps function over caches, even multiple caches in various layers.""" + frozen = isinstance(cache, flax.core.FrozenDict) + if frozen: + cache = flax.core.unfreeze(cache) + flat_cache = flax.traverse_util.flatten_dict(cache) + # Exclude cached relative position bias from beam expansion, etc. + keyvals = {k: v for k, v in flat_cache.items() if k[-1] != "cached_bias"} + keyvals = jax.tree_map(fn, keyvals) + flat_cache.update(keyvals) + new_cache = flax.traverse_util.unflatten_dict(flat_cache) + if frozen: + new_cache = flax.core.freeze(new_cache) + return new_cache + + +@flax.struct.dataclass +class LoopState: + """Internal state of the temperature sampling loop.""" + # Position in the sequence that we are currently looking at. + cur_index: int + # Cache for fast auto-regressive decoding. + cache: Any + # Flags indicating whether the sequence reached eos [B*N]. + flags_finished: jnp.ndarray + # Sequences being generated [B*N, L+1]. Note: sequences start with 0 token. + sequences: jnp.ndarray + scores: jnp.array # Total sequence scores per batch element [B*N]. + logprobs: jnp.array # Logprobs of selected tokens [B*N, L]. + rng: jnp.ndarray # PRNGKey of the loop state. + + +def _init_state(prompts, cache, init_rng_key, num_samples): + batch_size, max_decode_len_plus_one = prompts.shape + # Add extra samples dim to attention cache pytree elements. + cache = cache_map( + lambda x: expand_samples_dim_and_flatten(x, num_samples), cache) + return LoopState( + cur_index=0, + cache=cache, + flags_finished=jnp.zeros((batch_size*num_samples), dtype=jnp.bool_), + sequences=expand_samples_dim_and_flatten(prompts, num_samples), + scores=jnp.zeros((batch_size*num_samples)), + logprobs=jnp.zeros((batch_size*num_samples, max_decode_len_plus_one-1)), + rng=init_rng_key) + + +def _should_temperature_sampling_continue(state, max_decode_len): + """Check if we should continue or not.""" + + max_length_not_reached = state.cur_index < max_decode_len - 1 + all_seqs_finished = jnp.all(state.flags_finished) + return max_length_not_reached & (~all_seqs_finished) + + +def _temperature_sampling_iteration(state, tokens_to_logits, temperature, eos, + top_k, top_p, mask_token_ids=()): + """Temperature sampling step function.""" + + rng_sampling, rng = jax.random.split(state.rng) + + # 1. Use the model to generate a distribution over the vocabulary (for the + # next token) and sample from it, optionally applying the temperature. + # --> [B,]. + cur_tokens = state.sequences[:, state.cur_index] + logits, new_cache = tokens_to_logits(cur_tokens[:, None], state.cache) + assert logits.ndim == 2, ("tokens_to_logits expected to return a" + f"2-dimensional array [B, V], got {logits.ndim}" + "dimensions.") + logprobs = jax.nn.log_softmax(logits) + + # Do not sample special tokens in with ids in mask_token_ids. + if mask_token_ids: + probs = jax.nn.softmax(logits) + for i in mask_token_ids: + probs = probs.at[:, i].set(0.) + probs = probs / jnp.sum(probs, -1, keepdims=True) + logits = jnp.log(probs) + + if top_p: # Nucleus sampling. + logits_sorted = jnp.sort(logits, axis=-1)[:, ::-1] + sorted_cum_probs = jnp.cumsum( + jax.nn.softmax(logits_sorted, axis=-1), axis=-1) + cutoff_index = jnp.sum(sorted_cum_probs < top_p, axis=-1, keepdims=True) + cutoff_logit = jnp.take_along_axis(logits_sorted, cutoff_index, axis=-1) + logits = jnp.where(logits < cutoff_logit, + jnp.full_like(logits, NEG_INF), logits) + if top_k: + topk_logits, topk_indices = jax.lax.top_k(logits, top_k) + topk_token = jax.random.categorical(rng_sampling, topk_logits / temperature) + sampled_tokens = jnp.squeeze( + jnp.take_along_axis(topk_indices, jnp.expand_dims(topk_token, -1), + axis=-1), axis=-1) + else: + sampled_tokens = jax.random.categorical(rng_sampling, logits / temperature) + + sampled_logprobs = jnp.squeeze(jnp.take_along_axis( + logprobs, jnp.expand_dims(sampled_tokens, axis=1), axis=-1), axis=-1) + + # 2. Use the sampled tokens to update the sequences that did not finish yet, + # but only if they are out of prompt. + next_tokens = state.sequences[:, state.cur_index + 1] + next_logprobs = jnp.squeeze(jnp.take_along_axis( + logprobs, jnp.expand_dims(next_tokens, axis=1), axis=-1), axis=-1) + out_of_prompt = next_tokens == 0 + update_pos = out_of_prompt * (~state.flags_finished) + next_tokens = sampled_tokens * update_pos + next_tokens * (~update_pos) + sampled_logprobs = update_pos*sampled_logprobs + ~update_pos*next_logprobs + sequences = state.sequences.at[:, state.cur_index + 1].set(next_tokens) + scores = state.scores + sampled_logprobs + seqs_logprobs = state.logprobs.at[:, state.cur_index].set(sampled_logprobs) + + # 3. Update the finished flags. Only out of prompts seqs can finish. + flags_finished = out_of_prompt & (state.flags_finished | + (sampled_tokens == eos)) + return LoopState( + cur_index=state.cur_index+1, + cache=new_cache, + flags_finished=flags_finished, + sequences=sequences, + scores=scores, + logprobs=seqs_logprobs, + rng=rng) + + +def _temperature_sampling(prompts, cache, tokens_to_logits, num_samples=1, + eos_token=EOS_ID, max_decode_len=None, + seed=0, temperature=1., top_k=0, top_p=0.0, + mask_token_ids=()): + """Temperature sampling. + + Purely stochastic sampling-based greedy procedure to generate sequences. Every + next token in the sequence is sampled from the discrete vocab distribution + produced by the auto-regressive sequence model. Optionally we can adjust the + distribution by changing the temperature before sampling from it. Generated + sequences are no longer than max_decode_len. + + Args: + prompts: optional prompts [B, L]. By default (None), we call free form + generation without any prompts. Prompt sequences should finish with + trailing zeros and should not contain eos tokens. + cache: cache for fast decoding (generation). + tokens_to_logits: fast autoregressive decoder function taking single token + slices and cache and returning next-token logits and updated cache. + num_samples: int: number of samples to generate per batch item. Note, no + deduplication is performed, and in dependence of parameter settings, same + sequences could be generated and returned. + eos_token: end-of-sentence token. + max_decode_len: maximal length of generated sequences (L). + seed: PRNGKey for random sampling. + temperature: positive real-valued sampling temperature. By default we sample + from the original distribution. As the temperature approaches 0., the + entire distribution concentrates on the most probable outcome(s). + top_k: limit sampling to only top-k logits. Zero means no limit. + top_p: limit sampling to smallest number of top logits with max cumulative + prob <= top_p. Zero means no limit. Cannot use both top_p and top_k. + mask_token_ids: if set then tokens with given ids are not sampled. + + Returns: + sequences: generated sequences [B, num_samples, L]. + scores: not implemented in the naive temperature sampling [B, num_samples]. + logprobs: Log probabilities for the generated tokens [B, num_samples, L]. + """ + if top_k > 0 and top_p > 0.0: + raise ValueError(f"Cannot use both top_k {top_k} and top_p {top_p}.") + if max_decode_len is None: + max_decode_len = prompts.shape[1] + # We will start generating sequences from 0 token. + prompts = jnp.pad(prompts, ((0, 0), (1, 0))) + eos = jnp.array(eos_token) + if isinstance(seed, int): + seed = jax.random.PRNGKey(seed) + + # Initialize the state. + loop_init_state = _init_state(prompts, cache, seed, num_samples) + should_temperature_sampling_continue_fn = functools.partial( + _should_temperature_sampling_continue, + max_decode_len=max_decode_len+1) # Account for prompt padding with 0's. + temperature_sampling_iteration_fn = functools.partial( + _temperature_sampling_iteration, + tokens_to_logits=tokens_to_logits, + temperature=temperature, top_k=top_k, top_p=top_p, + eos=eos, mask_token_ids=mask_token_ids) + + # Run the temperature sampling and generate the sequences. + final_state = lax.while_loop( + should_temperature_sampling_continue_fn, + temperature_sampling_iteration_fn, + loop_init_state) + + # Return the generated sequences, discarding the 0 token in the beginning. + return ( + final_state.sequences[:, 1:].reshape((-1, num_samples, max_decode_len)), + final_state.scores.reshape((-1, num_samples)), + final_state.logprobs.reshape((-1, num_samples, max_decode_len))) diff --git a/big_vision/models/proj/uvim/vit.py b/big_vision/models/proj/uvim/vit.py new file mode 100644 index 0000000000000000000000000000000000000000..7f3ac42b357b2f2f7315373df4419d3c69e7073c --- /dev/null +++ b/big_vision/models/proj/uvim/vit.py @@ -0,0 +1,338 @@ +# Copyright 2022 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""VQ-VAE autoencoder with ViT backbone.""" + +import functools +from typing import Mapping, Optional, Sequence, Union + +from big_vision import utils +from big_vision.models import common +from big_vision.models import vit + +import einops +import flax.linen as nn +import flax.training.checkpoints +import jax +import jax.numpy as jnp +import numpy as np + + +partial = functools.partial + +# Multiplicative perturbation applied to codewords when doing the split. +# Note, the multiplicative pertubation is not perfectly symmetric and rep. +# applications can shrink the embedding. However, in practice it does not matter +# for the value we use. +PERTURB = 0.001 + + +# The function below takes a vector `x` and a dictioniary of vectors `e` as an +# input. It then returns a "quantized" version of x (namely the closest to `x` +# vector from `e`) and its index in `e` as well. +# On top of this, it has two extra features: +# 1. Double `vmap` vectorizes this function to operate on many `x` vectors. +# More concretely, we add two extra dimensions (batch and space) to `x`. +# Also note we compute euclidian distance in a decomposed way, because it +# makes it more efficient for vmapping. +# 2. `quantize` is a "discrete" operation, so it does not have a gradient for +# `x`. So we implement a so-called "straight-through" gradient estimator +# using `stop_gradient` magic. It does not affect forward pass, but changes +# the gradient. +@partial(jax.vmap, in_axes=(0, None), out_axes=(0, 0)) +@partial(jax.vmap, in_axes=(0, None), out_axes=(0, 0)) +def quantize(x, e): + dist = jnp.sum(x * x)[None] - 2 * x.dot(e.T) + jnp.sum(e * e, axis=1) + idx = jnp.argmin(dist) + x_q = jax.lax.stop_gradient(e[idx] - x) + x # just `e[idx]` for the fwd pass. + return x_q, idx + + +def split_the_most_frequent_embedding(state): + """Splits most frequent embedding into two and eliminates least frequent. + + Args: + state: a dict. that contains current jax rng, embeddings and their counts. + + Returns: + New dict. with the updated jax rng, embeddings and counts. + """ + rng, e, c = state["rng"], state["dictionary"], state["counts"] + rng, rng_local = jax.random.split(rng) + + i_max = jnp.argmax(c) + i_min = jnp.argmin(c) + + e = e.at[i_min].set( + e[i_max] * jax.random.uniform(rng_local, (e.shape[1],), jnp.float32, + 1.0-PERTURB, 1.0+PERTURB)) + + c = c.at[i_min].set(c[i_max] / 2.0) + c = c.at[i_max].set(c[i_max] / 2.0) + + e = e.at[i_min].set(e[i_min] / 2.0) + e = e.at[i_max].set(e[i_max] / 2.0) + + return {"rng": rng, "dictionary": e, "counts": c} + + +class Model(nn.Module): + """ViT model.""" + + inputs: Mapping[str, Sequence[int]] + outputs: Mapping[str, Sequence[int]] + input_size: Sequence[int] = (256, 256) + patch_size: Sequence[int] = (8, 8) + code_len: int = 256 + width: int = 768 + enc_depth: int = 6 + dec_depth: int = 6 + mlp_dim: Optional[int] = None + num_heads: int = 12 + posemb: str = "learn" # Can also be "sincos2d" + rep_size: Union[int, bool] = False + dropout: float = 0.0 + reinit: Optional[Sequence[str]] = None + head_zeroinit: bool = True + dict_size: int = 512 # Number of words in dict. + codeword_dim: Optional[int] = None + dict_momentum: float = 0.995 # Exp. moving average coeff. for dict. learning. + quantize: bool = True + # Useful to set to None when running without pmap, e.g. testing. + statistics_axis_name: str = "batch" + # Threshold for the discounted count after which the codeword will be + # considered unused. For the `dict_momentum` param of 0.995 the codeword + # should not be present in ~500 batches in a row. + min_count: float = 0.1 # ~= 0.995 ** 500 + with_encoder_ctx: bool = False + with_decoder_ctx: bool = False + code_dropout: str = "none" + bottleneck_resize: bool = False + zero_decoder_seq: bool = False + + def setup(self): + + self.grid_size = np.array(self.input_size) // np.array(self.patch_size) + + self.embeddings = { + k: nn.DenseGeneral(features=(self.width,), axis=range(-len(shape), 0), + name=f"embedding_{k}") + for k, shape in self.inputs.items() + } + + kw = {"kernel_init": nn.initializers.zeros} if self.head_zeroinit else {} + self.heads = { + k: nn.DenseGeneral(features=shape, name=f"head_{k}", **kw) + for k, shape in self.outputs.items() + } + + if self.with_encoder_ctx: + self.stem_conv_ctx_enc = nn.Conv( + self.width, self.patch_size, strides=self.patch_size, + padding="VALID", name="ctx_enc_embedding") + + if self.with_decoder_ctx: + self.stem_conv_ctx_dec = nn.Conv( + self.width, self.patch_size, strides=self.patch_size, + padding="VALID", name="ctx_dec_embedding") + + self.pos_embedding_encoder = vit.get_posemb( + self, self.posemb, self.grid_size, self.width, "pos_embedding_encoder") + self.encoder = vit.Encoder( + depth=self.enc_depth, + mlp_dim=self.mlp_dim, + num_heads=self.num_heads, + dropout=self.dropout, + name="encoder") + + if not self.bottleneck_resize: + self.bottleneck_downsample = self.param( + "bottleneck_downsample", + nn.initializers.xavier_uniform(), + (np.prod(self.grid_size), self.code_len)) + + norm_init = nn.initializers.normal(stddev=1.0 / np.sqrt(self.dict_size)) + self.dictionary = self.variable( + "state", "dictionary", + lambda shape: norm_init(self.make_rng("state"), shape), + (self.dict_size, self.codeword_dim or self.width)) + self.counts = self.variable("state", "counts", jnp.ones, (self.dict_size,)) + + if not self.bottleneck_resize: + self.bottleneck_upsample = self.param( + "bottleneck_upsample", + nn.initializers.xavier_uniform(), + (self.code_len, np.prod(self.grid_size))) + + self.pos_embedding_decoder = vit.get_posemb( + self, self.posemb, self.grid_size, self.width, "pos_embedding_decoder") + self.decoder = vit.Encoder( + depth=self.dec_depth, + mlp_dim=self.mlp_dim, + num_heads=self.num_heads, + dropout=self.dropout, + name="decoder") + + self.encoder_head = nn.Dense(self.codeword_dim or self.width) + self.decoder_stem = nn.Dense(self.width) + + def get_codewords(self): + e = self.dictionary.value / self.counts.value[:, None] + e = e / jnp.linalg.norm(e, axis=-1, keepdims=True) + return e + + def encode(self, x, *, ctx=None, train=False, update_dict=True): + out = {} + + out["stem"] = {} + for key, embed in self.embeddings.items(): + out["stem"][key] = embed(x[key]) + x = sum(out["stem"].values()) + + if self.with_encoder_ctx: + ctx_tokens = self.stem_conv_ctx_enc(ctx) + ctx_tokens = einops.rearrange(ctx_tokens, "b h w c -> b (h w) c") + x = x + ctx_tokens + + x, _ = self.encoder(x + self.pos_embedding_encoder, deterministic=not train) + + if self.bottleneck_resize: + x = einops.rearrange(x, "b (h w) c -> b h w c", + h=self.grid_size[0], w=self.grid_size[1]) + l = int(np.round(self.code_len ** 0.5)) + x = jax.image.resize( + x, (x.shape[0], l, l, x.shape[3]), + method="linear") + x = einops.rearrange(x, "b h w c -> b (h w) c") + else: + x = jnp.einsum("btc,tn->bnc", x, self.bottleneck_downsample) + + x = self.encoder_head(x) + + x = jax.nn.standardize(x, axis=-1) + x_pre_q = out["bottleneck"] = x + e = self.get_codewords() + x, idx = quantize(x, e) + out["bottleneck_q"] = x + out["code"] = idx + + # Implements explicit dictionary learning algo outlined in the VQ-VAE paper. + # We slightly deviate from the papers formulation, as we find it confusing, + # especially in the multi-host scenario. What is implemented below can be + # seen as computing discounted counts and sums of all embeddings. + if train: + # Compute counts and sum(x) of code in the global batch. + counts = jnp.zeros(self.dict_size, dtype=jnp.int32) + counts = counts.at[idx].add(1) + + # Below we introduce redundant stop_gradient, because jax' dead code + # elimination for our program's gradient fails to infer that the code + # below does not require gradient computation. + # Relevant github issue: https://github.com/google/jax/issues/9042. + # TODO: remove stop_gradient when the bug is fixed. + x_sum = jnp.zeros_like(self.dictionary.value) + x_sum = x_sum.at[idx].add(jax.lax.stop_gradient(x_pre_q)) + + if self.statistics_axis_name: + counts = jax.lax.psum(counts, axis_name=self.statistics_axis_name) + x_sum = jax.lax.psum(x_sum, axis_name=self.statistics_axis_name) + + out["codebook_max_ratio"] = jnp.max(counts) / jnp.sum(counts) + out["codebook_zeros_ratio"] = jnp.sum(counts == 0) / len(counts) + + if update_dict: + self.counts.value = self.counts.value * self.dict_momentum + counts + self.dictionary.value = (self.dictionary.value * self.dict_momentum + + x_sum) + + state = {"dictionary": self.dictionary.value, + "counts": self.counts.value, + "rng": self.make_rng("vqvae")} + new_state = jax.lax.while_loop( + lambda state: jnp.any(state["counts"] < self.min_count), + split_the_most_frequent_embedding, + state) + self.counts.value = new_state["counts"] + self.dictionary.value = new_state["dictionary"] + + if not self.quantize: + x = x_pre_q + out["bottleneck_q"] = x + return x, out + + def decode(self, x, ctx=None, discrete_input=False, train=False): + out = {} + + if discrete_input: + e = self.get_codewords() + x = e[x] + + if self.zero_decoder_seq: + x = jnp.zeros_like(x) + + if train and self.code_dropout != "none": + importance = jnp.linspace(1.0, 0.0, self.code_len + 2)[1:-1] + thr = jax.random.uniform(self.make_rng("dropout"), x.shape[:1]) + mask = importance[None, :] > thr[:, None] + if self.code_dropout == "random": + mask = jax.random.permutation( + self.make_rng("dropout"), mask, axis=-1, independent=True) + x = x * mask[:, :, None] + + x = self.decoder_stem(x) + + if self.bottleneck_resize: + l = int(np.round(self.code_len ** 0.5)) + x = einops.rearrange(x, "b (h w) c -> b h w c", h=l, w=l) + x = jax.image.resize( + x, (x.shape[0], self.grid_size[0], self.grid_size[1], x.shape[3]), + method="linear") + x = einops.rearrange(x, "b h w c -> b (h w) c") + else: + x = jnp.einsum("bnc,nt->btc", x, self.bottleneck_upsample) + + if self.with_decoder_ctx: + ctx_tokens = self.stem_conv_ctx_dec(ctx) + ctx_tokens = einops.rearrange(ctx_tokens, "b h w c -> b (h w) c") + x = x + ctx_tokens + + x, _ = self.decoder(x + self.pos_embedding_decoder) + + out["logits"] = {} + for key, head in self.heads.items(): + out["logits"][key] = head(x) + + return out["logits"], out + + def __call__(self, x, *, ctx=None, train=False, update_dict=True): + x, out_enc = self.encode(x, ctx=ctx, train=train, update_dict=update_dict) + x, out_dec = self.decode(x, ctx=ctx, train=train) + return x, {**out_enc, **out_dec} + + +def load(init_params, init_file, model_params=None, dont_load=()): + """Loads params from init checkpoint and merges into init_params.""" + del model_params + ckpt = flax.core.unfreeze(utils.load_checkpoint(None, init_file)) + params = {"params": ckpt["params"], "state": ckpt["state"]} + params = flax.training.checkpoints.convert_pre_linen(params) + # Fix old-style param name. + if "Encoder" in params["params"]: + p = params["params"] + p["encoder"] = p.pop("Encoder") + p["decoder"] = p.pop("Decoder") + params["params"] = p + if init_params is not None: + params = common.merge_params(params, init_params, dont_load) + return params["params"], params["state"] diff --git a/big_vision/models/proj/uvim/vit_test.py b/big_vision/models/proj/uvim/vit_test.py new file mode 100644 index 0000000000000000000000000000000000000000..f42c85a2bb3ff040c73a233d01d1053adfbba5ed --- /dev/null +++ b/big_vision/models/proj/uvim/vit_test.py @@ -0,0 +1,76 @@ +# Copyright 2022 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for vit vqvae model.""" +from absl.testing import absltest + +from big_vision.models.proj.uvim import vit +import jax +import jax.numpy as jnp +import ml_collections + + +class ViTVQVAEModelTest(absltest.TestCase): + + def test_model(self): + model_config = ml_collections.ConfigDict({ + "input_size": (32, 32), + "code_len": 4, + "width": 16, + "mlp_dim": 64, + "num_heads": 4, + "enc_depth": 1, + "dec_depth": 1, + "with_encoder_ctx": True, + "with_decoder_ctx": True, + "statistics_axis_name": None, + "inputs": { + "in1": (10, 3), + "in2": (25,), + }, + "outputs": { + "out1": (5,), + "out2": (20,), + }, + }) + + model = vit.Model(**model_config) + batch_size = 4 + seq_len = (32 // 8) ** 2 + x = { + "in1": jnp.zeros((batch_size, seq_len, 10, 3)), + "in2": jnp.zeros((batch_size, seq_len, 25)), + } + ctx_image = jnp.zeros((batch_size,) + model_config.input_size + (3,)) + init_rngs = { + "params": jax.random.PRNGKey(0), + "state": jax.random.PRNGKey(1), + } + params = model.init(init_rngs, x, ctx=ctx_image) + self.assertEqual(params.keys(), set(["params", "state"])) + + apply_rngs = { + "dropout": jax.random.PRNGKey(0), + "vqvae": jax.random.PRNGKey(0), + } + (logits, _), params = model.apply( + params, x, ctx=ctx_image, train=True, update_dict=True, + rngs=apply_rngs, mutable=["state"]) + self.assertEqual(logits.keys(), set(["out1", "out2"])) + self.assertEqual(logits["out1"].shape, (batch_size, seq_len, 5)) + self.assertEqual(logits["out2"].shape, (batch_size, seq_len, 20)) + + +if __name__ == "__main__": + absltest.main() diff --git a/big_vision/models/proj/uvim/vtt.py b/big_vision/models/proj/uvim/vtt.py new file mode 100644 index 0000000000000000000000000000000000000000..7bfd87cd7e6a5bc86d1aa6e6f7bc0afa8efa3b08 --- /dev/null +++ b/big_vision/models/proj/uvim/vtt.py @@ -0,0 +1,270 @@ +# Copyright 2022 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Simple vision-text transformer with encoder-decoder architecture. + +Used abbreviations for dimension annotations: + B: batch size. + H: image height. + W: image width. + P: number of patches (PH/PW: number of patches in height/width dimensions). + E: embedding size. + L: sequence length of text tokens. + V: vocab size. +""" +from typing import Sequence +from big_vision import utils +from big_vision.models import common +from big_vision.models import vit +import einops +import flax +import flax.linen as nn +import jax.numpy as jnp +import ml_collections +import numpy as np + + +def shift_right(x, axis=1): + """Shift to the right on given axis with padding value 0.""" + pad_widths = [(0, 0)] * len(x.shape) + pad_widths[axis] = (1, 0) + padded = jnp.pad(x, pad_widths, constant_values=0) + return padded[:, :-1] + + +class EncoderDecoderBlock(nn.Module): + """Transformer encoder-decoder layer.""" + mlp_dim: int + num_heads: int + dropout_rate: float = 0. + decode: bool = False + + @nn.compact + def __call__(self, targets, encoded, decoder_mask=None, deterministic=True): + """Applies EncoderDecoder1DBlock module. + + Args: + targets: target text embeddings [B, L, E]. + encoded: encoded image patches from encoder [B, P, E]. + decoder_mask: decoder self-attention mask. + deterministic: bool, deterministic or not (to apply dropout). + + Returns: + output after transformer encoder-decoder block [B, L, E]. + """ + # Decoder block. + x = nn.LayerNorm(name="LayerNorm1")(targets) + x = nn.SelfAttention( + num_heads=self.num_heads, use_bias=False, broadcast_dropout=False, + dropout_rate=self.dropout_rate, decode=self.decode, name="SelfAttn")( + x, decoder_mask, deterministic=deterministic) + x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic) + x = x + targets + + # Encoder-Decoder block. + y = nn.LayerNorm(name="LayerNorm2")(x) + y = nn.MultiHeadDotProductAttention( + num_heads=self.num_heads, use_bias=False, broadcast_dropout=False, + dropout_rate=self.dropout_rate, name="CrossAttn")( + y, encoded, deterministic=deterministic) + y = nn.Dropout(rate=self.dropout_rate)(y, deterministic=deterministic) + y = y + x + + # MLP block. + z = nn.LayerNorm(name="LayerNorm3")(y) + z = vit.MlpBlock(mlp_dim=self.mlp_dim, dropout=self.dropout_rate, + name="MLP")(z, deterministic=deterministic) + + return y + z + + +class Decoder(nn.Module): + """Transformer Model Decoder for sequence to sequence translation.""" + emb_dim: int + mlp_dim: int + num_heads: int + num_layers: int + dropout_rate: float = 0. + output_vocab_size: int = 32000 + zero_decoder_seq: bool = False + + @nn.compact + def __call__(self, + encoded, + targets, + pos_emb, + decoder_mask=None, + decode=False, + deterministic=True, + max_decode_length=None): + """Applies Transformer model on the inputs. + + Args: + encoded: encoded image patches from encoder [B, P, E]. + targets: target text tokens [B, L]. + pos_emb: positional embeddings. + decoder_mask: decoder self-attention mask. + decode: bool, whether to perform fast autoregressive decoding with cache. + deterministic: bool, deterministic or not (to apply dropout). + max_decode_length: optional max length for positional embeddings. + + Returns: + output of a transformer decoder [B, L, V]. + """ + y = targets.astype("int32") + if not decode: + y = shift_right(y) + y = nn.Embed(self.output_vocab_size, self.emb_dim, name="EmbedTargets", + embedding_init=nn.initializers.normal(stddev=1.0))(y) + if self.zero_decoder_seq: + y = jnp.zeros_like(y) + y = common.AddPositionEmbs( + decode=decode, name="PosEmbedTargets")(y, pos_emb) + y = nn.Dropout(rate=self.dropout_rate)(y, deterministic=deterministic) + + for lyr in range(self.num_layers): + y = EncoderDecoderBlock( + num_heads=self.num_heads, mlp_dim=self.mlp_dim, + dropout_rate=self.dropout_rate, decode=decode, + name=f"EncDecBlock{lyr}")(y, encoded, decoder_mask=decoder_mask, + deterministic=deterministic) + y = nn.LayerNorm(name="LayerNorm")(y) + logits = nn.Dense(self.output_vocab_size, kernel_init=nn.initializers.zeros, + name="LogitsDense")(y) + return logits + + +class Model(nn.Module): + """Transformer Model for sequence to sequence translation.""" + patches: ml_collections.ConfigDict + # Encoder/decoder shared params: + num_heads: int = 8 + num_layers: int = 6 + mlp_dim: int = 2048 + dropout_rate: float = 0. + # Decoder params: + emb_dim: int = 512 + vocab_size: int = 32000 + seq_len: int = 256 + # Encoder params: + input_size: Sequence[int] = (256, 256) + posemb_type: str = "sincos2d" # Can also be "learn" + zero_decoder_seq: bool = False + + def setup(self): + grid_size = np.array(self.input_size) // np.array(self.patches.size) + self.pos_emb_for_encoder = vit.get_posemb( + self, self.posemb_type, grid_size, self.emb_dim, + "pos_embedding_encoder") + self.pos_emb_for_decoder = vit.get_posemb( + self, self.posemb_type, (1, self.seq_len), self.emb_dim, + "pos_embedding_decoder") + + self.encoder = vit.Encoder( + depth=self.num_layers, + mlp_dim=self.mlp_dim, + num_heads=self.num_heads, + dropout=self.dropout_rate) + self.decoder = Decoder( + num_layers=self.num_layers, + mlp_dim=self.mlp_dim, + num_heads=self.num_heads, + dropout_rate=self.dropout_rate, + emb_dim=self.emb_dim, + output_vocab_size=self.vocab_size, + zero_decoder_seq=self.zero_decoder_seq, + ) + self.conv = nn.Conv(self.emb_dim, self.patches.size, padding="VALID", + strides=self.patches.size, name="EmbedPatches") + + def encode(self, image, train=False): + """Encodes input image or embeddings.""" + emb = self.conv(image) + patch_embeddings = einops.rearrange(emb, "B PH PW E -> B (PH PW) E") + encoded, _ = self.encoder( + patch_embeddings + self.pos_emb_for_encoder, deterministic=not train) + return encoded + + def decode(self, encoded, targets, decode=False, train=False, + max_decode_length=None): + """Applies Transformer decoder-branch on encoded-input and target. + + Args: + encoded: encoded image patches from encoder [B, P, E]. + targets: target text tokens [B, L]. + decode: whether to prepare and use an autoregressive cache. + train: whether it is training. + max_decode_length: optional max length for positional embeddings. + + Returns: + logits array from transformer decoder [B, L, V]. + """ + decoder_mask = None if decode else nn.make_causal_mask(targets) + logits = self.decoder( + encoded, + targets, + pos_emb=self.pos_emb_for_decoder, + decoder_mask=decoder_mask, + decode=decode, + deterministic=not train, + max_decode_length=max_decode_length) + return logits + + def __call__(self, image, text, *, decode=False, train=False): + """Applies Transformer model on the inputs. + + Args: + image: batch of images [B, H, W, 3]. + text: batch of tokenized texts [B, L]. + decode: whether to prepare and use an autoregressive cache. + train: whether it is training. + + Returns: + logits array from full transformer [B, L, V]. + """ + encoded = self.encode(image, train=train) + return self.decode(encoded, text, decode=decode, train=train) + + +def load(init_params, init_files, model_params=None, + dont_load=("head/kernel", "head/bias", "cls")): + """Loads params from init checkpoint and merges into init_params.""" + del model_params + if isinstance(init_files, str): + # A shortcut for a single file checkpoint of a vtt model. + ckpt_params = utils.load_params(None, init_files) + ckpt_params = flax.training.checkpoints.convert_pre_linen(ckpt_params) + if init_params is not None: + ckpt_params = common.merge_params(ckpt_params, init_params, dont_load) + else: + init_files = {**init_files} # Shallow copy because we'll pop stuff off. + + enc_init = init_files.pop("encoder", None) + if enc_init: + ckpt_params = init_params.copy() + vit_params = { + "pos_embedding": ckpt_params["pos_embedding_encoder"], + "Transformer": ckpt_params["encoder"], + "embedding": ckpt_params["EmbedPatches"], + } + encoder_params = vit.load( + vit_params, enc_init, model_cfg={}, + dont_load=dont_load) + ckpt_params["encoder"] = encoder_params["Transformer"] + ckpt_params["pos_embedding_encoder"] = encoder_params["pos_embedding"] + ckpt_params["EmbedPatches"] = encoder_params["embedding"] + else: + raise ValueError("Only encoder init is supported: {}.".format(init_files)) + + return ckpt_params diff --git a/big_vision/models/proj/uvim/vtt_test.py b/big_vision/models/proj/uvim/vtt_test.py new file mode 100644 index 0000000000000000000000000000000000000000..50b279b565af6cb644bae83253b297cae1569cd7 --- /dev/null +++ b/big_vision/models/proj/uvim/vtt_test.py @@ -0,0 +1,50 @@ +# Copyright 2022 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for vision-text-transformer.""" +from absl.testing import absltest + +from big_vision.models.proj.uvim import vtt +import jax +import jax.numpy as jnp +import ml_collections + + +class VTTTest(absltest.TestCase): + + def test_vtt_with_1_step(self): + model_config = ml_collections.ConfigDict(dict( + input_size=(224, 224), + patches={"size": (16, 16)}, + num_heads=2, + num_layers=2, + mlp_dim=128, + emb_dim=64, + vocab_size=500)) + batch_size, max_len = 8, 50 + image = jnp.ones((batch_size, 224, 224, 3)) + text = jnp.ones((batch_size, max_len), dtype=jnp.int32) + + m = vtt.Model(**model_config) + variables = m.init(jax.random.PRNGKey(42), image, text) + self.assertCountEqual(variables.keys(), ["params"]) + + params = variables["params"] + out = m.apply({"params": params}, image, text) + expected_shape = (batch_size, max_len, model_config.vocab_size) + self.assertEqual(out.shape, expected_shape) + + +if __name__ == "__main__": + absltest.main() diff --git a/big_vision/models/vit.py b/big_vision/models/vit.py new file mode 100644 index 0000000000000000000000000000000000000000..376cb6bd8616eb49af40f0815cfbc4985df156b9 --- /dev/null +++ b/big_vision/models/vit.py @@ -0,0 +1,478 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A refactored and simplified ViT. + +However, the names of modules are made to match the old ones for easy loading. +""" + +from typing import Optional, Sequence, Union + +from absl import logging +from big_vision import utils +from big_vision.models import common +import flax +import flax.linen as nn +import flax.training.checkpoints +import jax +import jax.numpy as jnp +import numpy as np +import scipy.ndimage + + +def posemb_sincos_2d(h, w, width, temperature=10_000., dtype=jnp.float32): + """Follows the MoCo v3 logic.""" + y, x = jnp.mgrid[:h, :w] + + assert width % 4 == 0, "Width must be mult of 4 for sincos posemb" + omega = jnp.arange(width // 4) / (width // 4 - 1) + omega = 1. / (temperature**omega) + y = jnp.einsum("m,d->md", y.flatten(), omega) + x = jnp.einsum("m,d->md", x.flatten(), omega) + pe = jnp.concatenate([jnp.sin(x), jnp.cos(x), jnp.sin(y), jnp.cos(y)], axis=1) + return jnp.asarray(pe, dtype)[None, :, :] + + +def get_posemb(self, typ, seqshape, width, name, dtype=jnp.float32): + if typ == "learn": + return self.param(name, nn.initializers.normal(stddev=1/np.sqrt(width)), + (1, np.prod(seqshape), width), dtype) + elif typ == "sincos2d": + return posemb_sincos_2d(*seqshape, width, dtype=dtype) + else: + raise ValueError(f"Unknown posemb type: {typ}") + + +class MlpBlock(nn.Module): + """Transformer MLP / feed-forward block.""" + mlp_dim: Optional[int] = None # Defaults to 4x input dim + dropout: float = 0.0 + dtype_mm: str = "float32" + + @nn.compact + def __call__(self, x, deterministic=True): + """Applies Transformer MlpBlock module.""" + inits = dict( + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.normal(stddev=1e-6), + ) + + n, l, d = x.shape # pylint: disable=unused-variable + x = nn.Dense(self.mlp_dim or 4 * d, dtype=self.dtype_mm, **inits)(x) + x = nn.gelu(x) + x = nn.Dropout(rate=self.dropout)(x, deterministic) + x = nn.Dense(d, dtype=self.dtype_mm, **inits)(x) + return x + + +class Encoder1DBlock(nn.Module): + """Single transformer encoder block (MHSA + MLP).""" + mlp_dim: Optional[int] = None # Defaults to 4x input dim + num_heads: int = 12 + dropout: float = 0.0 + dtype_mm: str = "float32" + + @nn.compact + def __call__(self, x, deterministic=True): + out = {} + x = nn.with_logical_constraint(x, ("act_batch", "act_len", "act_emb")) + y = nn.LayerNorm()(x) + y = out["sa"] = nn.MultiHeadDotProductAttention( + num_heads=self.num_heads, + kernel_init=nn.initializers.xavier_uniform(), + deterministic=deterministic, + dtype=self.dtype_mm, + )(y, y) + y = nn.with_logical_constraint(y, ("act_batch", "act_len", "act_emb")) + y = nn.Dropout(rate=self.dropout)(y, deterministic) + x = out["+sa"] = x + y + + y = nn.LayerNorm()(x) + y = out["mlp"] = MlpBlock( + mlp_dim=self.mlp_dim, dropout=self.dropout, + dtype_mm=self.dtype_mm, + )(y, deterministic) + y = nn.with_logical_constraint(y, ("act_batch", "act_len", "act_emb")) + y = nn.Dropout(rate=self.dropout)(y, deterministic) + x = out["+mlp"] = x + y + x = nn.with_logical_constraint(x, ("act_batch", "act_len", "act_emb")) + return x, out + + +class Encoder(nn.Module): + """Transformer Model Encoder for sequence to sequence translation.""" + depth: int + mlp_dim: Optional[int] = None # Defaults to 4x input dim + num_heads: int = 12 + dropout: float = 0.0 + scan: bool = False + remat_policy: str = "nothing_saveable" + dtype_mm: str = "float32" + + @nn.compact + def __call__(self, x, deterministic=True): + out = {} + + if self.scan: + block = nn.remat( + Encoder1DBlock, + prevent_cse=False, + static_argnums=(2,), # 0=self, 2=deterministic + policy=getattr(jax.checkpoint_policies, self.remat_policy, None), + ) + x, scan_out = nn.scan( + block, + variable_axes={"params": 0}, + split_rngs={"params": True, "dropout": True}, + in_axes=nn.broadcast, + length=self.depth)( + name="encoderblock", + dtype_mm=self.dtype_mm, + mlp_dim=self.mlp_dim, + num_heads=self.num_heads, + dropout=self.dropout)(x, deterministic) + for lyr in range(self.depth): + out[f"block{lyr:02d}"] = jax.tree.map(lambda o, l=lyr: o[l], scan_out) + else: + # Input Encoder + for lyr in range(self.depth): + block_cur = Encoder1DBlock( + name=f"encoderblock_{lyr}", + dtype_mm=self.dtype_mm, + mlp_dim=self.mlp_dim, num_heads=self.num_heads, + dropout=self.dropout) + x, out[f"block{lyr:02d}"] = block_cur(x, deterministic) + out["pre_ln"] = x # Alias for last block, but without the number in it. + + return nn.LayerNorm(name="encoder_norm")(x), out + + +class MAPHead(nn.Module): + """Multihead Attention Pooling.""" + mlp_dim: Optional[int] = None # Defaults to 4x input dim + num_heads: int = 12 + + @nn.compact + def __call__(self, x): + # TODO + n, l, d = x.shape # pylint: disable=unused-variable + probe = self.param("probe", nn.initializers.xavier_uniform(), + (1, 1, d), x.dtype) + probe = jnp.tile(probe, [n, 1, 1]) + + x = nn.MultiHeadDotProductAttention( + num_heads=self.num_heads, + kernel_init=nn.initializers.xavier_uniform())(probe, x) + + # TODO: dropout on head? + y = nn.LayerNorm()(x) + x = x + MlpBlock(mlp_dim=self.mlp_dim)(y) + return x[:, 0] + + +class _Model(nn.Module): + """ViT model.""" + + num_classes: Optional[int] = None + patch_size: Sequence[int] = (16, 16) + width: int = 768 + depth: int = 12 + mlp_dim: Optional[int] = None # Defaults to 4x input dim + num_heads: int = 12 + posemb: str = "learn" # Can also be "sincos2d" + rep_size: Union[int, bool] = False + dropout: float = 0.0 + pool_type: str = "gap" # Can also be "map" or "tok" + head_zeroinit: bool = True + scan: bool = False + # or "dots_with_no_batch_dims_saveable" for more speed (memory costly) + remat_policy: str = "nothing_saveable" + dtype_mm: str = "float32" + + @nn.compact + def __call__(self, image, *, train=False): + out = {} + + image = jnp.asarray(image, self.dtype_mm) + + # Patch extraction + x = out["stem"] = nn.Conv( + self.width, self.patch_size, strides=self.patch_size, + padding="VALID", name="embedding", dtype=self.dtype_mm)(image) + + n, h, w, c = x.shape + x = jnp.reshape(x, [n, h * w, c]) + + # Add posemb before adding extra token. + x = out["with_posemb"] = x + get_posemb( + self, self.posemb, (h, w), c, "pos_embedding", x.dtype) + + if self.pool_type == "tok": + cls = self.param("cls", nn.initializers.zeros, (1, 1, c), x.dtype) + x = jnp.concatenate([jnp.tile(cls, [n, 1, 1]), x], axis=1) + + n, l, c = x.shape # pylint: disable=unused-variable + x = nn.Dropout(rate=self.dropout)(x, not train) + + x, out["encoder"] = Encoder( + depth=self.depth, + mlp_dim=self.mlp_dim, + num_heads=self.num_heads, + dropout=self.dropout, + scan=self.scan, + remat_policy=self.remat_policy, + dtype_mm=self.dtype_mm, + name="Transformer")( + x, deterministic=not train) + encoded = out["encoded"] = x + + if self.pool_type == "map": + x = out["head_input"] = MAPHead( + num_heads=self.num_heads, mlp_dim=self.mlp_dim)(x) + elif self.pool_type == "gap": + x = out["head_input"] = jnp.mean(x, axis=1) + elif self.pool_type == "0": + x = out["head_input"] = x[:, 0] + elif self.pool_type == "tok": + x = out["head_input"] = x[:, 0] + encoded = encoded[:, 1:] + elif self.pool_type == "none": + pass + else: + raise ValueError(f"Unknown pool type: '{self.pool_type}'") + + x_2d = jnp.reshape(encoded, [n, h, w, -1]) + + if self.rep_size: + rep_size = self.width if self.rep_size is True else self.rep_size + hid = nn.Dense(rep_size, name="pre_logits") + # NOTE: In the past we did not include tanh in pre_logits. + # For few-shot, it should not matter much, as it whitens anyways. + x_2d = nn.tanh(hid(x_2d)) + x = nn.tanh(hid(x)) + + out["pre_logits_2d"] = x_2d + out["pre_logits"] = x + + if self.num_classes: + kw = {"kernel_init": nn.initializers.zeros} if self.head_zeroinit else {} + head = nn.Dense(self.num_classes, name="head", **kw) + x_2d = out["logits_2d"] = head(x_2d) + x = out["logits"] = head(x) + + return x, out + + +def Model(num_classes=None, *, variant=None, **kw): # pylint: disable=invalid-name + """Factory function, because linen really don't like what I'm doing!""" + return _Model(num_classes, **{**decode_variant(variant), **kw}) + + +def decode_variant(variant): + """Converts a string like "B" or "B/32" into a params dict.""" + if variant is None: + return {} + + v, patch = variant, {} + if "/" in variant: + v, patch = variant.split("/") + patch = {"patch_size": (int(patch), int(patch))} + + return { + # pylint:disable=line-too-long + # Reference: Table 2 of https://arxiv.org/abs/2106.04560. + "width": {"mu": 32, "Ti": 192, "S": 384, "M": 512, "B": 768, "L": 1024, "So400m": 1152, "H": 1280, "g": 1408, "g-opt": 1536, "G": 1664, "G-opt": 1536, "e": 1792}[v], + "depth": {"mu": 1, "Ti": 12, "S": 12, "M": 12, "B": 12, "L": 24, "So400m": 27, "H": 32, "g": 40, "g-opt": 40, "G": 48, "G-opt": 48, "e": 56}[v], + "mlp_dim": {"mu": 128, "Ti": 768, "S": 1536, "M": 2048, "B": 3072, "L": 4096, "So400m": 4304, "H": 5120, "g": 6144, "g-opt": 6144, "G": 8192, "G-opt": 8192, "e": 15360}[v], + "num_heads": {"mu": 2, "Ti": 3, "S": 6, "M": 8, "B": 12, "L": 16, "So400m": 16, "H": 16, "g": 16, "g-opt": 16, "G": 16, "G-opt": 16, "e": 16}[v], + # pylint:enable=line-too-long + **patch + } + + +def resample_posemb(old, new): + """This function implements "high-res finetuning" for transformer models.""" + # Rescale the grid of position embeddings. Param shape is (1,N,1024) + if old.shape == new.shape: + return old + + logging.info("ViT: resize %s to %s", old.shape, new.shape) + gs_old = int(np.sqrt(old.shape[1])) + gs_new = int(np.sqrt(new.shape[1])) + logging.info("ViT: grid-size from %s to %s", gs_old, gs_new) + grid = old.reshape(gs_old, gs_old, -1) + + zoom = (gs_new/gs_old, gs_new/gs_old, 1) + grid = scipy.ndimage.zoom(grid, zoom, order=1) + grid = grid.reshape(1, gs_new*gs_new, -1) + return grid + + +def fix_old_checkpoints(params): + """Fix small bwd incompat that can't be resolved with names in model def.""" + + params = flax.core.unfreeze( + flax.training.checkpoints.convert_pre_linen(params)) + + # Original ViT paper variant had posemb in a module: + if "posembed_input" in params["Transformer"]: + logging.info("ViT: Loading and fixing VERY old posemb") + posemb = params["Transformer"].pop("posembed_input") + params["pos_embedding"] = posemb["pos_embedding"] + + # Widely used version before 2022 had posemb in Encoder: + if "pos_embedding" in params["Transformer"]: + logging.info("ViT: Loading and fixing old posemb") + params["pos_embedding"] = params["Transformer"].pop("pos_embedding") + + # Old vit.py used to first concat [cls] token, then add posemb. + # This means a B/32@224px would have 7x7+1 posembs. This is useless and clumsy + # so we changed to add posemb then concat [cls]. We can recover the old + # checkpoint by manually summing [cls] token and its posemb entry. + if "pos_embedding" in params: + pe = params["pos_embedding"] + if int(np.sqrt(pe.shape[1])) ** 2 + 1 == int(pe.shape[1]): + logging.info("ViT: Loading and fixing combined cls+posemb") + pe_cls, params["pos_embedding"] = pe[:, :1], pe[:, 1:] + if "cls" in params: + params["cls"] += pe_cls + + # MAP-head variants during ViT-G development had it inlined: + if "probe" in params: + params["MAPHead_0"] = { + k: params.pop(k) for k in + ["probe", "MlpBlock_0", "MultiHeadDotProductAttention_0", "LayerNorm_0"] + } + + return params + + +def pyloop_to_scan(params_pyloop): + """Converts a python for-loop ViT checkpoint to a lax.scan based one.""" + # On a high level, they are the same except that the for loop has separate + # array pytrees for each encoderblock, while the scan one has just one + # encoderblock pytree, with all block's params concatenated. + + params_scan = jax.tree.map(lambda x: x, params_pyloop) # Structural copy + t = params_scan["Transformer"] + + # Find highest index of encoderblocks in the checkpoint (they start at 0): + encoderblocks = {k for k in t if k.startswith("encoderblock_")} + depth = 1 + max({int(k.split("_")[-1]) for k in encoderblocks}) + + def stack(*values): + return np.stack(values) + + # Stack all encoderblocks into a single one: + t["encoderblock"] = jax.tree.map( + stack, *[t[f"encoderblock_{lyr}"] for lyr in range(depth)]) + + for lyr in range(depth): + del t[f"encoderblock_{lyr}"] + + return params_scan + + +def scan_to_pyloop(params_scan): + """Converts a lax.scan ViT checkpoint to a python for-loop based one.""" + # See comment in pyloop_to_scan. + + params_scan = jax.tree.map(lambda x: x, params_scan) # Structural copy + t = params_scan["Transformer"] + + # Find out how many encoderblocks there are + depth = len(t["encoderblock"]["LayerNorm_0"]["bias"]) + + # Create that many encoderblocks, each with their slice of their sub-pytree. + for lyr in range(depth): + block = jax.tree.map(lambda x, lyr=lyr: x[lyr], t["encoderblock"]) + t[f"encoderblock_{lyr}"] = block + + del t["encoderblock"] + return params_scan + + +def load(init_params, init_file, model_cfg, dont_load=()): # pylint: disable=invalid-name because we had to CamelCase above. + """Load init from checkpoint, both old model and this one. +Hi-res posemb.""" + init_file = VANITY_NAMES.get(init_file, init_file) + restored_params = utils.load_params(init_file) + + restored_params = fix_old_checkpoints(restored_params) + + # Detect attempts to load non-scan checkpoint into scan model. + if (model_cfg.get("scan") and + "encoderblock" not in restored_params["Transformer"]): + restored_params = pyloop_to_scan(restored_params) + if (not model_cfg.get("scan") + and "encoderblock" in restored_params["Transformer"]): + restored_params = scan_to_pyloop(restored_params) + + # possibly use the random init for some of the params (such as, the head). + restored_params = common.merge_params(restored_params, init_params, dont_load) + + # resample posemb if needed. + # TODO: Take this from model_cfg to avoid need for init_params. + if init_params and "pos_embedding" in init_params: + restored_params["pos_embedding"] = resample_posemb( + old=restored_params["pos_embedding"], + new=init_params["pos_embedding"]) + + return restored_params + + +# Shortcut names for some canonical paper checkpoints: +VANITY_NAMES = { + # pylint: disable=line-too-long + # Recommended models from https://arxiv.org/abs/2106.10270 + # Many more models at https://github.com/google-research/vision_transformer + "howto-i21k-Ti/16": "gs://vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz", + "howto-i21k-S/32": "gs://vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_none-wd_0.1-do_0.0-sd_0.0.npz", + "howto-i21k-S/16": "gs://vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz", + "howto-i21k-B/32": "gs://vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0.npz", + "howto-i21k-B/16": "gs://vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz", + "howto-i21k-B/8": "gs://vit_models/augreg/B_8-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0.npz", + "howto-i21k-L/16": "gs://vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_strong1-wd_0.1-do_0.0-sd_0.0.npz", + + # Better plain vit-s16 baselines from https://arxiv.org/abs/2205.01580 + "i1k-s16-90ep": "gs://big_vision/vit_s16_i1k_90ep.npz", + "i1k-s16-150ep": "gs://big_vision/vit_s16_i1k_150ep.npz", + "i1k-s16-300ep": "gs://big_vision/vit_s16_i1k_300ep.npz", + + # DeiT-3 checkpoints from https://github.com/facebookresearch/deit/blob/main/README_revenge.md + # First layer converted to take inputs in [-1,1] + "deit3_S_224_1k": "gs://big_vision/zoo/deit3/bv_deit_3_small_224_1k.npz", + "deit3_S_224_21k": "gs://big_vision/zoo/deit3/bv_deit_3_small_224_21k.npz", + "deit3_S_384_1k": "gs://big_vision/zoo/deit3/bv_deit_3_small_384_1k.npz", + "deit3_S_384_21k": "gs://big_vision/zoo/deit3/bv_deit_3_small_384_21k.npz", + "deit3_B_224_1k": "gs://big_vision/zoo/deit3/bv_deit_3_base_224_1k.npz", + "deit3_B_224_21k": "gs://big_vision/zoo/deit3/bv_deit_3_base_224_21k.npz", + "deit3_B_384_1k": "gs://big_vision/zoo/deit3/bv_deit_3_base_384_1k.npz", + "deit3_B_384_21k": "gs://big_vision/zoo/deit3/bv_deit_3_base_384_21k.npz", + "deit3_L_224_1k": "gs://big_vision/zoo/deit3/bv_deit_3_large_224_1k.npz", + "deit3_L_224_21k": "gs://big_vision/zoo/deit3/bv_deit_3_large_224_21k.npz", + "deit3_L_384_1k": "gs://big_vision/zoo/deit3/bv_deit_3_large_384_1k.npz", + "deit3_L_384_21k": "gs://big_vision/zoo/deit3/bv_deit_3_large_384_21k.npz", + + # SigLIP image encoder checkpoints from https://arxiv.org/abs/2303.15343 + "SigLIP B/16 224": "gs://big_vision/siglip/webli_en_b16_224_63724782.npz:img", + "SigLIP B/16 256": "gs://big_vision/siglip/webli_en_b16_256_60500360.npz:img", + "SigLIP B/16 384": "gs://big_vision/siglip/webli_en_b16_384_68578854.npz:img", + "SigLIP B/16 512": "gs://big_vision/siglip/webli_en_b16_512_68580893.npz:img", + "SigLIP L/16 256": "gs://big_vision/siglip/webli_en_l16_256_60552751.npz:img", + "SigLIP L/16 384": "gs://big_vision/siglip/webli_en_l16_384_63634585.npz:img", + "SigLIP So400m/14 224": "gs://big_vision/siglip/webli_en_so400m_224_57633886.npz:img", + "SigLIP So400m/14 384": "gs://big_vision/siglip/webli_en_so400m_384_58765454.npz:img", + "SigLIP B/16-i18n 256": "gs://big_vision/siglip/webli_i18n_b16_256_66117334.npz:img", + # pylint: enable=line-too-long +} diff --git a/big_vision/optax.py b/big_vision/optax.py new file mode 100644 index 0000000000000000000000000000000000000000..10b856067449c42e35931a1bdad6edaed2030ab3 --- /dev/null +++ b/big_vision/optax.py @@ -0,0 +1,222 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Gradient transformations and other optax utilities.""" + +import operator +import big_vision.utils as u +import jax +import jax.numpy as jnp +import optax + + +def find_states(opt_state, cls): + leaves = jax.tree.leaves( + opt_state, is_leaf=lambda node: isinstance(node, cls)) + return [leaf for leaf in leaves if isinstance(leaf, cls)] + + +def get_count(opt_state, jittable=False): + """Returns `ScaleByScheduleState.count` from `opt_state` as an integer.""" + counts = [ + state.count + for state in find_states(opt_state, optax.ScaleByScheduleState) + ] + if jittable: + return counts[0] + else: + counts = {int(c) for c in counts} + assert len(counts) == 1, f"Expected exactly 1 ScaleByScheduleState:{counts}" + return next(iter(counts)) + + +def replace_frozen(schedule, pytree, replacement, log=None): + """Replaces values matching frozen params in `pytree` with `replacement`.""" + if not isinstance(schedule, (list, tuple)): + return pytree + masks, scheds = _make_mask_trees(pytree, schedule, log=log) + frozen_mask, _, _ = _split_frozen(masks, scheds) + return jax.tree.map( + lambda v, f: replacement if f else v, pytree, frozen_mask) + + +def clip_by_per_example_global_norm( + max_norm: float, +) -> optax.GradientTransformation: + """Clips the norm of per-example gradients.""" + + def init_fn(params): + del params + return optax.EmptyState() + + def update_fn(updates, state, params=None): + del params + grads_flat, grads_treedef = jax.tree_util.tree_flatten(updates) + clipped, _ = optax.per_example_global_norm_clip(grads_flat, max_norm) + return jax.tree_util.tree_unflatten(grads_treedef, clipped), state + + return optax.GradientTransformation(init_fn, update_fn) + + +def make(config, params, *, sched_kw): + """Returns gradient transform and learning rate functions.""" + + # Global schedule. No schedule means frozen. + schedule = config.get("schedule", {}) + if not isinstance(schedule, (tuple, list)): + schedule = [(".*", schedule)] + masks, scheds = _make_mask_trees(params, schedule, "config.schedule") + frozen_mask, masks, scheds = _split_frozen(masks, scheds) + not_frozen_mask = jax.tree.map(operator.not_, frozen_mask) + def create_schedule(mult=1.0, **kw): + assert "base" not in kw, kw + return u.create_learning_rate_schedule(base=mult, **kw) + schedule_fns = [create_schedule(**sched_kw, **sched) for sched in scheds] + schedule_txs = [ + optax.masked(optax.scale_by_schedule(schedule_fn), mask) + for schedule_fn, mask in zip(schedule_fns, masks) + ] + [ + # Removes weight decay updates. Note that weight decay already has an + # independent mask (which cannot be combined easily with a second mask), + # so instead we multiply updates for frozen params with zero. + optax.masked(optax.set_to_zero(), frozen_mask) + ] + + # Gradient clipping. + if clip_norm := config.get("grad_clip_norm"): + if config.get("grad_clip_per_example"): + clip_tx = clip_by_per_example_global_norm(clip_norm) + else: + clip_tx = optax.clip_by_global_norm(clip_norm) + grad_clip_norm_tx = optax.masked(clip_tx, not_frozen_mask) + else: + grad_clip_norm_tx = optax.identity() + + # Optimizer updates. + tx_func = operator.attrgetter(config.optax_name)(optax) + opt_txs = [optax.masked(tx_func(**config.get("optax", {})), not_frozen_mask)] + assert "optim" not in config, "Deprecated option, use config.optax." + + # Learning rate multipliers. Defaults to 1.0. + lr_mult_txs = [optax.scale(config.lr)] + if config.get("lr_mults"): + masks, mults = _make_mask_trees(params, config.lr_mults, "config.lr_mults") + assert all(mult > 0 for mult in mults), ( + f"Use schedule=None for parameter freezing instead of lr_mults={mults}") + lr_mult_txs += [ + optax.masked(optax.scale(mult), mask) + for mult, mask in zip(mults, masks) + ] + + # Weight decay. Defaults to 0.0. + # Weight decay is not gradient-based but instead uses "params side-input". + # Hence, weight decay is additive and independent of previous gradient-based + # updates. + assert "weight_decay" not in config, "Deprecated option. Use wd and schedule." + assert config.get("weight_decay_decouple", True), ( + "Coupled weight decay not supported anymore.") + if config.get("wd"): + wd_mults = config.get("wd_mults", [(".*/kernel$", 1.0)]) + masks, mults = _make_mask_trees(params, wd_mults, "config.wd_mults") + weight_decay_txs = [ + optax.add_decayed_weights(config.wd * mult, mask) + for mult, mask in zip(mults, masks) + ] + else: + weight_decay_txs = [] + + # Combine gradient updates and learning rate schedules. + return optax.chain( + grad_clip_norm_tx, + *opt_txs, + *lr_mult_txs, + *weight_decay_txs, + *schedule_txs, + optax.scale(-1.0)), schedule_fns + + +def _make_mask_trees(params, patterns_values, log): + patterns, values = zip(*patterns_values) + masks = u.make_mask_trees(params, patterns, log=log) + return masks, values + + +def _split_frozen(masks, scheds): + """Computes `frozen_mask` and updates `masks` and `scheds`.""" + # Specifying `None` as a scheduler freezes params. + all_false = jax.tree.map(lambda *bools: not any(bools), *masks) + not_covered = [k for k, v in u.tree_flatten_with_names(all_false)[0] if v] + assert not not_covered, ( + f"All params must be covered (use `None` for freezing): {not_covered}") + frozen_masks = [ + mask for mask, sched in zip(masks, scheds) if sched is None] + frozen_mask = jax.tree.map( + lambda *bools: any(bools), *frozen_masks, + all_false) # `all_false` is required when `frozen_masks==[]`. + masks, scheds = zip(*( + (mask, sched) for mask, sched in zip(masks, scheds) if sched is not None)) + return frozen_mask, masks, scheds + + +############ Custom BigVision optimizers ####################################### +# Currently there's only one custom optimizer and we don't foresee new ones in +# the near future, we opt not to create a new optimizer folder/module for just +# one isolated case. If there will be more optimizers, we can consider moving +# them into individual files in a subfolder. + + +# A dummy object to allow for foo.bar access syntax, see +# https://stackoverflow.com/a/19476841/2366315 +optax.big_vision = type("", (), {})() + + +def scale_by_adafactor(min_dim_size_to_factor=32, + decay_rate=0.8, decay_offset=0, + beta2_cap=0.999, + clipping_threshold=None, + momentum=0.9, dtype_momentum=jnp.bfloat16, + eps=1e-30): + """The BigVision variant of Adafactor optimizer.""" + + def _decay_rate_pow(i, exponent): + """Second-order moment decay schedule.""" + t = jnp.array(i, jnp.float32) + 1.0 + return jnp.minimum(beta2_cap, 1.0 - t**(-exponent)) + + scale_by_rms = optax.scale_by_factored_rms( + factored=True, + decay_rate=decay_rate, + step_offset=decay_offset, + min_dim_size_to_factor=min_dim_size_to_factor, + epsilon=eps, + decay_rate_fn=_decay_rate_pow) + + clip = (optax.clip_by_block_rms(clipping_threshold) if clipping_threshold + else optax.identity()) + + mom = (optax.ema(momentum, debias=False, accumulator_dtype=dtype_momentum) + if momentum else optax.identity()) + + return optax.chain(scale_by_rms, clip, mom) + +optax.big_vision.scale_by_adafactor = scale_by_adafactor # pytype: disable=module-attr + + +# A few more aliases we use frequently: +def momentum_hp(momentum=0.9, dtype=jnp.bfloat16, nesterov=False): + """SGD-Momentum with half-precision accumulator.""" + return optax.trace(decay=momentum, accumulator_dtype=dtype, nesterov=nesterov) + +optax.big_vision.momentum_hp = momentum_hp # pytype: disable=module-attr +optax.big_vision.sgd = optax.identity # pytype: disable=module-attr diff --git a/big_vision/optax_test.py b/big_vision/optax_test.py new file mode 100644 index 0000000000000000000000000000000000000000..86f7bd9999079b393565ad5a718b4c1dbd815e79 --- /dev/null +++ b/big_vision/optax_test.py @@ -0,0 +1,341 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for optax.""" + +from absl.testing import absltest +from absl.testing import parameterized +from big_vision import optax as bv_optax +import chex +import jax +import jax.numpy as jnp +import ml_collections +import numpy as np +import optax + + +class OptaxTest(parameterized.TestCase): + + def test_get_count(self): + params = jax.tree.map(jnp.array, {"a": 1.}) + tx = optax.masked( + optax.scale_by_schedule(lambda step: step), + {"a": True}, + ) + opt_state = tx.init(params) + self.assertEqual(bv_optax.get_count(opt_state), 0) + _, opt_state = tx.update(params, opt_state) + self.assertEqual(bv_optax.get_count(opt_state), 1) + + def test_split_frozen(self): + params = jax.tree.map(jnp.array, { + "Dense_0": {"kernel": 1., "bias": 2.}, + }) # pyformat: disable + sched1 = dict(decay_type="cosine") + sched2 = dict(decay_type="linear") + schedule = [ + (".*/kernel", sched1), + (".*/bias", sched2), + ] + masks, scheds = bv_optax._make_mask_trees(params, schedule, log="schedule") + frozen_mask, masks, scheds = bv_optax._split_frozen(masks, scheds) + chex.assert_trees_all_equal( + frozen_mask, + {"Dense_0": {"kernel": False, "bias": False}}, + ) # pyformat: disable + chex.assert_trees_all_equal( + masks, + ( + {"Dense_0": {"kernel": True, "bias": False}}, + {"Dense_0": {"kernel": False, "bias": True}}, + ), + ) # pyformat: disable + self.assertEqual(scheds, (sched1, sched2)) + # freeze some + schedule = [ + (".*/bias", None), + ("Dense_0/.*", sched1), + (".*", None), + ] + masks, scheds = bv_optax._make_mask_trees(params, schedule, log="schedule") + frozen_mask, masks, scheds = bv_optax._split_frozen(masks, scheds) + chex.assert_trees_all_equal( + frozen_mask, + {"Dense_0": {"kernel": False, "bias": True}}, + ) # pyformat: disable + chex.assert_trees_all_equal( + masks, + ({"Dense_0": {"kernel": True, "bias": False}},), + ) # pyformat: disable + self.assertEqual(scheds, (sched1,)) + # does not cover all params - fails + schedule = [ + (".*/kernel", None), + ] + masks, scheds = bv_optax._make_mask_trees(params, schedule, log="schedule") + with self.assertRaisesRegex(AssertionError, "All params must be covered"): + _ = bv_optax._split_frozen(masks, scheds) + + def test_replace_frozen(self): + params = jax.tree.map(jnp.array, { + "Dense_0": {"kernel": 1., "bias": 2.}, + }) # pyformat: disable + schedule = [ + (".*/kernel", {}), + (".*", None), + ] + chex.assert_trees_all_equal( + bv_optax.replace_frozen(schedule, params, 0.), + {"Dense_0": {"kernel": 1., "bias": 0.}}, + ) # pyformat: disable + + def test_make_simple(self): + params = jax.tree.map(jnp.array, { + "Dense_0": {"kernel": 1., "bias": 2.}, + }) # pyformat: disable + + config = ml_collections.ConfigDict() + config.lr = 0.01 + config.schedule = dict(decay_type="linear") + config.optax_name = "scale" + config.optax = ml_collections.ConfigDict() + g_scale = 0.5 + config.optax.step_size = g_scale + + total_steps = 10 + sched_kw = dict(global_batch_size=1, total_steps=total_steps) + tx, (schedule_fn,) = bv_optax.make(config, params, sched_kw=sched_kw) + opt_state = tx.init(params) + grads = jax.tree.map(jnp.ones_like, params) + for step in range(total_steps): + updates, opt_state = tx.update(grads, opt_state) + self.assertEqual(bv_optax.get_count(opt_state), step + 1) + sched = schedule_fn(step) + np.testing.assert_almost_equal( + sched, 1.0 / total_steps * (total_steps - step)) + make_tx = lambda sched: lambda g: -sched * config.lr * g_scale * g + chex.assert_trees_all_close(updates, jax.tree.map(make_tx(sched), grads)) + + def test_make_wd(self): + params = jax.tree.map(jnp.array, { + "Dense_0": {"kernel": 1., "bias": 2., "other": 3.}, + }) # pyformat: disable + wds = jax.tree.map(jnp.array, { + "Dense_0": {"kernel": 2e-3, "bias": 5e-4, "other": 0.}, + }) # pyformat: disable + + config = ml_collections.ConfigDict() + config.lr = 0.01 + config.wd = 1e-3 + config.wd_mults = [ + (".*/kernel", 2.0), + (".*/bias", 0.5), + ] + config.schedule = dict(decay_type="linear") + config.optax_name = "scale" + config.optax = ml_collections.ConfigDict() + g_scale = 0.5 + config.optax.step_size = g_scale + + total_steps = 10 + sched_kw = dict(global_batch_size=1, total_steps=total_steps) + tx, (sched_fn,) = bv_optax.make(config, params, sched_kw=sched_kw) + opt_state = tx.init(params) + grads = jax.tree.map(jnp.ones_like, params) + for step in range(total_steps): + updates, opt_state = tx.update(grads, opt_state, params) + self.assertEqual(bv_optax.get_count(opt_state), step + 1) + sched = sched_fn(step) + np.testing.assert_almost_equal( + sched, 1.0 / total_steps * (total_steps - step)) + + def make_tx(sched): + def inner(p, g, wd): + return -sched * (config.lr * g_scale * g + p * wd) + return inner + + chex.assert_trees_all_close( + updates, jax.tree.map(make_tx(sched), params, grads, wds)) + + def test_make_clip_norm(self): + params = jax.tree.map(jnp.array, { + "Dense_0": {"kernel": 1., "bias": 2., "other": 3.}, + }) # pyformat: disable + + config = ml_collections.ConfigDict() + config.lr = 0.01 + config.schedule = dict(decay_type="linear") + config.optax_name = "scale" + config.grad_clip_norm = 1.0 + config.optax = ml_collections.ConfigDict() + g_scale = 0.5 + config.optax.step_size = g_scale + + total_steps = 10 + sched_kw = dict(global_batch_size=1, total_steps=total_steps) + tx, (sched_fn,) = bv_optax.make(config, params, sched_kw=sched_kw) + opt_state = tx.init(params) + + grads = jax.tree.map(jnp.ones_like, params) + gflat = jax.tree.leaves(grads) + l2_g = jnp.sqrt(sum([jnp.vdot(p, p) for p in gflat])) + grad_clip_factor = jnp.minimum(1.0, config.grad_clip_norm / l2_g) + grads_scaled = jax.tree.map(lambda p: grad_clip_factor * p, grads) + + for step in range(total_steps): + updates, opt_state = tx.update(grads, opt_state) + self.assertEqual(bv_optax.get_count(opt_state), step + 1) + sched = sched_fn(step) + np.testing.assert_almost_equal( + sched, 1.0 / total_steps * (total_steps - step)) + make_tx = lambda sched: lambda g: -sched * config.lr * g_scale * g + chex.assert_trees_all_close(updates, + jax.tree.map(make_tx(sched), grads_scaled)) + + def test_make_multi(self): + params = jax.tree.map( + jnp.array, { + "Dense_0": {"kernel": 1.0, "bias": 2.0, "other": 3.0}, + "Dense_1": {"kernel": 4.0, "bias": 5.0, "other": 6.0}, + "Dense_2": {"kernel": 7.0, "bias": 8.0, "other": 9.0}, + "Dense_3": {"kernel": 10., "bias": 11., "other": 12.}, + }) # pyformat: disable + + # Manually specify lr + wd for computing expected values. + lrb = 0.01 + lr1 = 2.0 + lr2 = 0.5 + lr_mults = { + "Dense_0": {"kernel": lr1, "bias": lr1, "other": lr1}, + "Dense_1": {"kernel": lr2, "bias": lr2, "other": lr2}, + "Dense_2": {"kernel": 1.0, "bias": 1.0, "other": 1.0}, + "Dense_3": {"kernel": 1.0, "bias": 1.0, "other": 1.0}, + } # pyformat: disable + wdb = 1e-3 + wd1 = 10.0 + wd2 = 0.1 + wds = jax.tree.map( + jnp.array, { + "Dense_0": {"kernel": wd1 * wdb, "bias": wd2 * wdb, "other": 0.}, + "Dense_1": {"kernel": wd1 * wdb, "bias": wd2 * wdb, "other": 0.}, + "Dense_2": {"kernel": wd1 * wdb, "bias": wd2 * wdb, "other": 0.}, + "Dense_3": {"kernel": 0.0 * wdb, "bias": 0.0 * wdb, "other": 0.}, + }) # pyformat: disable + + config = ml_collections.ConfigDict() + config.lr = lrb + config.lr_mults = [ + ("Dense_0/.*", lr1), + ("Dense_1/.*", lr2), + ] + config.wd = wdb + config.wd_mults = [ + (".*/kernel", wd1), + (".*/bias", wd2), + ] + mult1 = 1.0 + mult2 = 0.1 + config.schedule = [ + ("Dense_0/.*", dict(decay_type="linear", mult=mult1, linear_end=mult1)), + ("Dense_[12]/.*", dict(decay_type="linear", mult=mult2)), + (".*", None), + ] + config.optax_name = "scale" + config.grad_clip_norm = 1.0 + config.optax = ml_collections.ConfigDict() + g_scale = 0.5 + config.optax.step_size = g_scale + + total_steps = 10 + sched_kw = dict(global_batch_size=1, total_steps=total_steps) + tx, (sched_fn1, + sched_fn2) = bv_optax.make(config, params, sched_kw=sched_kw) + opt_state = tx.init(params) + + # Manually specify schedules for computing expected values. + frozen_fn = lambda _: jnp.array(0.) + sched_fns = { + "Dense_0": {"kernel": sched_fn1, "bias": sched_fn1, "other": sched_fn1}, + "Dense_1": {"kernel": sched_fn2, "bias": sched_fn2, "other": sched_fn2}, + "Dense_2": {"kernel": sched_fn2, "bias": sched_fn2, "other": sched_fn2}, + "Dense_3": {"kernel": frozen_fn, "bias": frozen_fn, "other": frozen_fn}, + } # pyformat: disable + + grads = jax.tree.map(jnp.ones_like, params) + gflat, _ = jax.tree.flatten( + # Don't count frozen params towards gradient norm. + jax.tree.map(lambda g, sched_fn: {frozen_fn: 0}.get(sched_fn, g), + grads, sched_fns)) + l2_g = jnp.sqrt(sum([jnp.vdot(p, p) for p in gflat])) + grad_clip_factor = jnp.minimum(1.0, config.grad_clip_norm / l2_g) + grads_scaled = jax.tree.map(lambda p: grad_clip_factor * p, grads) + + def make_tx(step): + def get_update(p, g, wd, sched_fn, lr_mult): + return -sched_fn(step) * (lrb * lr_mult * g_scale * g + p * wd) + return get_update + + for step in range(total_steps): + updates, opt_state = tx.update(grads, opt_state, params) + self.assertEqual(bv_optax.get_count(opt_state), step + 1) + sched1, sched2 = sched_fn1(step), sched_fn2(step) + np.testing.assert_almost_equal(sched1, mult1) + np.testing.assert_almost_equal(sched2, + mult2 * (total_steps - step) / total_steps) + chex.assert_trees_all_close( + updates, + jax.tree.map( + make_tx(step), params, grads_scaled, wds, sched_fns, lr_mults)) + + def test_frozen_no_state(self): + params = {"small": jnp.zeros([1]), "large": jnp.zeros([1000])} + config = ml_collections.ConfigDict() + config.lr = 0.01 + config.schedule = [ + ("small", dict(decay_type="cosine")), + ("large", None), + ] + config.optax_name = "scale_by_adam" + + sched_kw = dict(global_batch_size=1, total_steps=1) + tx, _ = bv_optax.make(config, params, sched_kw=sched_kw) + + opt_state = tx.init(params) + adam_state = bv_optax.find_states(opt_state, optax.ScaleByAdamState) + nbytes = sum( + jax.tree.flatten(jax.tree.map(lambda x: x.nbytes, adam_state))[0]) + self.assertLess(nbytes, 1_000) + + def test_adafactor(self): + params = {"Dense_0": {"kernel": jnp.zeros([1024, 1024])}} + + config = ml_collections.ConfigDict() + config.optax_name = "big_vision.scale_by_adafactor" + config.lr = 0.01 + config.schedule = dict(decay_type="linear") + sched_kw = dict(global_batch_size=1, total_steps=1) + + tx, _ = bv_optax.make(config, params, sched_kw=sched_kw) + + opt_state = tx.init(params) + adafactor_state = bv_optax.find_states(opt_state, optax.FactoredState) + n_state_params = sum( + jax.tree.flatten( + jax.tree.map(lambda x: np.prod( + x.shape if hasattr(x, "shape") else 0), adafactor_state))[0]) + self.assertEqual(n_state_params, 2 * 1024 + 2) + + +if __name__ == "__main__": + absltest.main() diff --git a/big_vision/pp/__init__.py b/big_vision/pp/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/big_vision/pp/__pycache__/__init__.cpython-310.pyc b/big_vision/pp/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..36115ddb007ca9f814dcb5b4972e3cc8d330b1e3 Binary files /dev/null and b/big_vision/pp/__pycache__/__init__.cpython-310.pyc differ diff --git a/big_vision/pp/__pycache__/registry.cpython-310.pyc b/big_vision/pp/__pycache__/registry.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5d45a2d1e167833beb088c9270f13f1dca5cfac7 Binary files /dev/null and b/big_vision/pp/__pycache__/registry.cpython-310.pyc differ diff --git a/big_vision/pp/archive/__init__.py b/big_vision/pp/archive/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/big_vision/pp/archive/autoaugment.py b/big_vision/pp/archive/autoaugment.py new file mode 100644 index 0000000000000000000000000000000000000000..f1af14ec6a1125ee9c3ea426f9224153483fd8e6 --- /dev/null +++ b/big_vision/pp/archive/autoaugment.py @@ -0,0 +1,700 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""AutoAugment and RandAugment policies for enhanced image preprocessing. + +AutoAugment Reference: https://arxiv.org/abs/1805.09501 +RandAugment Reference: https://arxiv.org/abs/1909.13719 + +This code is forked from +https://github.com/tensorflow/tpu/blob/11d0db15cf1c3667f6e36fecffa111399e008acd/models/official/efficientnet/autoaugment.py +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import dataclasses +import inspect +import math +import tensorflow.compat.v1 as tf +from tensorflow_addons import image as contrib_image + +# This signifies the max integer that the controller RNN could predict for the +# augmentation scheme. +_MAX_LEVEL = 10. + + +@dataclasses.dataclass +class HParams: + """Parameters for AutoAugment and RandAugment.""" + cutout_const: int + translate_const: int + + +def policy_v0(): + """Autoaugment policy that was used in AutoAugment Paper.""" + # Each tuple is an augmentation operation of the form + # (operation, probability, magnitude). Each element in policy is a + # sub-policy that will be applied sequentially on the image. + policy = [ + [('Equalize', 0.8, 1), ('ShearY', 0.8, 4)], + [('Color', 0.4, 9), ('Equalize', 0.6, 3)], + [('Color', 0.4, 1), ('Rotate', 0.6, 8)], + [('Solarize', 0.8, 3), ('Equalize', 0.4, 7)], + [('Solarize', 0.4, 2), ('Solarize', 0.6, 2)], + [('Color', 0.2, 0), ('Equalize', 0.8, 8)], + [('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)], + [('ShearX', 0.2, 9), ('Rotate', 0.6, 8)], + [('Color', 0.6, 1), ('Equalize', 1.0, 2)], + [('Invert', 0.4, 9), ('Rotate', 0.6, 0)], + [('Equalize', 1.0, 9), ('ShearY', 0.6, 3)], + [('Color', 0.4, 7), ('Equalize', 0.6, 0)], + [('Posterize', 0.4, 6), ('AutoContrast', 0.4, 7)], + [('Solarize', 0.6, 8), ('Color', 0.6, 9)], + [('Solarize', 0.2, 4), ('Rotate', 0.8, 9)], + [('Rotate', 1.0, 7), ('TranslateY', 0.8, 9)], + [('ShearX', 0.0, 0), ('Solarize', 0.8, 4)], + [('ShearY', 0.8, 0), ('Color', 0.6, 4)], + [('Color', 1.0, 0), ('Rotate', 0.6, 2)], + [('Equalize', 0.8, 4), ('Equalize', 0.0, 8)], + [('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)], + [('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)], + [('Posterize', 0.8, 2), ('Solarize', 0.6, 10)], + [('Solarize', 0.6, 8), ('Equalize', 0.6, 1)], + [('Color', 0.8, 6), ('Rotate', 0.4, 5)], + ] + return policy + + +def policy_vtest(): + """Autoaugment test policy for debugging.""" + # Each tuple is an augmentation operation of the form + # (operation, probability, magnitude). Each element in policy is a + # sub-policy that will be applied sequentially on the image. + policy = [ + [('TranslateX', 1.0, 4), ('Equalize', 1.0, 10)], + ] + return policy + + +def blend(image1, image2, factor): + """Blend image1 and image2 using 'factor'. + Factor can be above 0.0. A value of 0.0 means only image1 is used. + A value of 1.0 means only image2 is used. A value between 0.0 and + 1.0 means we linearly interpolate the pixel values between the two + images. A value greater than 1.0 "extrapolates" the difference + between the two pixel values, and we clip the results to values + between 0 and 255. + Args: + image1: An image Tensor of type uint8. + image2: An image Tensor of type uint8. + factor: A floating point value above 0.0. + Returns: + A blended image Tensor of type uint8. + """ + if factor == 0.0: + return tf.convert_to_tensor(image1) + if factor == 1.0: + return tf.convert_to_tensor(image2) + + image1 = tf.to_float(image1) + image2 = tf.to_float(image2) + + difference = image2 - image1 + scaled = factor * difference + + # Do addition in float. + temp = tf.to_float(image1) + scaled + + # Interpolate + if factor > 0.0 and factor < 1.0: + # Interpolation means we always stay within 0 and 255. + return tf.cast(temp, tf.uint8) + + # Extrapolate: + # + # We need to clip and then cast. + return tf.cast(tf.clip_by_value(temp, 0.0, 255.0), tf.uint8) + + +def cutout(image, pad_size, replace=0): + """Apply cutout (https://arxiv.org/abs/1708.04552) to image. + This operation applies a (2*pad_size x 2*pad_size) mask of zeros to + a random location within `img`. The pixel values filled in will be of the + value `replace`. The located where the mask will be applied is randomly + chosen uniformly over the whole image. + Args: + image: An image Tensor of type uint8. + pad_size: Specifies how big the zero mask that will be generated is that + is applied to the image. The mask will be of size + (2*pad_size x 2*pad_size). + replace: What pixel value to fill in the image in the area that has + the cutout mask applied to it. + Returns: + An image Tensor that is of type uint8. + """ + image_height = tf.shape(image)[0] + image_width = tf.shape(image)[1] + + # Sample the center location in the image where the zero mask will be applied. + cutout_center_height = tf.random_uniform( + shape=[], minval=0, maxval=image_height, + dtype=tf.int32) + + cutout_center_width = tf.random_uniform( + shape=[], minval=0, maxval=image_width, + dtype=tf.int32) + + lower_pad = tf.maximum(0, cutout_center_height - pad_size) + upper_pad = tf.maximum(0, image_height - cutout_center_height - pad_size) + left_pad = tf.maximum(0, cutout_center_width - pad_size) + right_pad = tf.maximum(0, image_width - cutout_center_width - pad_size) + + cutout_shape = [image_height - (lower_pad + upper_pad), + image_width - (left_pad + right_pad)] + padding_dims = [[lower_pad, upper_pad], [left_pad, right_pad]] + mask = tf.pad( + tf.zeros(cutout_shape, dtype=image.dtype), + padding_dims, constant_values=1) + mask = tf.expand_dims(mask, -1) + mask = tf.tile(mask, [1, 1, 3]) + image = tf.where( + tf.equal(mask, 0), + tf.ones_like(image, dtype=image.dtype) * replace, + image) + return image + + +def solarize(image, threshold=128): + # For each pixel in the image, select the pixel + # if the value is less than the threshold. + # Otherwise, subtract 255 from the pixel. + return tf.where(image < threshold, image, 255 - image) + + +def solarize_add(image, addition=0, threshold=128): + # For each pixel in the image less than threshold + # we add 'addition' amount to it and then clip the + # pixel value to be between 0 and 255. The value + # of 'addition' is between -128 and 128. + added_image = tf.cast(image, tf.int64) + addition + added_image = tf.cast(tf.clip_by_value(added_image, 0, 255), tf.uint8) + return tf.where(image < threshold, added_image, image) + + +def color(image, factor): + """Equivalent of PIL Color.""" + degenerate = tf.image.grayscale_to_rgb(tf.image.rgb_to_grayscale(image)) + return blend(degenerate, image, factor) + + +def contrast(image, factor): + """Equivalent of PIL Contrast.""" + degenerate = tf.image.rgb_to_grayscale(image) + # Cast before calling tf.histogram. + degenerate = tf.cast(degenerate, tf.int32) + + # Compute the grayscale histogram, then compute the mean pixel value, + # and create a constant image size of that value. Use that as the + # blending degenerate target of the original image. + hist = tf.histogram_fixed_width(degenerate, [0, 255], nbins=256) + mean = tf.reduce_sum(tf.cast(hist, tf.float32)) / 256.0 + degenerate = tf.ones_like(degenerate, dtype=tf.float32) * mean + degenerate = tf.clip_by_value(degenerate, 0.0, 255.0) + degenerate = tf.image.grayscale_to_rgb(tf.cast(degenerate, tf.uint8)) + return blend(degenerate, image, factor) + + +def brightness(image, factor): + """Equivalent of PIL Brightness.""" + degenerate = tf.zeros_like(image) + return blend(degenerate, image, factor) + + +def posterize(image, bits): + """Equivalent of PIL Posterize.""" + shift = 8 - bits + return tf.bitwise.left_shift(tf.bitwise.right_shift(image, shift), shift) + + +def rotate(image, degrees, replace): + """Rotates the image by degrees either clockwise or counterclockwise. + Args: + image: An image Tensor of type uint8. + degrees: Float, a scalar angle in degrees to rotate all images by. If + degrees is positive the image will be rotated clockwise otherwise it will + be rotated counterclockwise. + replace: A one or three value 1D tensor to fill empty pixels caused by + the rotate operation. + Returns: + The rotated version of image. + """ + # Convert from degrees to radians. + degrees_to_radians = math.pi / 180.0 + radians = degrees * degrees_to_radians + + # In practice, we should randomize the rotation degrees by flipping + # it negatively half the time, but that's done on 'degrees' outside + # of the function. + image = contrib_image.rotate(wrap(image), radians) + return unwrap(image, replace) + + +def translate_x(image, pixels, replace): + """Equivalent of PIL Translate in X dimension.""" + image = contrib_image.translate(wrap(image), [-pixels, 0]) + return unwrap(image, replace) + + +def translate_y(image, pixels, replace): + """Equivalent of PIL Translate in Y dimension.""" + image = contrib_image.translate(wrap(image), [0, -pixels]) + return unwrap(image, replace) + + +def shear_x(image, level, replace): + """Equivalent of PIL Shearing in X dimension.""" + # Shear parallel to x axis is a projective transform + # with a matrix form of: + # [1 level + # 0 1]. + image = contrib_image.transform( + wrap(image), [1., level, 0., 0., 1., 0., 0., 0.]) + return unwrap(image, replace) + + +def shear_y(image, level, replace): + """Equivalent of PIL Shearing in Y dimension.""" + # Shear parallel to y axis is a projective transform + # with a matrix form of: + # [1 0 + # level 1]. + image = contrib_image.transform( + wrap(image), [1., 0., 0., level, 1., 0., 0., 0.]) + return unwrap(image, replace) + + +def autocontrast(image): + """Implements Autocontrast function from PIL using TF ops. + Args: + image: A 3D uint8 tensor. + Returns: + The image after it has had autocontrast applied to it and will be of type + uint8. + """ + + def scale_channel(image): + """Scale the 2D image using the autocontrast rule.""" + # A possibly cheaper version can be done using cumsum/unique_with_counts + # over the histogram values, rather than iterating over the entire image. + # to compute mins and maxes. + lo = tf.to_float(tf.reduce_min(image)) + hi = tf.to_float(tf.reduce_max(image)) + + # Scale the image, making the lowest value 0 and the highest value 255. + def scale_values(im): + scale = 255.0 / (hi - lo) + offset = -lo * scale + im = tf.to_float(im) * scale + offset + im = tf.clip_by_value(im, 0.0, 255.0) + return tf.cast(im, tf.uint8) + + result = tf.cond(hi > lo, lambda: scale_values(image), lambda: image) + return result + + # Assumes RGB for now. Scales each channel independently + # and then stacks the result. + s1 = scale_channel(image[:, :, 0]) + s2 = scale_channel(image[:, :, 1]) + s3 = scale_channel(image[:, :, 2]) + image = tf.stack([s1, s2, s3], 2) + return image + + +def sharpness(image, factor): + """Implements Sharpness function from PIL using TF ops.""" + orig_image = image + image = tf.cast(image, tf.float32) + # Make image 4D for conv operation. + image = tf.expand_dims(image, 0) + # SMOOTH PIL Kernel. + kernel = tf.constant( + [[1, 1, 1], [1, 5, 1], [1, 1, 1]], dtype=tf.float32, + shape=[3, 3, 1, 1]) / 13. + # Tile across channel dimension. + kernel = tf.tile(kernel, [1, 1, 3, 1]) + strides = [1, 1, 1, 1] + with tf.device('/cpu:0'): + # Some augmentation that uses depth-wise conv will cause crashing when + # training on GPU. See ((internal link)) for details. + degenerate = tf.nn.depthwise_conv2d( + image, kernel, strides, padding='VALID', rate=[1, 1]) + degenerate = tf.clip_by_value(degenerate, 0.0, 255.0) + degenerate = tf.squeeze(tf.cast(degenerate, tf.uint8), [0]) + + # For the borders of the resulting image, fill in the values of the + # original image. + mask = tf.ones_like(degenerate) + padded_mask = tf.pad(mask, [[1, 1], [1, 1], [0, 0]]) + padded_degenerate = tf.pad(degenerate, [[1, 1], [1, 1], [0, 0]]) + result = tf.where(tf.equal(padded_mask, 1), padded_degenerate, orig_image) + + # Blend the final result. + return blend(result, orig_image, factor) + + +def equalize(image): + """Implements Equalize function from PIL using TF ops.""" + def scale_channel(im, c): + """Scale the data in the channel to implement equalize.""" + im = tf.cast(im[:, :, c], tf.int32) + # Compute the histogram of the image channel. + histo = tf.histogram_fixed_width(im, [0, 255], nbins=256) + + # For the purposes of computing the step, filter out the nonzeros. + nonzero = tf.where(tf.not_equal(histo, 0)) + nonzero_histo = tf.reshape(tf.gather(histo, nonzero), [-1]) + step = (tf.reduce_sum(nonzero_histo) - nonzero_histo[-1]) // 255 + + def build_lut(histo, step): + # Compute the cumulative sum, shifting by step // 2 + # and then normalization by step. + lut = (tf.cumsum(histo) + (step // 2)) // step + # Shift lut, prepending with 0. + lut = tf.concat([[0], lut[:-1]], 0) + # Clip the counts to be in range. This is done + # in the C code for image.point. + return tf.clip_by_value(lut, 0, 255) + + # If step is zero, return the original image. Otherwise, build + # lut from the full histogram and step and then index from it. + result = tf.cond(tf.equal(step, 0), + lambda: im, + lambda: tf.gather(build_lut(histo, step), im)) + + return tf.cast(result, tf.uint8) + + # Assumes RGB for now. Scales each channel independently + # and then stacks the result. + s1 = scale_channel(image, 0) + s2 = scale_channel(image, 1) + s3 = scale_channel(image, 2) + image = tf.stack([s1, s2, s3], 2) + return image + + +def invert(image): + """Inverts the image pixels.""" + image = tf.convert_to_tensor(image) + return 255 - image + + +def wrap(image): + """Returns 'image' with an extra channel set to all 1s.""" + shape = tf.shape(image) + extended_channel = tf.ones([shape[0], shape[1], 1], image.dtype) + extended = tf.concat([image, extended_channel], 2) + return extended + + +def unwrap(image, replace): + """Unwraps an image produced by wrap. + Where there is a 0 in the last channel for every spatial position, + the rest of the three channels in that spatial dimension are grayed + (set to 128). Operations like translate and shear on a wrapped + Tensor will leave 0s in empty locations. Some transformations look + at the intensity of values to do preprocessing, and we want these + empty pixels to assume the 'average' value, rather than pure black. + Args: + image: A 3D Image Tensor with 4 channels. + replace: A one or three value 1D tensor to fill empty pixels. + Returns: + image: A 3D image Tensor with 3 channels. + """ + image_shape = tf.shape(image) + # Flatten the spatial dimensions. + flattened_image = tf.reshape(image, [-1, image_shape[2]]) + + # Find all pixels where the last channel is zero. + alpha_channel = flattened_image[:, 3] + + replace = tf.concat([replace, tf.ones([1], image.dtype)], 0) + + # Where they are zero, fill them in with 'replace'. + flattened_image = tf.where( + tf.equal(alpha_channel, 0), + tf.ones_like(flattened_image, dtype=image.dtype) * replace, + flattened_image) + + image = tf.reshape(flattened_image, image_shape) + image = tf.slice(image, [0, 0, 0], [image_shape[0], image_shape[1], 3]) + return image + + +NAME_TO_FUNC = { + 'AutoContrast': autocontrast, + 'Equalize': equalize, + 'Invert': invert, + 'Rotate': rotate, + 'Posterize': posterize, + 'Solarize': solarize, + 'SolarizeAdd': solarize_add, + 'Color': color, + 'Contrast': contrast, + 'Brightness': brightness, + 'Sharpness': sharpness, + 'ShearX': shear_x, + 'ShearY': shear_y, + 'TranslateX': translate_x, + 'TranslateY': translate_y, + 'Cutout': cutout, +} + + +def _randomly_negate_tensor(tensor): + """With 50% prob turn the tensor negative.""" + should_flip = tf.cast(tf.floor(tf.random_uniform([]) + 0.5), tf.bool) + final_tensor = tf.cond(should_flip, lambda: tensor, lambda: -tensor) + return final_tensor + + +def _rotate_level_to_arg(level): + level = (level/_MAX_LEVEL) * 30. + level = _randomly_negate_tensor(level) + return (level,) + + +def _shrink_level_to_arg(level): + """Converts level to ratio by which we shrink the image content.""" + if level == 0: + return (1.0,) # if level is zero, do not shrink the image + # Maximum shrinking ratio is 2.9. + level = 2. / (_MAX_LEVEL / level) + 0.9 + return (level,) + + +def _enhance_level_to_arg(level): + return ((level/_MAX_LEVEL) * 1.8 + 0.1,) + + +def _shear_level_to_arg(level): + level = (level/_MAX_LEVEL) * 0.3 + # Flip level to negative with 50% chance. + level = _randomly_negate_tensor(level) + return (level,) + + +def _translate_level_to_arg(level, translate_const): + level = (level/_MAX_LEVEL) * float(translate_const) + # Flip level to negative with 50% chance. + level = _randomly_negate_tensor(level) + return (level,) + + +def level_to_arg(hparams): + return { + 'AutoContrast': lambda level: (), + 'Equalize': lambda level: (), + 'Invert': lambda level: (), + 'Rotate': _rotate_level_to_arg, + 'Posterize': lambda level: (int((level/_MAX_LEVEL) * 4),), + 'Solarize': lambda level: (int((level/_MAX_LEVEL) * 256),), + 'SolarizeAdd': lambda level: (int((level/_MAX_LEVEL) * 110),), + 'Color': _enhance_level_to_arg, + 'Contrast': _enhance_level_to_arg, + 'Brightness': _enhance_level_to_arg, + 'Sharpness': _enhance_level_to_arg, + 'ShearX': _shear_level_to_arg, + 'ShearY': _shear_level_to_arg, + 'Cutout': lambda level: (int((level/_MAX_LEVEL) * hparams.cutout_const),), + 'TranslateX': lambda level: _translate_level_to_arg( + level, hparams.translate_const), + 'TranslateY': lambda level: _translate_level_to_arg( + level, hparams.translate_const), + # pylint:enable=g-long-lambda + } + + +def _parse_policy_info(name, prob, level, replace_value, augmentation_hparams): + """Return the function that corresponds to `name` and update `level` param.""" + func = NAME_TO_FUNC[name] + args = level_to_arg(augmentation_hparams)[name](level) + + # Check to see if prob is passed into function. This is used for operations + # where we alter bboxes independently. + # pytype:disable=wrong-arg-types + if 'prob' in inspect.getfullargspec(func).args: + args = tuple([prob] + list(args)) + # pytype:enable=wrong-arg-types + + # Add in replace arg if it is required for the function that is being called. + # pytype:disable=wrong-arg-types + if 'replace' in inspect.getfullargspec(func).args: + # Make sure replace is the final argument + assert 'replace' == inspect.getfullargspec(func).args[-1] + args = tuple(list(args) + [replace_value]) + # pytype:enable=wrong-arg-types + + return (func, prob, args) + + +def _apply_func_with_prob(func, image, args, prob): + """Apply `func` to image w/ `args` as input with probability `prob`.""" + assert isinstance(args, tuple) + + # If prob is a function argument, then this randomness is being handled + # inside the function, so make sure it is always called. + # pytype:disable=wrong-arg-types + if 'prob' in inspect.getfullargspec(func).args: + prob = 1.0 + # pytype:enable=wrong-arg-types + + # Apply the function with probability `prob`. + should_apply_op = tf.cast( + tf.floor(tf.random_uniform([], dtype=tf.float32) + prob), tf.bool) + augmented_image = tf.cond( + should_apply_op, + lambda: func(image, *args), + lambda: image) + return augmented_image + + +def select_and_apply_random_policy(policies, image): + """Select a random policy from `policies` and apply it to `image`.""" + policy_to_select = tf.random_uniform([], maxval=len(policies), dtype=tf.int32) + # Note that using tf.case instead of tf.conds would result in significantly + # larger graphs and would even break export for some larger policies. + for (i, policy) in enumerate(policies): + image = tf.cond( + tf.equal(i, policy_to_select), + lambda selected_policy=policy: selected_policy(image), + lambda: image) + return image + + +def build_and_apply_nas_policy(policies, image, + augmentation_hparams): + """Build a policy from the given policies passed in and apply to image. + Args: + policies: list of lists of tuples in the form `(func, prob, level)`, `func` + is a string name of the augmentation function, `prob` is the probability + of applying the `func` operation, `level` is the input argument for + `func`. + image: tf.Tensor that the resulting policy will be applied to. + augmentation_hparams: Hparams associated with the NAS learned policy. + Returns: + A version of image that now has data augmentation applied to it based on + the `policies` pass into the function. + """ + replace_value = [128, 128, 128] + + # func is the string name of the augmentation function, prob is the + # probability of applying the operation and level is the parameter associated + # with the tf op. + + # tf_policies are functions that take in an image and return an augmented + # image. + tf_policies = [] + for policy in policies: + tf_policy = [] + # Link string name to the correct python function and make sure the correct + # argument is passed into that function. + for policy_info in policy: + policy_info = list(policy_info) + [replace_value, augmentation_hparams] + + tf_policy.append(_parse_policy_info(*policy_info)) + # Now build the tf policy that will apply the augmentation procedue + # on image. + def make_final_policy(tf_policy_): + def final_policy(image_): + for func, prob, args in tf_policy_: + image_ = _apply_func_with_prob( + func, image_, args, prob) + return image_ + return final_policy + tf_policies.append(make_final_policy(tf_policy)) + + augmented_image = select_and_apply_random_policy( + tf_policies, image) + return augmented_image + + +def distort_image_with_autoaugment(image, augmentation_name): + """Applies the AutoAugment policy to `image`. + AutoAugment is from the paper: https://arxiv.org/abs/1805.09501. + Args: + image: `Tensor` of shape [height, width, 3] representing an image. + augmentation_name: The name of the AutoAugment policy to use. The available + options are `v0` and `test`. `v0` is the policy used for + all of the results in the paper and was found to achieve the best results + on the COCO dataset. `v1`, `v2` and `v3` are additional good policies + found on the COCO dataset that have slight variation in what operations + were used during the search procedure along with how many operations are + applied in parallel to a single image (2 vs 3). + Returns: + A tuple containing the augmented versions of `image`. + """ + available_policies = {'v0': policy_v0, + 'test': policy_vtest} + if augmentation_name not in available_policies: + raise ValueError('Invalid augmentation_name: {}'.format(augmentation_name)) + + policy = available_policies[augmentation_name]() + # Hparams that will be used for AutoAugment. + augmentation_hparams = HParams( + cutout_const=100, translate_const=250) + + return build_and_apply_nas_policy(policy, image, augmentation_hparams) + + +def distort_image_with_randaugment(image, num_layers, magnitude): + """Applies the RandAugment policy to `image`. + RandAugment is from the paper https://arxiv.org/abs/1909.13719, + Args: + image: `Tensor` of shape [height, width, 3] representing an image. + num_layers: Integer, the number of augmentation transformations to apply + sequentially to an image. Represented as (N) in the paper. Usually best + values will be in the range [1, 3]. + magnitude: Integer, shared magnitude across all augmentation operations. + Represented as (M) in the paper. Usually best values are in the range + [5, 30]. + Returns: + The augmented version of `image`. + """ + replace_value = [128] * 3 + tf.logging.info('Using RandAug.') + augmentation_hparams = HParams( + cutout_const=40, translate_const=100) + available_ops = [ + 'AutoContrast', 'Equalize', 'Invert', 'Rotate', 'Posterize', + 'Solarize', 'Color', 'Contrast', 'Brightness', 'Sharpness', + 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Cutout', 'SolarizeAdd'] + + for layer_num in range(num_layers): + op_to_select = tf.random_uniform( + [], maxval=len(available_ops), dtype=tf.int32) + random_magnitude = float(magnitude) + with tf.name_scope('randaug_layer_{}'.format(layer_num)): + for (i, op_name) in enumerate(available_ops): + prob = tf.random_uniform([], minval=0.2, maxval=0.8, dtype=tf.float32) + func, _, args = _parse_policy_info(op_name, prob, random_magnitude, + replace_value, augmentation_hparams) + image = tf.cond( + tf.equal(i, op_to_select), + lambda selected_func=func, selected_args=args: selected_func( + image, *selected_args), + # pylint:enable=g-long-lambda + lambda: image) + return image diff --git a/big_vision/pp/archive/randaug.py b/big_vision/pp/archive/randaug.py new file mode 100644 index 0000000000000000000000000000000000000000..e8acec830237ce8840a75e44788d1a425ec17289 --- /dev/null +++ b/big_vision/pp/archive/randaug.py @@ -0,0 +1,46 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""RandAug depends on deprecated tfa.image package, now defunct.""" + +from big_vision.pp import registry +from big_vision.pp import utils +from big_vision.pp.archive import autoaugment + + +@registry.Registry.register("preprocess_ops.randaug") +@utils.InKeyOutKey() +def get_randaug(num_layers: int = 2, magnitude: int = 10): + """Creates a function that applies RandAugment. + + RandAugment is from the paper https://arxiv.org/abs/1909.13719, + + Args: + num_layers: Integer, the number of augmentation transformations to apply + sequentially to an image. Represented as (N) in the paper. Usually best + values will be in the range [1, 3]. + magnitude: Integer, shared magnitude across all augmentation operations. + Represented as (M) in the paper. Usually best values are in the range [5, + 30]. + + Returns: + a function that applies RandAugment. + """ + + def _randaug(image): + return autoaugment.distort_image_with_randaugment( + image, num_layers, magnitude + ) + + return _randaug diff --git a/big_vision/pp/autoaugment.py b/big_vision/pp/autoaugment.py new file mode 100644 index 0000000000000000000000000000000000000000..6cc45f14e5d8c49cb54c649104851e0729ebb180 --- /dev/null +++ b/big_vision/pp/autoaugment.py @@ -0,0 +1,700 @@ +# Copyright 2023 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""AutoAugment and RandAugment policies for enhanced image preprocessing. + +AutoAugment Reference: https://arxiv.org/abs/1805.09501 +RandAugment Reference: https://arxiv.org/abs/1909.13719 + +This code is forked from +https://github.com/tensorflow/tpu/blob/11d0db15cf1c3667f6e36fecffa111399e008acd/models/official/efficientnet/autoaugment.py +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import dataclasses +import inspect +import math +import tensorflow.compat.v1 as tf +from tensorflow_addons import image as contrib_image + +# This signifies the max integer that the controller RNN could predict for the +# augmentation scheme. +_MAX_LEVEL = 10. + + +@dataclasses.dataclass +class HParams: + """Parameters for AutoAugment and RandAugment.""" + cutout_const: int + translate_const: int + + +def policy_v0(): + """Autoaugment policy that was used in AutoAugment Paper.""" + # Each tuple is an augmentation operation of the form + # (operation, probability, magnitude). Each element in policy is a + # sub-policy that will be applied sequentially on the image. + policy = [ + [('Equalize', 0.8, 1), ('ShearY', 0.8, 4)], + [('Color', 0.4, 9), ('Equalize', 0.6, 3)], + [('Color', 0.4, 1), ('Rotate', 0.6, 8)], + [('Solarize', 0.8, 3), ('Equalize', 0.4, 7)], + [('Solarize', 0.4, 2), ('Solarize', 0.6, 2)], + [('Color', 0.2, 0), ('Equalize', 0.8, 8)], + [('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)], + [('ShearX', 0.2, 9), ('Rotate', 0.6, 8)], + [('Color', 0.6, 1), ('Equalize', 1.0, 2)], + [('Invert', 0.4, 9), ('Rotate', 0.6, 0)], + [('Equalize', 1.0, 9), ('ShearY', 0.6, 3)], + [('Color', 0.4, 7), ('Equalize', 0.6, 0)], + [('Posterize', 0.4, 6), ('AutoContrast', 0.4, 7)], + [('Solarize', 0.6, 8), ('Color', 0.6, 9)], + [('Solarize', 0.2, 4), ('Rotate', 0.8, 9)], + [('Rotate', 1.0, 7), ('TranslateY', 0.8, 9)], + [('ShearX', 0.0, 0), ('Solarize', 0.8, 4)], + [('ShearY', 0.8, 0), ('Color', 0.6, 4)], + [('Color', 1.0, 0), ('Rotate', 0.6, 2)], + [('Equalize', 0.8, 4), ('Equalize', 0.0, 8)], + [('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)], + [('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)], + [('Posterize', 0.8, 2), ('Solarize', 0.6, 10)], + [('Solarize', 0.6, 8), ('Equalize', 0.6, 1)], + [('Color', 0.8, 6), ('Rotate', 0.4, 5)], + ] + return policy + + +def policy_vtest(): + """Autoaugment test policy for debugging.""" + # Each tuple is an augmentation operation of the form + # (operation, probability, magnitude). Each element in policy is a + # sub-policy that will be applied sequentially on the image. + policy = [ + [('TranslateX', 1.0, 4), ('Equalize', 1.0, 10)], + ] + return policy + + +def blend(image1, image2, factor): + """Blend image1 and image2 using 'factor'. + Factor can be above 0.0. A value of 0.0 means only image1 is used. + A value of 1.0 means only image2 is used. A value between 0.0 and + 1.0 means we linearly interpolate the pixel values between the two + images. A value greater than 1.0 "extrapolates" the difference + between the two pixel values, and we clip the results to values + between 0 and 255. + Args: + image1: An image Tensor of type uint8. + image2: An image Tensor of type uint8. + factor: A floating point value above 0.0. + Returns: + A blended image Tensor of type uint8. + """ + if factor == 0.0: + return tf.convert_to_tensor(image1) + if factor == 1.0: + return tf.convert_to_tensor(image2) + + image1 = tf.to_float(image1) + image2 = tf.to_float(image2) + + difference = image2 - image1 + scaled = factor * difference + + # Do addition in float. + temp = tf.to_float(image1) + scaled + + # Interpolate + if factor > 0.0 and factor < 1.0: + # Interpolation means we always stay within 0 and 255. + return tf.cast(temp, tf.uint8) + + # Extrapolate: + # + # We need to clip and then cast. + return tf.cast(tf.clip_by_value(temp, 0.0, 255.0), tf.uint8) + + +def cutout(image, pad_size, replace=0): + """Apply cutout (https://arxiv.org/abs/1708.04552) to image. + This operation applies a (2*pad_size x 2*pad_size) mask of zeros to + a random location within `img`. The pixel values filled in will be of the + value `replace`. The located where the mask will be applied is randomly + chosen uniformly over the whole image. + Args: + image: An image Tensor of type uint8. + pad_size: Specifies how big the zero mask that will be generated is that + is applied to the image. The mask will be of size + (2*pad_size x 2*pad_size). + replace: What pixel value to fill in the image in the area that has + the cutout mask applied to it. + Returns: + An image Tensor that is of type uint8. + """ + image_height = tf.shape(image)[0] + image_width = tf.shape(image)[1] + + # Sample the center location in the image where the zero mask will be applied. + cutout_center_height = tf.random_uniform( + shape=[], minval=0, maxval=image_height, + dtype=tf.int32) + + cutout_center_width = tf.random_uniform( + shape=[], minval=0, maxval=image_width, + dtype=tf.int32) + + lower_pad = tf.maximum(0, cutout_center_height - pad_size) + upper_pad = tf.maximum(0, image_height - cutout_center_height - pad_size) + left_pad = tf.maximum(0, cutout_center_width - pad_size) + right_pad = tf.maximum(0, image_width - cutout_center_width - pad_size) + + cutout_shape = [image_height - (lower_pad + upper_pad), + image_width - (left_pad + right_pad)] + padding_dims = [[lower_pad, upper_pad], [left_pad, right_pad]] + mask = tf.pad( + tf.zeros(cutout_shape, dtype=image.dtype), + padding_dims, constant_values=1) + mask = tf.expand_dims(mask, -1) + mask = tf.tile(mask, [1, 1, 3]) + image = tf.where( + tf.equal(mask, 0), + tf.ones_like(image, dtype=image.dtype) * replace, + image) + return image + + +def solarize(image, threshold=128): + # For each pixel in the image, select the pixel + # if the value is less than the threshold. + # Otherwise, subtract 255 from the pixel. + return tf.where(image < threshold, image, 255 - image) + + +def solarize_add(image, addition=0, threshold=128): + # For each pixel in the image less than threshold + # we add 'addition' amount to it and then clip the + # pixel value to be between 0 and 255. The value + # of 'addition' is between -128 and 128. + added_image = tf.cast(image, tf.int64) + addition + added_image = tf.cast(tf.clip_by_value(added_image, 0, 255), tf.uint8) + return tf.where(image < threshold, added_image, image) + + +def color(image, factor): + """Equivalent of PIL Color.""" + degenerate = tf.image.grayscale_to_rgb(tf.image.rgb_to_grayscale(image)) + return blend(degenerate, image, factor) + + +def contrast(image, factor): + """Equivalent of PIL Contrast.""" + degenerate = tf.image.rgb_to_grayscale(image) + # Cast before calling tf.histogram. + degenerate = tf.cast(degenerate, tf.int32) + + # Compute the grayscale histogram, then compute the mean pixel value, + # and create a constant image size of that value. Use that as the + # blending degenerate target of the original image. + hist = tf.histogram_fixed_width(degenerate, [0, 255], nbins=256) + mean = tf.reduce_sum(tf.cast(hist, tf.float32)) / 256.0 + degenerate = tf.ones_like(degenerate, dtype=tf.float32) * mean + degenerate = tf.clip_by_value(degenerate, 0.0, 255.0) + degenerate = tf.image.grayscale_to_rgb(tf.cast(degenerate, tf.uint8)) + return blend(degenerate, image, factor) + + +def brightness(image, factor): + """Equivalent of PIL Brightness.""" + degenerate = tf.zeros_like(image) + return blend(degenerate, image, factor) + + +def posterize(image, bits): + """Equivalent of PIL Posterize.""" + shift = 8 - bits + return tf.bitwise.left_shift(tf.bitwise.right_shift(image, shift), shift) + + +def rotate(image, degrees, replace): + """Rotates the image by degrees either clockwise or counterclockwise. + Args: + image: An image Tensor of type uint8. + degrees: Float, a scalar angle in degrees to rotate all images by. If + degrees is positive the image will be rotated clockwise otherwise it will + be rotated counterclockwise. + replace: A one or three value 1D tensor to fill empty pixels caused by + the rotate operation. + Returns: + The rotated version of image. + """ + # Convert from degrees to radians. + degrees_to_radians = math.pi / 180.0 + radians = degrees * degrees_to_radians + + # In practice, we should randomize the rotation degrees by flipping + # it negatively half the time, but that's done on 'degrees' outside + # of the function. + image = contrib_image.rotate(wrap(image), radians) + return unwrap(image, replace) + + +def translate_x(image, pixels, replace): + """Equivalent of PIL Translate in X dimension.""" + image = contrib_image.translate(wrap(image), [-pixels, 0]) + return unwrap(image, replace) + + +def translate_y(image, pixels, replace): + """Equivalent of PIL Translate in Y dimension.""" + image = contrib_image.translate(wrap(image), [0, -pixels]) + return unwrap(image, replace) + + +def shear_x(image, level, replace): + """Equivalent of PIL Shearing in X dimension.""" + # Shear parallel to x axis is a projective transform + # with a matrix form of: + # [1 level + # 0 1]. + image = contrib_image.transform( + wrap(image), [1., level, 0., 0., 1., 0., 0., 0.]) + return unwrap(image, replace) + + +def shear_y(image, level, replace): + """Equivalent of PIL Shearing in Y dimension.""" + # Shear parallel to y axis is a projective transform + # with a matrix form of: + # [1 0 + # level 1]. + image = contrib_image.transform( + wrap(image), [1., 0., 0., level, 1., 0., 0., 0.]) + return unwrap(image, replace) + + +def autocontrast(image): + """Implements Autocontrast function from PIL using TF ops. + Args: + image: A 3D uint8 tensor. + Returns: + The image after it has had autocontrast applied to it and will be of type + uint8. + """ + + def scale_channel(image): + """Scale the 2D image using the autocontrast rule.""" + # A possibly cheaper version can be done using cumsum/unique_with_counts + # over the histogram values, rather than iterating over the entire image. + # to compute mins and maxes. + lo = tf.to_float(tf.reduce_min(image)) + hi = tf.to_float(tf.reduce_max(image)) + + # Scale the image, making the lowest value 0 and the highest value 255. + def scale_values(im): + scale = 255.0 / (hi - lo) + offset = -lo * scale + im = tf.to_float(im) * scale + offset + im = tf.clip_by_value(im, 0.0, 255.0) + return tf.cast(im, tf.uint8) + + result = tf.cond(hi > lo, lambda: scale_values(image), lambda: image) + return result + + # Assumes RGB for now. Scales each channel independently + # and then stacks the result. + s1 = scale_channel(image[:, :, 0]) + s2 = scale_channel(image[:, :, 1]) + s3 = scale_channel(image[:, :, 2]) + image = tf.stack([s1, s2, s3], 2) + return image + + +def sharpness(image, factor): + """Implements Sharpness function from PIL using TF ops.""" + orig_image = image + image = tf.cast(image, tf.float32) + # Make image 4D for conv operation. + image = tf.expand_dims(image, 0) + # SMOOTH PIL Kernel. + kernel = tf.constant( + [[1, 1, 1], [1, 5, 1], [1, 1, 1]], dtype=tf.float32, + shape=[3, 3, 1, 1]) / 13. + # Tile across channel dimension. + kernel = tf.tile(kernel, [1, 1, 3, 1]) + strides = [1, 1, 1, 1] + with tf.device('/cpu:0'): + # Some augmentation that uses depth-wise conv will cause crashing when + # training on GPU. See ((internal link)) for details. + degenerate = tf.nn.depthwise_conv2d( + image, kernel, strides, padding='VALID', rate=[1, 1]) + degenerate = tf.clip_by_value(degenerate, 0.0, 255.0) + degenerate = tf.squeeze(tf.cast(degenerate, tf.uint8), [0]) + + # For the borders of the resulting image, fill in the values of the + # original image. + mask = tf.ones_like(degenerate) + padded_mask = tf.pad(mask, [[1, 1], [1, 1], [0, 0]]) + padded_degenerate = tf.pad(degenerate, [[1, 1], [1, 1], [0, 0]]) + result = tf.where(tf.equal(padded_mask, 1), padded_degenerate, orig_image) + + # Blend the final result. + return blend(result, orig_image, factor) + + +def equalize(image): + """Implements Equalize function from PIL using TF ops.""" + def scale_channel(im, c): + """Scale the data in the channel to implement equalize.""" + im = tf.cast(im[:, :, c], tf.int32) + # Compute the histogram of the image channel. + histo = tf.histogram_fixed_width(im, [0, 255], nbins=256) + + # For the purposes of computing the step, filter out the nonzeros. + nonzero = tf.where(tf.not_equal(histo, 0)) + nonzero_histo = tf.reshape(tf.gather(histo, nonzero), [-1]) + step = (tf.reduce_sum(nonzero_histo) - nonzero_histo[-1]) // 255 + + def build_lut(histo, step): + # Compute the cumulative sum, shifting by step // 2 + # and then normalization by step. + lut = (tf.cumsum(histo) + (step // 2)) // step + # Shift lut, prepending with 0. + lut = tf.concat([[0], lut[:-1]], 0) + # Clip the counts to be in range. This is done + # in the C code for image.point. + return tf.clip_by_value(lut, 0, 255) + + # If step is zero, return the original image. Otherwise, build + # lut from the full histogram and step and then index from it. + result = tf.cond(tf.equal(step, 0), + lambda: im, + lambda: tf.gather(build_lut(histo, step), im)) + + return tf.cast(result, tf.uint8) + + # Assumes RGB for now. Scales each channel independently + # and then stacks the result. + s1 = scale_channel(image, 0) + s2 = scale_channel(image, 1) + s3 = scale_channel(image, 2) + image = tf.stack([s1, s2, s3], 2) + return image + + +def invert(image): + """Inverts the image pixels.""" + image = tf.convert_to_tensor(image) + return 255 - image + + +def wrap(image): + """Returns 'image' with an extra channel set to all 1s.""" + shape = tf.shape(image) + extended_channel = tf.ones([shape[0], shape[1], 1], image.dtype) + extended = tf.concat([image, extended_channel], 2) + return extended + + +def unwrap(image, replace): + """Unwraps an image produced by wrap. + Where there is a 0 in the last channel for every spatial position, + the rest of the three channels in that spatial dimension are grayed + (set to 128). Operations like translate and shear on a wrapped + Tensor will leave 0s in empty locations. Some transformations look + at the intensity of values to do preprocessing, and we want these + empty pixels to assume the 'average' value, rather than pure black. + Args: + image: A 3D Image Tensor with 4 channels. + replace: A one or three value 1D tensor to fill empty pixels. + Returns: + image: A 3D image Tensor with 3 channels. + """ + image_shape = tf.shape(image) + # Flatten the spatial dimensions. + flattened_image = tf.reshape(image, [-1, image_shape[2]]) + + # Find all pixels where the last channel is zero. + alpha_channel = flattened_image[:, 3] + + replace = tf.concat([replace, tf.ones([1], image.dtype)], 0) + + # Where they are zero, fill them in with 'replace'. + flattened_image = tf.where( + tf.equal(alpha_channel, 0), + tf.ones_like(flattened_image, dtype=image.dtype) * replace, + flattened_image) + + image = tf.reshape(flattened_image, image_shape) + image = tf.slice(image, [0, 0, 0], [image_shape[0], image_shape[1], 3]) + return image + + +NAME_TO_FUNC = { + 'AutoContrast': autocontrast, + 'Equalize': equalize, + 'Invert': invert, + 'Rotate': rotate, + 'Posterize': posterize, + 'Solarize': solarize, + 'SolarizeAdd': solarize_add, + 'Color': color, + 'Contrast': contrast, + 'Brightness': brightness, + 'Sharpness': sharpness, + 'ShearX': shear_x, + 'ShearY': shear_y, + 'TranslateX': translate_x, + 'TranslateY': translate_y, + 'Cutout': cutout, +} + + +def _randomly_negate_tensor(tensor): + """With 50% prob turn the tensor negative.""" + should_flip = tf.cast(tf.floor(tf.random_uniform([]) + 0.5), tf.bool) + final_tensor = tf.cond(should_flip, lambda: tensor, lambda: -tensor) + return final_tensor + + +def _rotate_level_to_arg(level): + level = (level/_MAX_LEVEL) * 30. + level = _randomly_negate_tensor(level) + return (level,) + + +def _shrink_level_to_arg(level): + """Converts level to ratio by which we shrink the image content.""" + if level == 0: + return (1.0,) # if level is zero, do not shrink the image + # Maximum shrinking ratio is 2.9. + level = 2. / (_MAX_LEVEL / level) + 0.9 + return (level,) + + +def _enhance_level_to_arg(level): + return ((level/_MAX_LEVEL) * 1.8 + 0.1,) + + +def _shear_level_to_arg(level): + level = (level/_MAX_LEVEL) * 0.3 + # Flip level to negative with 50% chance. + level = _randomly_negate_tensor(level) + return (level,) + + +def _translate_level_to_arg(level, translate_const): + level = (level/_MAX_LEVEL) * float(translate_const) + # Flip level to negative with 50% chance. + level = _randomly_negate_tensor(level) + return (level,) + + +def level_to_arg(hparams): + return { + 'AutoContrast': lambda level: (), + 'Equalize': lambda level: (), + 'Invert': lambda level: (), + 'Rotate': _rotate_level_to_arg, + 'Posterize': lambda level: (int((level/_MAX_LEVEL) * 4),), + 'Solarize': lambda level: (int((level/_MAX_LEVEL) * 256),), + 'SolarizeAdd': lambda level: (int((level/_MAX_LEVEL) * 110),), + 'Color': _enhance_level_to_arg, + 'Contrast': _enhance_level_to_arg, + 'Brightness': _enhance_level_to_arg, + 'Sharpness': _enhance_level_to_arg, + 'ShearX': _shear_level_to_arg, + 'ShearY': _shear_level_to_arg, + 'Cutout': lambda level: (int((level/_MAX_LEVEL) * hparams.cutout_const),), + 'TranslateX': lambda level: _translate_level_to_arg( + level, hparams.translate_const), + 'TranslateY': lambda level: _translate_level_to_arg( + level, hparams.translate_const), + # pylint:enable=g-long-lambda + } + + +def _parse_policy_info(name, prob, level, replace_value, augmentation_hparams): + """Return the function that corresponds to `name` and update `level` param.""" + func = NAME_TO_FUNC[name] + args = level_to_arg(augmentation_hparams)[name](level) + + # Check to see if prob is passed into function. This is used for operations + # where we alter bboxes independently. + # pytype:disable=wrong-arg-types + if 'prob' in inspect.getfullargspec(func).args: + args = tuple([prob] + list(args)) + # pytype:enable=wrong-arg-types + + # Add in replace arg if it is required for the function that is being called. + # pytype:disable=wrong-arg-types + if 'replace' in inspect.getfullargspec(func).args: + # Make sure replace is the final argument + assert 'replace' == inspect.getfullargspec(func).args[-1] + args = tuple(list(args) + [replace_value]) + # pytype:enable=wrong-arg-types + + return (func, prob, args) + + +def _apply_func_with_prob(func, image, args, prob): + """Apply `func` to image w/ `args` as input with probability `prob`.""" + assert isinstance(args, tuple) + + # If prob is a function argument, then this randomness is being handled + # inside the function, so make sure it is always called. + # pytype:disable=wrong-arg-types + if 'prob' in inspect.getfullargspec(func).args: + prob = 1.0 + # pytype:enable=wrong-arg-types + + # Apply the function with probability `prob`. + should_apply_op = tf.cast( + tf.floor(tf.random_uniform([], dtype=tf.float32) + prob), tf.bool) + augmented_image = tf.cond( + should_apply_op, + lambda: func(image, *args), + lambda: image) + return augmented_image + + +def select_and_apply_random_policy(policies, image): + """Select a random policy from `policies` and apply it to `image`.""" + policy_to_select = tf.random_uniform([], maxval=len(policies), dtype=tf.int32) + # Note that using tf.case instead of tf.conds would result in significantly + # larger graphs and would even break export for some larger policies. + for (i, policy) in enumerate(policies): + image = tf.cond( + tf.equal(i, policy_to_select), + lambda selected_policy=policy: selected_policy(image), + lambda: image) + return image + + +def build_and_apply_nas_policy(policies, image, + augmentation_hparams): + """Build a policy from the given policies passed in and apply to image. + Args: + policies: list of lists of tuples in the form `(func, prob, level)`, `func` + is a string name of the augmentation function, `prob` is the probability + of applying the `func` operation, `level` is the input argument for + `func`. + image: tf.Tensor that the resulting policy will be applied to. + augmentation_hparams: Hparams associated with the NAS learned policy. + Returns: + A version of image that now has data augmentation applied to it based on + the `policies` pass into the function. + """ + replace_value = [128, 128, 128] + + # func is the string name of the augmentation function, prob is the + # probability of applying the operation and level is the parameter associated + # with the tf op. + + # tf_policies are functions that take in an image and return an augmented + # image. + tf_policies = [] + for policy in policies: + tf_policy = [] + # Link string name to the correct python function and make sure the correct + # argument is passed into that function. + for policy_info in policy: + policy_info = list(policy_info) + [replace_value, augmentation_hparams] + + tf_policy.append(_parse_policy_info(*policy_info)) + # Now build the tf policy that will apply the augmentation procedue + # on image. + def make_final_policy(tf_policy_): + def final_policy(image_): + for func, prob, args in tf_policy_: + image_ = _apply_func_with_prob( + func, image_, args, prob) + return image_ + return final_policy + tf_policies.append(make_final_policy(tf_policy)) + + augmented_image = select_and_apply_random_policy( + tf_policies, image) + return augmented_image + + +def distort_image_with_autoaugment(image, augmentation_name): + """Applies the AutoAugment policy to `image`. + AutoAugment is from the paper: https://arxiv.org/abs/1805.09501. + Args: + image: `Tensor` of shape [height, width, 3] representing an image. + augmentation_name: The name of the AutoAugment policy to use. The available + options are `v0` and `test`. `v0` is the policy used for + all of the results in the paper and was found to achieve the best results + on the COCO dataset. `v1`, `v2` and `v3` are additional good policies + found on the COCO dataset that have slight variation in what operations + were used during the search procedure along with how many operations are + applied in parallel to a single image (2 vs 3). + Returns: + A tuple containing the augmented versions of `image`. + """ + available_policies = {'v0': policy_v0, + 'test': policy_vtest} + if augmentation_name not in available_policies: + raise ValueError('Invalid augmentation_name: {}'.format(augmentation_name)) + + policy = available_policies[augmentation_name]() + # Hparams that will be used for AutoAugment. + augmentation_hparams = HParams( + cutout_const=100, translate_const=250) + + return build_and_apply_nas_policy(policy, image, augmentation_hparams) + + +def distort_image_with_randaugment(image, num_layers, magnitude): + """Applies the RandAugment policy to `image`. + RandAugment is from the paper https://arxiv.org/abs/1909.13719, + Args: + image: `Tensor` of shape [height, width, 3] representing an image. + num_layers: Integer, the number of augmentation transformations to apply + sequentially to an image. Represented as (N) in the paper. Usually best + values will be in the range [1, 3]. + magnitude: Integer, shared magnitude across all augmentation operations. + Represented as (M) in the paper. Usually best values are in the range + [5, 30]. + Returns: + The augmented version of `image`. + """ + replace_value = [128] * 3 + tf.logging.info('Using RandAug.') + augmentation_hparams = HParams( + cutout_const=40, translate_const=100) + available_ops = [ + 'AutoContrast', 'Equalize', 'Invert', 'Rotate', 'Posterize', + 'Solarize', 'Color', 'Contrast', 'Brightness', 'Sharpness', + 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Cutout', 'SolarizeAdd'] + + for layer_num in range(num_layers): + op_to_select = tf.random_uniform( + [], maxval=len(available_ops), dtype=tf.int32) + random_magnitude = float(magnitude) + with tf.name_scope('randaug_layer_{}'.format(layer_num)): + for (i, op_name) in enumerate(available_ops): + prob = tf.random_uniform([], minval=0.2, maxval=0.8, dtype=tf.float32) + func, _, args = _parse_policy_info(op_name, prob, random_magnitude, + replace_value, augmentation_hparams) + image = tf.cond( + tf.equal(i, op_to_select), + lambda selected_func=func, selected_args=args: selected_func( + image, *selected_args), + # pylint:enable=g-long-lambda + lambda: image) + return image diff --git a/big_vision/pp/builder.py b/big_vision/pp/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..c612ddbbdd5fdbf831984ed9df39a9a7cfbe5d55 --- /dev/null +++ b/big_vision/pp/builder.py @@ -0,0 +1,81 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Preprocessing builder.""" + +from absl import logging +from big_vision.pp import registry +import tensorflow as tf + + +def get_preprocess_fn(pp_pipeline, log_data=True): + """Transform an input string into the preprocessing function. + + The minilanguage is as follows: + + fn1|fn2(arg, arg2,...)|... + + And describes the successive application of the various `fn`s to the input, + where each function can optionally have one or more arguments, which are + either positional or key/value, as dictated by the `fn`. + + The output preprocessing function expects a dictionary as input. This + dictionary should have a key "image" that corresponds to a 3D tensor + (height x width x channel). + + Args: + pp_pipeline: A string describing the pre-processing pipeline. If empty or + None, no preprocessing will be executed. + log_data: Whether to log the data before and after preprocessing. Can also + be a string to show in the log for debugging, for example dataset name. + + Returns: + preprocessing function. + + Raises: + ValueError: if preprocessing function name is unknown + """ + + names, ops = [], [] + if pp_pipeline: + for op_spec in pp_pipeline.split("|"): + if not op_spec: continue # Skip empty section instead of error. + try: + ops.append(registry.Registry.lookup(f"preprocess_ops.{op_spec}")()) + names.append(registry.parse_name(op_spec)[0]) + except SyntaxError as err: + raise ValueError(f"Syntax error on: {op_spec}") from err + + def _preprocess_fn(data): + """The preprocessing function that is returned.""" + nonlocal log_data + + # Apply all the individual steps in sequence. + if log_data: + logging.info("Data before pre-processing (%s):\n%s", log_data, data) + for name, op in zip(names, ops): + with tf.name_scope(name): + data = op(data) + + # Validate input + if not isinstance(data, dict): + raise ValueError("Argument `data` must be a dictionary, " + "not %s" % str(type(data))) + + if log_data: + logging.info("Data after pre-processing (%s):\n%s", log_data, data) + log_data = False # For eager&pygrain: only log first one of each pipeline. + return data + + return _preprocess_fn diff --git a/big_vision/pp/builder_test.py b/big_vision/pp/builder_test.py new file mode 100644 index 0000000000000000000000000000000000000000..f3a75cc05417a7f3ace70f97583d2e7b1ae4c432 --- /dev/null +++ b/big_vision/pp/builder_test.py @@ -0,0 +1,72 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for builder.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from big_vision.pp import builder +from big_vision.pp import ops_general # pylint: disable=unused-import +from big_vision.pp import ops_image # pylint: disable=unused-import +import numpy as np +import tensorflow.compat.v1 as tf + + +class BuilderTest(tf.test.TestCase): + + def testSingle(self): + pp_fn = builder.get_preprocess_fn("resize(256)") + x = np.random.randint(0, 256, [640, 480, 3]) + image = pp_fn({"image": x})["image"] + self.assertEqual(image.numpy().shape, (256, 256, 3)) + + def testEmpty(self): + pp_fn = builder.get_preprocess_fn("||inception_crop|||resize(256)||") + + # Typical image input + x = np.random.randint(0, 256, [640, 480, 3]) + image = pp_fn({"image": x})["image"] + self.assertEqual(image.numpy().shape, (256, 256, 3)) + + def testPreprocessingPipeline(self): + pp_str = ("inception_crop|resize(256)|resize((256, 256))|" + "central_crop((80, 120))|flip_lr|value_range(0,1)|" + "value_range(-1,1)") + pp_fn = builder.get_preprocess_fn(pp_str) + + # Typical image input + x = np.random.randint(0, 256, [640, 480, 3]) + image = pp_fn({"image": x})["image"] + self.assertEqual(image.numpy().shape, (80, 120, 3)) + self.assertLessEqual(np.max(image.numpy()), 1) + self.assertGreaterEqual(np.min(image.numpy()), -1) + + def testNumArgsException(self): + + x = np.random.randint(0, 256, [640, 480, 3]) + for pp_str in [ + "inception_crop(1)", + "resize()", + "resize(1, 1, 1)" + "flip_lr(1)", + "central_crop()", + ]: + with self.assertRaises(BaseException): + builder.get_preprocess_fn(pp_str)(x) + + +if __name__ == "__main__": + tf.test.main() diff --git a/big_vision/pp/ops_general.py b/big_vision/pp/ops_general.py new file mode 100644 index 0000000000000000000000000000000000000000..32622cba59143fc96285ab281ac8ad356f50bd73 --- /dev/null +++ b/big_vision/pp/ops_general.py @@ -0,0 +1,453 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Generic tensor preprocessing ops. + +All preprocessing ops should return a data processing functors. A data +is represented as a dictionary of (TF) tensors. The functors output a modified +dictionary. +""" + +import collections + +from big_vision.pp import utils +from big_vision.pp.registry import Registry +import big_vision.utils as bv_utils +import jax +import numpy as np +import tensorflow as tf + + +@Registry.register("preprocess_ops.value_range") +@utils.InKeyOutKey() +def get_value_range(vmin=-1, vmax=1, in_min=0, in_max=255.0, clip_values=False): + """Transforms a [in_min,in_max] image to [vmin,vmax] range. + + Input ranges in_min/in_max can be equal-size lists to rescale the invidudal + channels independently. + + Args: + vmin: A scalar. Output max value. + vmax: A scalar. Output min value. + in_min: A scalar or a list of input min values to scale. If a list, the + length should match to the number of channels in the image. + in_max: A scalar or a list of input max values to scale. If a list, the + length should match to the number of channels in the image. + clip_values: Whether to clip the output values to the provided ranges. + + Returns: + A function to rescale the values. + """ + + def _value_range(image): + """Scales values in given range.""" + in_min_t = tf.constant(in_min, tf.float32) + in_max_t = tf.constant(in_max, tf.float32) + image = tf.cast(image, tf.float32) + image = (image - in_min_t) / (in_max_t - in_min_t) + image = vmin + image * (vmax - vmin) + if clip_values: + image = tf.clip_by_value(image, vmin, vmax) + return image + + return _value_range + + +@Registry.register("preprocess_ops.lookup") +@utils.InKeyOutKey() +def get_lookup(mapping, npzkey="fnames", sep=None): + """Map string to number.""" + + # For NumPy files, we use the `npzkey` array in that file as the list of + # strings which are mapped to their index in that array. + # This is especially useful when other data (eg precomputed predictions) + # goes along with this mapping, to have everything in one place (the npz). + if mapping.endswith(".npz"): + with tf.io.gfile.GFile(mapping, "rb") as f: + keys = np.array(np.load(f, allow_pickle=False)[npzkey]) + vals = np.arange(len(keys)) + + # Otherwise, we simply use the file as a text file, with either of: + # - a string per line, mapped to its line-number + # - a pair, separated by `sep` per line, first value being the string, second + # value being the integer that the string is mapped to. + else: + with tf.io.gfile.GFile(mapping, "r") as f: + buf = f.read() + if sep is None: # values are the line numbers + keys = buf.splitlines() + vals = np.arange(len(keys)) + else: # each line is keyval, also make val int + keys, vals = zip(*[l.split(sep) for l in buf.splitlines()]) + vals = [int(v) for v in vals] + + def _do_the_mapping(needle): + """Map string to number.""" + with tf.init_scope(): # (Originally added for performance reasons.) + table = tf.lookup.StaticHashTable( + tf.lookup.KeyValueTensorInitializer(keys, vals), -1) + return table.lookup(needle) + + return _do_the_mapping + + +@Registry.register("preprocess_ops.onehot") +def get_onehot(depth, + key="labels", + key_result=None, + multi=True, + on=1.0, + off=0.0): + """One-hot encodes the input. + + Args: + depth: Length of the one-hot vector (how many classes). + key: Key of the data to be one-hot encoded. + key_result: Key under which to store the result (same as `key` if None). + multi: If there are multiple labels, whether to merge them into the same + "multi-hot" vector (True) or keep them as an extra dimension (False). + on: Value to fill in for the positive label (default: 1). + off: Value to fill in for negative labels (default: 0). + + Returns: + Data dictionary. + """ + + def _onehot(data): + # When there's more than one label, this is significantly more efficient + # than using tf.one_hot followed by tf.reduce_max; we tested. + labels = data[key] + labels = tf.cast(labels, tf.int64) # both scatter and one_hot expect this + if labels.shape.rank > 0 and multi: + x = tf.scatter_nd(labels[:, None], tf.ones(tf.shape(labels)[0]), (depth,)) + x = tf.clip_by_value(x, 0, 1) * (on - off) + off + else: + x = tf.one_hot(labels, depth, on_value=on, off_value=off) + data[key_result or key] = x + return data + + return _onehot + + +@Registry.register("preprocess_ops.keep") +def get_keep(*keys): + """Keeps only the given keys.""" + + def _keep(data): + return {k: v for k, v in data.items() if k in keys} + + return _keep + + +@Registry.register("preprocess_ops.drop") +def get_drop(*keys): + """Drops the given keys.""" + + def _drop(data): + return {k: v for k, v in data.items() if k not in keys} + + return _drop + + +@Registry.register("preprocess_ops.copy") +def get_copy(inkey, outkey): + """Copies value of `inkey` into `outkey`.""" + + def _copy(data): + # A "semi-deep" copy. deepcopy doesn't work when tf tensors are part of the + # game. What we want, is to only copy the python structure (dicts, lists) + # and keep tensors as they are, since we never modify them in-place anyways. + # The following achieves exactly that. + data[outkey] = jax.tree.map(lambda x: x, data[inkey]) + return data + + return _copy + + +@Registry.register("preprocess_ops.squeeze_last_dim") +@utils.InKeyOutKey() +def get_squeeze_last_dim(): + def _squeeze_last_dim(x): + return tf.squeeze(x, axis=-1) + return _squeeze_last_dim + + +@Registry.register("preprocess_ops.concat") +def get_concat(inkeys, outkey=None, axis=-1): + """Concatenates elements along some axis.""" + + def _concat(data): + data[outkey or inkeys[0]] = tf.concat([data[k] for k in inkeys], axis) + return data + + return _concat + + +@Registry.register("preprocess_ops.rag_tensor") +@utils.InKeyOutKey() +def get_rag_tensor(): + """Converts the specified feature to ragged tensor.""" + + def rag_tensor(raw_tensor): + # Note: Add one more dimension as `from_tensor` requires at least rank 2. + return tf.RaggedTensor.from_tensor(raw_tensor[None]) + + return rag_tensor + + +@Registry.register("preprocess_ops.pad_to_shape") +@utils.InKeyOutKey() +def get_pad_to_shape(shape, pad_value=0, where="after"): + """Pads tensor to specified `shape`.""" + + def _pads(cur, tgt): + if tgt is None: + return [0, 0] + diff = tgt - cur + return { + "before": [diff, 0], + "after": [0, diff], + "both": [diff // 2, diff - diff // 2], + }[where] + + def _pad_to_shape(x): + assert len(x.shape.as_list()) == len(shape) + paddings = [_pads(tgt=shape[i], cur=tf.shape(x)[i]) + for i in range(len(shape))] + constant_value = tf.constant(pad_value, x.dtype) + ret = tf.pad(x, paddings, constant_values=constant_value) + ret.set_shape(shape) + return ret + + return _pad_to_shape + + +@Registry.register("preprocess_ops.flatten") +def get_flatten(): + """Flattens the keys of data with separator '/'.""" + + def flatten(data): + flat, _ = bv_utils.tree_flatten_with_names(data) + return dict(flat) + + return flatten + + +@Registry.register("preprocess_ops.reshape") +@utils.InKeyOutKey() +def get_reshape(new_shape): + """Reshapes tensor to a given new shape. + + Args: + new_shape: new shape for the tensor. + + Returns: + A function for reshaping a tensor. + + """ + + def _reshape(tensor): + """Reshapes a tensor to a given shape.""" + dtype = tensor.dtype + tensor = tf.reshape(tensor, new_shape) + return tf.cast(tensor, dtype) + + return _reshape + + +@Registry.register("preprocess_ops.setdefault") +def get_setdefault(key, value): + """If `key` is an empty tensor or missing, set it to `value`.""" + def _setdefault(data): + x = data.get(key, tf.constant(value)) + v = tf.constant(value, dtype=x.dtype) + v = tf.broadcast_to(v, [s or 1 for s in x.shape]) + data[key] = tf.cond(tf.size(x) > 0, lambda: x, lambda: v) + return data + return _setdefault + + +@Registry.register("preprocess_ops.choice") +def get_choice(n="single", key=None, fewer_ok=False, inkey=None, outkey=None): + """Chooses the same `n` random entries of all `keys`. + + Args: + n: how many entries to randomly sample (without repeat). Possible values: + - int: that many entries (or fewer if there's fewer, see `fewer_ok`.) + - "single": The string "single" only chooses one and drop the leading dim. + - [min, max]: A pair means randomly take between min/max examples (incl.). + key: str or list of str: See Note. + fewer_ok: whether to fail when there's fewer than `n` elements to choose + from (and hence set static shape to `n`), or whether to allow it. + (and hence have unknown static shape). + inkey: str or list of str: See Note. + outkey: str or list of str: See Note. + + Note: + If key/inkey/outkey is a list, then the same random entries are chosen for + all of the keys. Other than that, they function the same as InKeyOutKey. + + The outkey can also contain the placeholder `{key}` that'll be . + + Examples: + choice(key="alt_text/text") + choice(n=128, key=["patches", "positions"]) + choice(inkey=["questions_i18n", "answers_i18n"], outkey=["q", "a"]) + + Returns: + The pp op. + """ + + # Normalize keys: + inkeys = utils.maybe_repeat(inkey or key, 1) + outkeys = utils.maybe_repeat(outkey or key, 1) + outkeys = [ok.format(key=ik) for ok, ik in zip(outkeys, inkeys)] + + # Let's DRY on this condition and give it a name. + is_varlen = isinstance(n, (list, tuple)) + min_n = n[0] if is_varlen else 1 if n == "single" else n + + def _choice(data): + nitems = tf.shape(data[inkeys[0]])[0] + + # Sanity check that all keys have same leading dimension, and that is at + # least as large as the minimum requested output. + lengths = [tf.shape(data[k])[0] for k in inkeys] + checks = [tf.debugging.assert_equal(l, nitems) for l in lengths] + if not fewer_ok: # Since we check for all-same, a single suffices here. + checks.append(tf.debugging.assert_greater_equal(nitems, min_n)) + with tf.control_dependencies(checks): + nitems = tf.identity(nitems) + + if n == "single": + index = tf.random.uniform([], 0, nitems, dtype=tf.int32) + else: + # Subsample by shuffling and taking first n, but... + indices = tf.random.shuffle(tf.range(nitems)) + end = n + if is_varlen: + end = tf.random.uniform([], n[0], n[1] + 1, dtype=tf.int32) + # ...keep the order while subsampling (it might have a meaning, eg boxes) + indices = tf.sort(indices[:end]) + + for ik, ok in zip(inkeys, outkeys): + if n == "single": + result = data[ik][index] + else: + result = tf.gather(data[ik], indices, axis=0) + if not is_varlen: # Give static shape when we can. + result = tf.ensure_shape(result, [n] + [None] * (result.ndim - 1)) + data[ok] = result + + return data + return _choice + + +def _shuffled_index(count, nitems, seed): + """Returns index from a shuffled sequence (items only repeat after epoch).""" + nitems = tf.cast(nitems, count.dtype) + item_epoch, item_offset = (count // nitems, count % nitems) + shuffled_indices = tf.random.experimental.stateless_shuffle( + tf.range(nitems), seed=tf.random.fold_in(seed, item_epoch)) + return shuffled_indices[item_offset] + + +@Registry.register("preprocess_ops.choice_no_replacement") +def get_choice_no_replacement(key=None, inkey=None, outkey=None): + """Chooses the same random (no replacement) entry of all `keys`. + + Note: Consider using this for iterating over small datasets with a small + number of epochs. It differs from `choice(n='single')` in that if an example, + as identified by its `_id` field, is seen N times then it will cycled through + all the inkeys values before repeating them. Additionally each repetition uses + a different order. + + Caveats: requires dataset to provide a _id field and uses host RAM to keep a + counter how often each id is seen. It is also not robust to preemptions. + + Args: + key: str or list of str: See Note. + inkey: str or list of str: See Note. + outkey: str or list of str: See Note. + + Note: + If key/inkey/outkey is a list, then the same random entries are chosen for + all of the keys. Other than that, they function the same as InKeyOutKey. + + The outkey can also contain the placeholder `{key}` that'll be replaced + by the inkey name. + + Examples: + choice(key="alt_text/text") + choice(key=["patches", "positions"]) + choice(inkey=["questions_i18n", "answers_i18n"], outkey=["q", "a"]) + + Returns: + The pp op. + """ + # Normalize keys: + inkeys = utils.maybe_repeat(inkey or key, 1) + outkeys = utils.maybe_repeat(outkey or key, 1) + outkeys = [ok.format(key=ik) for ok, ik in zip(outkeys, inkeys)] + + # TODO: Ideally the data pipeline should provide us with an epoch + # counter. For now count how often we see a given example id and don't worry + # on memory consumption. Counter returns 0 the first time an example is seen. + counter = collections.defaultdict(lambda: -1) + def _seen_count(example_id): + example_id = example_id.item() + counter[example_id] += 1 + return counter[example_id] + + # We need a seed to deterministically decide on a shuffled sequence and use + # the number of times an example was seen to iterate through it. The seed + # should be different for every instance of a create preprocessing function + # but it has to be fixed for each instance. + seed = tf.random.uniform( + [2], minval=tf.int32.min, maxval=tf.int32.max, dtype=tf.int32) + + def _choice(data): + nitems = tf.shape(data[inkeys[0]])[0] + + # Sanity check that all keys have same leading dimension. + checks = [ + tf.debugging.assert_equal(tf.shape(data[k])[0], nitems) + for k in inkeys + ] + with tf.control_dependencies(checks): + nitems = tf.identity(nitems) + + # Using the seed, example id and the number of times an example was seen + # pick an `index` such that items are only repeated after all items are seen + # an equal number of times. E.g. it could return indexes from this sequence: + # [0, 1, 2, 1, 2, 0, 2, 0, 1, 0, 2, 1, ...]. + count = tf.numpy_function( + _seen_count, (data["_id"],), Tout=tf.int64, stateful=True) + count = tf.cast(count, tf.int32) + nitems = tf.cast(nitems, tf.int32) + shuffle_epoch = count // nitems + shuffle_offset = count % nitems + + example_seed = tf.random.fold_in(seed, data["_id"]) + shuffle_seed = tf.random.fold_in(example_seed, shuffle_epoch) + shuffle = tf.random.experimental.stateless_shuffle( + tf.range(nitems), seed=shuffle_seed) + index = shuffle[shuffle_offset] + + # Select item[index] for all keys. + for ik, ok in zip(inkeys, outkeys): + data[ok] = data[ik][index] + return data + + return _choice diff --git a/big_vision/pp/ops_general_test.py b/big_vision/pp/ops_general_test.py new file mode 100644 index 0000000000000000000000000000000000000000..89f616e1690c6e83aff818cf0fff540dcad073fd --- /dev/null +++ b/big_vision/pp/ops_general_test.py @@ -0,0 +1,236 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for ops_general.""" + +import copy + +import big_vision.pp.ops_general as pp +import numpy as np +import tensorflow as tf + + +class PreprocessOpsTest(tf.test.TestCase): + + def tfrun(self, ppfn, data): + # Run once as standalone, as could happen eg in colab. + yield {k: np.array(v) for k, v in ppfn(copy.deepcopy(data)).items()} + + # And then once again as part of tfdata pipeline. + # You'd be surprised how much these two differ! + tfdata = tf.data.Dataset.from_tensors(copy.deepcopy(data)) + for npdata in tfdata.map(ppfn).as_numpy_iterator(): + yield npdata + + def test_value_range(self): + img = tf.random.uniform((640, 480, 3), 0, 255, tf.int32) + data = {"image": tf.cast(img, tf.uint8)} + for out in self.tfrun(pp.get_value_range(-0.5, 0.5), data): + self.assertLessEqual(np.max(out["image"]), 0.5) + self.assertGreaterEqual(np.min(out["image"]), -0.5) + + def test_value_range_custom_input_range(self): + img = tf.random.uniform((640, 480, 3), 0, 255, tf.int32) + data = {"image": tf.cast(img, tf.uint8)} + for out in self.tfrun(pp.get_value_range(-0.5, 0.5, -256, 255, True), data): + self.assertLessEqual(np.max(out["image"]), 0.5) + self.assertGreaterEqual(np.min(out["image"]), 0.0) + + def test_get_keep_drop(self): + data = {"image": 1, "labels": 2, "something": 3} + + for data_keep in self.tfrun(pp.get_keep("image", "labels"), data): + self.assertAllEqual(set(data_keep.keys()), {"image", "labels"}) + + for data_drop in self.tfrun(pp.get_drop("image", "labels"), data): + self.assertAllEqual(set(data_drop.keys()), {"something"}) + + def test_onehot(self): + data = {"labels": tf.constant(2, dtype=tf.int64)} + for out in self.tfrun(pp.get_onehot(4, "labels", multi=True), data): + self.assertAllClose(out["labels"], [0., 0., 1., 0.]) + + def test_onehot_multi(self): + data = {"labels": tf.constant([2, 3, 0], dtype=tf.int64)} + for out in self.tfrun(pp.get_onehot(4, "labels", multi=False), data): + self.assertAllClose(out["labels"], [ + [0., 0., 1., 0.], + [0., 0., 0., 1.], + [1., 0., 0., 0.]]) + + for out in self.tfrun(pp.get_onehot(4, "labels", multi=True), data): + self.assertAllClose(out["labels"], [1., 0., 1., 1.]) + + def test_onehot_2d(self): + data = {"labels": tf.constant([[2, 3], [0, 1]], dtype=tf.int64)} + for out in self.tfrun(pp.get_onehot(4, "labels", multi=False), data): + self.assertAllClose(out["labels"], [ + [[0., 0., 1., 0.], [0., 0., 0., 1.]], + [[1., 0., 0., 0.], [0., 1., 0., 0.]]]) + + def test_onehot_smoothing(self): + data = {"labels": tf.constant([2, 3, 0], dtype=tf.int64)} + for out in self.tfrun( + pp.get_onehot(4, "labels", multi=False, on=0.8, off=0.1), data): + self.assertAllClose(out["labels"], [ + [0.1, 0.1, 0.8, 0.1], + [0.1, 0.1, 0.1, 0.8], + [0.8, 0.1, 0.1, 0.1]]) + + for out in self.tfrun( + pp.get_onehot(4, "labels", multi=True, on=0.8, off=0.1), data): + self.assertAllClose(out["labels"], [0.8, 0.1, 0.8, 0.8]) + + def test_squeeze_last_dim(self): + data = {"image": tf.constant(np.zeros((32, 32, 3, 1)))} + for out in self.tfrun(pp.get_squeeze_last_dim(), data): + self.assertAllEqual(out["image"].shape, [32, 32, 3]) + + def test_pad_to_shape(self): + desired_shape = (8, 10) + for input_shape in [(8, 4), (8, 3), (8, 10), (8, 1)]: + data = {"x": tf.ones(input_shape, dtype=tf.float32)} + for out in self.tfrun( + pp.get_pad_to_shape(desired_shape, pad_value=-1, key="x"), data): + self.assertEqual( + tf.reduce_sum(out["x"]), + 2 * np.prod(input_shape) - np.prod(desired_shape)) + + def test_pad_to_shape_none(self): + data = {"x": tf.ones((8, 4), dtype=tf.float32)} + for out in self.tfrun( + pp.get_pad_to_shape((None, 6), pad_value=-1, key="x"), data): + self.assertEqual(out["x"].shape, (8, 6)) + self.assertEqual(tf.reduce_sum(out["x"]), 8*4 - 8*2) + + def test_pad_to_shape_which_side(self): + data = {"x": tf.ones((8, 4), dtype=tf.float32)} + for where, idxs in [("before", [0]), ("both", [0, -1]), ("after", [-1])]: + for out in self.tfrun( + pp.get_pad_to_shape((8, 6), key="x", where=where), data): + self.assertEqual(out["x"].shape, (8, 6)) + self.assertEqual(tf.reduce_sum(out["x"]), 8*4) + for i in idxs: + self.assertEqual(out["x"][0, i], 0) + + def test_flatten(self): + d = {"a": {"b": tf.constant([1, 2, 3])}, "c": "str"} + self.assertEqual(pp.get_flatten()(d), { + "a/b": tf.constant([1, 2, 3]), + "c": "str" + }) + + def test_reshape(self): + data = {"image": tf.constant(np.zeros((8, 32 * 32 * 3)))} + for out in self.tfrun(pp.get_reshape(new_shape=(8, 32, 32, 3)), data): + self.assertAllEqual(out["image"].shape, [8, 32, 32, 3]) + + def test_setdefault(self): + data = { + "empty_image": tf.zeros([0, 0, 0]), + "image": tf.constant(np.arange(9).reshape(3, 3)), + "empty_text": tf.zeros([0], tf.string), + "text": tf.constant(["Hello", "World"], tf.string), + } + for out in self.tfrun(pp.get_setdefault("empty_image", 1), data): + self.assertAllEqual(out["empty_image"], np.array([[[1]]])) + for out in self.tfrun(pp.get_setdefault("image", 1), data): + self.assertAllEqual(out["image"], data["image"]) + for out in self.tfrun(pp.get_setdefault("empty_text", "Lucas"), data): + self.assertAllEqual(out["empty_text"], np.array(["Lucas"])) + for out in self.tfrun(pp.get_setdefault("text", "Lucas"), data): + self.assertAllEqual(out["text"], data["text"]) + + def _data_for_choice(self): + return { + "one_f32": tf.constant([0.42], tf.float32), + "two_f32": tf.constant([3.14, 0.42], tf.float32), + "one_str": tf.constant(["Hi"], tf.string), + "two_str": tf.constant(["Hi", "Lucas"], tf.string), + "one_vec": tf.reshape(tf.range(2, dtype=tf.float32), (1, 2)), + "two_vec": tf.reshape(tf.range(4, dtype=tf.float32), (2, 2)), + } + + def test_choice(self): + # Test for the default call (n="single") + data = self._data_for_choice() + self.assertEqual( + pp.get_choice(inkey="one_f32", outkey="choice")(data)["choice"], 0.42) + self.assertEqual( + pp.get_choice(inkey="one_str", outkey="choice")(data)["choice"], "Hi") + self.assertIn( + pp.get_choice(inkey="two_f32", outkey="choice")(data)["choice"], + [3.14, 0.42]) + self.assertIn( + pp.get_choice(inkey="two_str", outkey="choice")(data)["choice"], + ["Hi", "Lucas"]) + + def test_choice_nmax(self): + # n == nelems should be identity (and keep ordering!) + data = self._data_for_choice() + for k in ("one_f32", "one_str", "one_vec"): + for out in self.tfrun(pp.get_choice(n=1, key=[k]), data): + self.assertAllEqual(out[k], data[k]) + for out in self.tfrun(pp.get_choice(n=[1, 1], key=[k]), data): + self.assertAllEqual(out[k], data[k]) + for k in ("two_f32", "two_str", "two_vec"): + for out in self.tfrun(pp.get_choice(n=2, key=[k]), data): + self.assertAllEqual(out[k], data[k]) + for out in self.tfrun(pp.get_choice(n=[2, 2], key=[k]), data): + self.assertAllEqual(out[k], data[k]) + + def test_choice_n(self): + # n < nelems should be one of them: + data = self._data_for_choice() + for k in ("two_f32", "two_str"): + for out in self.tfrun(pp.get_choice(n=1, key=[k]), data): + self.assertIn(out[k], data[k]) + + # Special testing for vectors. + for out in self.tfrun(pp.get_choice(n=1, key=["two_vec"]), data): + self.assertTrue(tf.logical_or( + tf.reduce_all(out["two_vec"][0] == data["two_vec"][0]), + tf.reduce_all(out["two_vec"][0] == data["two_vec"][1]), + )) + + def test_choice_multi(self): + # Select consistently across multiple keys. + data = self._data_for_choice() + op = pp.get_choice(n=1, key=["two_f32", "two_str"]) + for out in self.tfrun(op, data): + self.assertTrue(tf.logical_or( + tf.logical_and( + tf.reduce_all(out["two_f32"][0] == data["two_f32"][0]), + tf.reduce_all(out["two_str"][0] == data["two_str"][0]), + ), + tf.logical_and( + tf.reduce_all(out["two_f32"][0] == data["two_f32"][1]), + tf.reduce_all(out["two_str"][0] == data["two_str"][1]), + ), + )) + + def test_choice_n_range(self): + # n < nelems should be one of them: + data = self._data_for_choice() + for k in ("two_f32", "two_str", "two_vec"): + for out in self.tfrun(pp.get_choice(n=[1, 2], key=[k]), data): + self.assertTrue(tf.reduce_any([ + tf.reduce_all(out[k] == data[k][0:1]), + tf.reduce_all(out[k] == data[k][1:2]), + tf.reduce_all(out[k] == data[k][0:2]), + ])) + + +if __name__ == "__main__": + tf.test.main() diff --git a/big_vision/pp/ops_image.py b/big_vision/pp/ops_image.py new file mode 100644 index 0000000000000000000000000000000000000000..bdc55a5659c3e0df0c209edef798bbd5c6a7f623 --- /dev/null +++ b/big_vision/pp/ops_image.py @@ -0,0 +1,361 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Image-centric preprocessing ops. + +All preprocessing ops should return a data processing functors. A data +is represented as a dictionary of (TF) tensors. The functors output a modified +dictionary. + +The key named "image" is commonly used for the image, and is a 3D tensor of +shape (height x width x channels). +""" + +from big_vision.pp import utils +from big_vision.pp.registry import Registry + +import tensorflow as tf + + +@Registry.register("preprocess_ops.decode") +@utils.InKeyOutKey() +def get_decode(channels=3, precise=False): + """Decode an encoded image string, see tf.io.decode_image. + + Args: + channels: see tf.io.decode_image. + precise: if False, use default TF image decoding algorithm. + If True, change DCT method for JPEG decoding to match PIL/cv2/PyTorch. + See also (internal link) for a concrete example. + + Returns: + The decoded image. + """ + + def _decode(image): + if precise: + return tf.image.decode_jpeg( # Also supports png btw. + image, channels=channels, dct_method="INTEGER_ACCURATE") + else: + return tf.io.decode_image( + image, channels=channels, expand_animations=False) + + return _decode + + +@Registry.register("preprocess_ops.resize") +@utils.InKeyOutKey() +def get_resize(size, method="bilinear", antialias=False): + """Resizes image to a given size. + + Args: + size: either an integer H, where H is both the new height and width + of the resized image, or a list or tuple [H, W] of integers, where H and W + are new image"s height and width respectively. + method: resize method, see tf.image.resize docs for options. + antialias: see tf.image.resize. Ideally set to True for all new configs. + + Returns: + A function for resizing an image. + + """ + size = utils.maybe_repeat(size, 2) + + def _resize(image): + """Resizes image to a given size.""" + # Note: use TF-2 version of tf.image.resize as the version in TF-1 is + # buggy: https://github.com/tensorflow/tensorflow/issues/6720. + # In particular it was not equivariant with rotation and lead to the network + # to learn a shortcut in self-supervised rotation task, if rotation was + # applied after resize. + dtype = image.dtype + tf_dtype = tf.type_spec_from_value(image).dtype + image = tf.image.resize(image, size, method=method, antialias=antialias) + return tf.cast(tf.clip_by_value(image, tf_dtype.min, tf_dtype.max), dtype) + + return _resize + + +# This functionality is used by resize_small and resize_long. But we're not +# registering it as a pp op yet, as there is no need for it. However, it can +# probably be slightly generalized into "scale augmentation" eventually. +def _resize_factor(image, factor, method="area", antialias=True): + """Resizes the image by a (float) `factor`, keeping the aspect ratio fixed.""" + h, w = tf.shape(image)[0], tf.shape(image)[1] + + h = tf.cast(tf.round(tf.cast(h, tf.float32) * factor), tf.int32) + w = tf.cast(tf.round(tf.cast(w, tf.float32) * factor), tf.int32) + + dtype = image.dtype + tf_dtype = tf.type_spec_from_value(image).dtype + image = tf.image.resize(image, (h, w), method=method, antialias=antialias) + return tf.cast(tf.clip_by_value(image, tf_dtype.min, tf_dtype.max), dtype) + + +@Registry.register("preprocess_ops.resize_small") +@utils.InKeyOutKey() +def get_resize_small(smaller_size, method="area", antialias=False): + """Resizes the smaller side to `smaller_size` keeping aspect ratio. + + Args: + smaller_size: an integer, that represents a new size of the smaller side of + an input image. + method: the resize method. `area` is a meaningful, bwd-compat default. + antialias: see tf.image.resize. Ideally set to True for all new configs. + + Returns: + A function, that resizes an image and preserves its aspect ratio. + + Note: + backwards-compat for "area"+antialias tested here: + (internal link) + """ + + def _resize_small(image): # pylint: disable=missing-docstring + h, w = tf.shape(image)[0], tf.shape(image)[1] + factor = ( + tf.cast(smaller_size, tf.float32) / + tf.cast(tf.minimum(h, w), tf.float32)) + return _resize_factor(image, factor, method=method, antialias=antialias) + return _resize_small + + +@Registry.register("preprocess_ops.resize_long") +@utils.InKeyOutKey() +def get_resize_long(longer_size, method="area", antialias=True): + """Resizes the longer side to `longer_size` keeping aspect ratio. + + Args: + longer_size: an integer, that represents a new size of the longer side of + an input image. + method: the resize method. `area` is a meaningful, bwd-compat default. + antialias: see tf.image.resize. Ideally set to True for all new configs. + + Returns: + A function, that resizes an image and preserves its aspect ratio. + """ + + def _resize_long(image): # pylint: disable=missing-docstring + h, w = tf.shape(image)[0], tf.shape(image)[1] + factor = ( + tf.cast(longer_size, tf.float32) / + tf.cast(tf.maximum(h, w), tf.float32)) + return _resize_factor(image, factor, method=method, antialias=antialias) + return _resize_long + + +@Registry.register("preprocess_ops.inception_crop") +@utils.InKeyOutKey() +def get_inception_crop(size=None, area_min=5, area_max=100, + method="bilinear", antialias=False): + """Makes inception-style image crop. + + Inception-style crop is a random image crop (its size and aspect ratio are + random) that was used for training Inception models, see + https://www.cs.unc.edu/~wliu/papers/GoogLeNet.pdf. + + Args: + size: Resize image to [size, size] after crop. + area_min: minimal crop area. + area_max: maximal crop area. + method: rezied method, see tf.image.resize docs for options. + antialias: see tf.image.resize. Ideally set to True for all new configs. + + Returns: + A function, that applies inception crop. + """ + + def _inception_crop(image): # pylint: disable=missing-docstring + begin, crop_size, _ = tf.image.sample_distorted_bounding_box( + tf.shape(image), + tf.zeros([0, 0, 4], tf.float32), + area_range=(area_min / 100, area_max / 100), + min_object_covered=0, # Don't enforce a minimum area. + use_image_if_no_bounding_boxes=True) + crop = tf.slice(image, begin, crop_size) + # Unfortunately, the above operation loses the depth-dimension. So we need + # to restore it the manual way. + crop.set_shape([None, None, image.shape[-1]]) + if size: + crop = get_resize(size, method, antialias)({"image": crop})["image"] + return crop + + return _inception_crop + + +@Registry.register("preprocess_ops.decode_jpeg_and_inception_crop") +@utils.InKeyOutKey() +def get_decode_jpeg_and_inception_crop(size=None, area_min=5, area_max=100, + ratio_min=0.75, ratio_max=1.33, + method="bilinear", antialias=False): + """Decode jpeg string and make inception-style image crop. + + Inception-style crop is a random image crop (its size and aspect ratio are + random) that was used for training Inception models, see + https://www.cs.unc.edu/~wliu/papers/GoogLeNet.pdf. + + Args: + size: Resize image to [size, size] after crop. + area_min: minimal crop area. + area_max: maximal crop area. + ratio_min: minimal aspect ratio. + ratio_max: maximal aspect ratio. + method: rezied method, see tf.image.resize docs for options. + antialias: see tf.image.resize. Ideally set to True for all new configs. + + Returns: + A function, that applies inception crop. + """ + + def _inception_crop(image_data): # pylint: disable=missing-docstring + shape = tf.image.extract_jpeg_shape(image_data) + begin, crop_size, _ = tf.image.sample_distorted_bounding_box( + shape, + tf.zeros([0, 0, 4], tf.float32), + area_range=(area_min / 100, area_max / 100), + aspect_ratio_range=(ratio_min, ratio_max), + min_object_covered=0, # Don't enforce a minimum area. + use_image_if_no_bounding_boxes=True) + + # Crop the image to the specified bounding box. + offset_y, offset_x, _ = tf.unstack(begin) + target_height, target_width, _ = tf.unstack(crop_size) + crop_window = tf.stack([offset_y, offset_x, target_height, target_width]) + image = tf.image.decode_and_crop_jpeg(image_data, crop_window, channels=3) + + if size: + image = get_resize(size, method, antialias)({"image": image})["image"] + + return image + + return _inception_crop + + +@Registry.register("preprocess_ops.random_crop") +@utils.InKeyOutKey() +def get_random_crop(crop_size): + """Makes a random crop of a given size. + + Args: + crop_size: either an integer H, where H is both the height and width of the + random crop, or a list or tuple [H, W] of integers, where H and W are + height and width of the random crop respectively. + + Returns: + A function, that applies random crop. + """ + crop_size = utils.maybe_repeat(crop_size, 2) + + def _crop(image): + return tf.image.random_crop(image, (*crop_size, image.shape[-1])) + + return _crop + + +@Registry.register("preprocess_ops.central_crop") +@utils.InKeyOutKey() +def get_central_crop(crop_size=None): + """Makes central crop of a given size. + + Args: + crop_size: either an integer H, where H is both the height and width of the + central crop, or a list or tuple [H, W] of integers, where H and W are + height and width of the central crop respectively. If `crop_size` is not + specified, then the largest possible center crop will be taken. + + Returns: + A function, that applies central crop. + """ + if crop_size: + crop_size = utils.maybe_repeat(crop_size, 2) + + def _crop(image): + if crop_size: + h, w = crop_size[0], crop_size[1] + else: + h = w = tf.minimum(tf.shape(image)[0], tf.shape(image)[1]) + dy = (tf.shape(image)[0] - h) // 2 + dx = (tf.shape(image)[1] - w) // 2 + return tf.image.crop_to_bounding_box(image, dy, dx, h, w) + + return _crop + + +@Registry.register("preprocess_ops.flip_lr") +@utils.InKeyOutKey() +def get_random_flip_lr(): + """Flips an image horizontally with probability 50%.""" + + def _random_flip_lr_pp(image): + return tf.image.random_flip_left_right(image) + + return _random_flip_lr_pp + + +@Registry.register("preprocess_ops.vgg_value_range") +@utils.InKeyOutKey() +def get_vgg_value_range( + mean=(0.485 * 255, 0.456 * 255, 0.406 * 255), + std=(0.229 * 255, 0.224 * 255, 0.225 * 255), +): + """VGG-style preprocessing, subtracts mean and divides by stddev. + + This preprocessing is very common for ImageNet pre-trained models since VGG, + and to this day the standard for models coming from most PyTorch codes. + + Args: + mean: Tuple of values to be subtracted. Default to widespread VGG values. + std: Tuple of values to be divided by. Default to widespread VGG values. + + Returns: + A function to rescale the values. + """ + mean = tf.constant(mean, tf.float32) + std = tf.constant(std, tf.float32) + + def _vgg_value_range(image): + return (tf.cast(image, tf.float32) - mean) / std + return _vgg_value_range + + +@Registry.register("preprocess_ops.clip_value_range") +@utils.InKeyOutKey() +def get_clip_value_range(): + mean = (0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255) + std = (0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255) + + def _clip_value_range(image): + return (tf.cast(image, tf.float32) - mean) / std + return _clip_value_range + + +@Registry.register("preprocess_ops.convert_to_video") +@utils.InKeyOutKey() +def get_convert_to_video(num_frames): + """Converts an image to a video with zero padded frames. + + Args: + num_frames: total number of frames that the video should have. + + Returns: + A function for converting an image to a video. + """ + + def _convert_to_video(image): + return tf.pad( + tf.expand_dims(image, axis=0), + [[0, num_frames - 1], [0, 0], [0, 0], [0, 0]], + ) + + return _convert_to_video diff --git a/big_vision/pp/ops_image_test.py b/big_vision/pp/ops_image_test.py new file mode 100644 index 0000000000000000000000000000000000000000..080fe673cf90f83b405106dd057870ee8e8f76a2 --- /dev/null +++ b/big_vision/pp/ops_image_test.py @@ -0,0 +1,82 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for ops_image.""" + +import copy +import io + +import big_vision.pp.ops_image as pp +import matplotlib.pyplot as plt +import numpy as np +import tensorflow as tf + + +def get_image_data(): + img = tf.random.uniform((640, 480, 3), 0, 255, tf.int32) # Can't ask uint8!? + return {"image": tf.cast(img, tf.uint8)} + + +class PreprocessOpsTest(tf.test.TestCase): + + def tfrun(self, ppfn, data): + # Run once as standalone, as could happen eg in colab. + yield {k: np.array(v) for k, v in ppfn(copy.deepcopy(data)).items()} + + # And then once again as part of tfdata pipeline. + # You'd be surprised how much these two differ! + tfdata = tf.data.Dataset.from_tensors(copy.deepcopy(data)) + for npdata in tfdata.map(ppfn).as_numpy_iterator(): + yield npdata + + def test_resize(self): + for data in self.tfrun(pp.get_resize([120, 80]), get_image_data()): + self.assertEqual(data["image"].shape, (120, 80, 3)) + + def test_resize_small(self): + for data in self.tfrun(pp.get_resize_small(240), get_image_data()): + self.assertEqual(data["image"].shape, (320, 240, 3)) + + def test_resize_long(self): + for data in self.tfrun(pp.get_resize_long(320), get_image_data()): + self.assertEqual(data["image"].shape, (320, 240, 3)) + + def test_inception_crop(self): + for data in self.tfrun(pp.get_inception_crop(), get_image_data()): + self.assertEqual(data["image"].shape[-1], 3) + + def test_decode_jpeg_and_inception_crop(self): + f = io.BytesIO() + plt.imsave(f, get_image_data()["image"].numpy(), format="jpg") + data = {"image": tf.cast(f.getvalue(), tf.string)} + for data in self.tfrun(pp.get_decode_jpeg_and_inception_crop(), data): + self.assertEqual(data["image"].shape[-1], 3) + + def test_random_crop(self): + for data in self.tfrun(pp.get_random_crop([120, 80]), get_image_data()): + self.assertEqual(data["image"].shape, (120, 80, 3)) + + def test_central_crop(self): + for data in self.tfrun(pp.get_central_crop([20, 80]), get_image_data()): + self.assertEqual(data["image"].shape, (20, 80, 3)) + + def test_random_flip_lr(self): + data_orig = get_image_data() + for data in self.tfrun(pp.get_random_flip_lr(), data_orig): + self.assertTrue( + np.all(data_orig["image"].numpy() == data["image"]) or + np.all(data_orig["image"].numpy() == data["image"][:, ::-1])) + +if __name__ == "__main__": + tf.test.main() diff --git a/big_vision/pp/ops_text.py b/big_vision/pp/ops_text.py new file mode 100644 index 0000000000000000000000000000000000000000..02172507f7f191ac0af16fb2446275e5e1a03440 --- /dev/null +++ b/big_vision/pp/ops_text.py @@ -0,0 +1,375 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Text-centric preprocessing ops. + +All preprocessing ops should return a data processing functors. A data +is represented as a dictionary of (TF) tensors. The functors output a modified +dictionary. + +A commonly used key for the tokenized output is "labels". +""" +import functools +import importlib + +from absl import logging +from big_vision.datasets.imagenet import class_names as imagenet_class_names +from big_vision.pp import ops_general +from big_vision.pp import tokenizer as bv_tok +from big_vision.pp import utils +from big_vision.pp.registry import Registry +import tensorflow as tf + +from tensorflow.io import gfile + +import sentencepiece +SPProcessor = sentencepiece.SentencePieceProcessor + +import os +os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python' +import sentencepiece.sentencepiece_model_pb2 +del os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] +SPModelProto = sentencepiece.sentencepiece_model_pb2.ModelProto + + +# TODO: b/lbeyer - softly introduce and move to new tokenizer API. + +KNOWN_TOKENIZERS = { + "mc4": # used in multilingual models (mT5, PaLI), vocab_size=250_000 + "gs://t5-data/vocabs/mc4.250000.100extra/sentencepiece.model", + "cc_all": # vocab_size=32_000 + "gs://t5-data/vocabs/cc_all.32000/sentencepiece.model", + "c4_en": # vocab_size=32_000 + "gs://t5-data/vocabs/cc_en.32000/sentencepiece.model", + "t5": # same as cc_all, but with 100 extra dummy tokens used by T5 models + "gs://t5-data/vocabs/cc_all.32000.100extra/sentencepiece.model", + "mt5": # same as mc4, but with 100 extra dummy tokens used by T5 models + "gs://t5-data/vocabs/mc4.250000.100extra/sentencepiece.model", +} + + +def create_tokenizer(model="c4_en", add_eos=True, add_bos=False): + """Creates a tokenizer which can be used in tfds.""" + logging.info("Creating tokenizer: %s", model) + with gfile.GFile(KNOWN_TOKENIZERS.get(model, model), "rb") as f: + model = f.read() + + # Lazy import of tensorflow_text so it is an optional dependency for + # the users of this file. + import tensorflow_text + return tensorflow_text.SentencepieceTokenizer( + model=model, add_eos=add_eos, add_bos=add_bos + ) + + +def tokenize(input_text, tokenizer, max_len, *, pad_value, force_eos, + multi_text=False): + """Tokenizes string, and adds `pad_value` if longer than `max_len`.""" + + def pad(tokens): + # Truncate/pad to max_len. + if force_eos: + tokens = tf.cond( + tf.shape(tokens)[0] >= max_len, + lambda: tf.concat( + # For too long, cut them off, but do keep the final EOS token. + [tokens[:max_len - 1], tokens[-1:]], axis=0), + lambda: tf.pad( + tokens, [(0, max_len - tf.shape(tokens)[0])], + constant_values=pad_value), + ) + else: + tokens = tokens[:max_len] + tokens = tf.pad( + tokens, [(0, max_len - tf.shape(tokens)[0])], + constant_values=pad_value) + tokens.set_shape([max_len]) + return tokens + + tokens = tokenizer.tokenize(input_text) + + if multi_text: + tokens = tokens.to_tensor(pad_value) # tf.RaggedTensor to tf.Tensor + tokens = tf.reshape(tokens, [-1, tf.shape(tokens)[-1]]) + tokens = tf.map_fn(pad, tokens) # `map_fn` only maps on axis 0 + + final_shape = tf.concat([tf.shape(input_text), [max_len]], axis=0) + return tf.reshape(tokens, final_shape) + else: + return pad(tokens) + + +@Registry.register("preprocess_ops.tokenize") +@utils.InKeyOutKey(indefault=None, outdefault="labels") +def get_pp_tokenize( + max_len, + eos, + model="c4_en", + lower=True, + sample_if_multi=True, + pad_value="", + add_bos=False +): + """Tokenizes a text. + + Let's assume max_len=3 and id("")=1, id("a")=2, then we have + + 1. `eos="none", pad_value=0`: + - "a" -> [2, 0, 0] + - "aa" -> [2, 2, 0] + - "aaa" -> [2, 2, 2] + + 2. `eos="yes", pad_value=0`: + - "a" -> [2, 1, 0] + - "aa" -> [2, 2, 1] + - "aaa" -> [2, 2, 2] + + This is usually used with generative models that need to learn when to + properly predict a "" (when the sentence is finished) and when to + abstain (when the sentence is truncated). + + 3. `eos="sticky", pad_value=0`: + - "a" -> [2, 1, 0] + - "aa" -> [2, 2, 1] + - "aaa" -> [2, 2, 1] + + 4. `eos="sticky", pad_value=1`: + - "a" -> [2, 1, 1] + - "aa" -> [2, 2, 1] + - "aaa" -> [2, 2, 1] + + This is traditionally used with contrastive models that use the last token + for embeddings, similarly to "cls" tokens in BERT-style models. + + Args: + max_len: maximum length of the tokenized text. + eos: Whether to add an "" (end of sentence) token and whether to keep it + when the sequence is longer than `max_len - 1`. See examples above for + details. Valid values: "none", "yes", "sticky". + model: a path to the pretrained sentencepiece model. + lower: lowercase the text before tokenizing. + sample_if_multi: If there's more than one, randomly pick one if this is + True; otherwise pick all texts and keep the input's batch shape in result. + pad_value: which token to pad the sequence with. If a string (for example + `""`), tokenize it and use its first token. Note that there is no + guarantee to have any padding at the end of the sentence, if the sentence + is longer than `max_len`. + add_bos: adds beginning of sentence symbol. + + Returns: + an op that outputs tokenized text. + """ + + if eos not in ("yes", "none", "sticky"): + raise ValueError(f"Invalid value for eos: '{eos}'.") + + tokenizer = create_tokenizer(model, add_eos=eos != "none", add_bos=add_bos) + + if isinstance(pad_value, str): + pad_value = tokenizer.string_to_id(pad_value) + + def _pp_tokenize(txt): + if sample_if_multi and tf.convert_to_tensor(txt).ndim: + # TODO: I wish this code-path could die. + logging.warning("sample_if_multi is deprecated and will be removed." + "Call `choice` (and maybe `setdefault`) instead.") + txt = ops_general.get_choice(key="t")( + ops_general.get_setdefault("t", "")({"t": txt}))["t"] + + if lower: + txt = tf.strings.lower(txt) if sample_if_multi else tf.map_fn( + tf.strings.lower, txt) + + return tokenize( + txt, + tokenizer, + max_len, + pad_value=pad_value, + force_eos=eos == "sticky", + multi_text=not sample_if_multi) + + return _pp_tokenize + + +@Registry.register("preprocess_ops.coco_captions") +def get_coco_captions(outkey="captions"): + """Extracts coco's captions from nested dict.""" + + def _pp_coco_captions(data): + data[outkey] = data["captions"]["text"] + return data + + return _pp_coco_captions + + +@Registry.register("preprocess_ops.clip_i1k_label_names") +@utils.InKeyOutKey(indefault="label", outdefault="labels") +def get_pp_clip_i1k_label_names(): + """Convert i1k label numbers to strings, using CLIP's class names.""" + + def _pp_imagenet_labels(label): + return tf.gather(imagenet_class_names.CLIP_IMAGENET_CLASS_NAMES, label) + + return _pp_imagenet_labels + + +@Registry.register("preprocess_ops.lower") +@utils.InKeyOutKey(indefault="text", outdefault="text") +def get_lower(): + """Lowercases text feature.""" + + def _pp_lower(text): + return tf.strings.lower(text) + + return _pp_lower + + +def _add_pieces(model_bytes, extra_pieces): + """Adds extra pieces to sentencpiece model specified by `model_bytes`.""" + + model = SPProcessor() + model.LoadFromSerializedProto(model_bytes) + unk_idx = model.PieceToId("") + assert model.IdToPiece(unk_idx) == "", model.IdToPiece(unk_idx) + + model_proto = SPModelProto.FromString(model_bytes) + idx_to_updated_piece = {} + for piece in extra_pieces: + # The SentencePieceModel proto stores whitespaces as the special + # character '▁'. We perform the conversion here. + piece = piece.replace(" ", "▁") + spiece = model_proto.SentencePiece( + piece=piece, + # We set the highest score to force priority on user defined tokens. + score=0.0, + type=model_proto.SentencePiece().Type.USER_DEFINED, + ) + existing_idx = model.PieceToId(piece) + if (existing_idx != unk_idx) ^ (piece == ""): + idx_to_updated_piece[existing_idx] = spiece + logging.info("Updating token at idx %d: %s", existing_idx, spiece.piece) + else: + model_proto.pieces.append(spiece) + + # Replace duplicated pieces with updated ones. + updated_pieces = [ + idx_to_updated_piece.get(i, piece) + for i, piece in enumerate(model_proto.pieces) + ] + del model_proto.pieces[:] + model_proto.pieces.extend(updated_pieces) + + return model_proto.SerializeToString() + + +def _iterable(x): + if isinstance(x, tf.RaggedTensor): + return True + if getattr(x, "ndim", 0) > 1: # np, jnp + return True + if isinstance(x, (list, tuple)) and not isinstance(x[0], (int, float)): + return True + return False + + +@Registry.register("tokenizers.sp") +class SentencepieceTokenizer(bv_tok.Tokenizer): + """Wraps a `tftext.SentencepieceTokenizer`. + + If you plan to use this tokenizer, please familiarize yourself with the test + cases first. This is likely to save you a lot of troubles down the road, trust + me! + """ + + def __init__(self, model, tokensets=()): + with gfile.GFile(KNOWN_TOKENIZERS.get(model, model), "rb") as f: + model_bytes = f.read() + extras = bv_tok.get_extra_tokens(tokensets) + model_bytes = _add_pieces(model_bytes, extras) + self._tok_sp = SPProcessor() + self._tok_sp.LoadFromSerializedProto(model_bytes) + self.extras = {self._tok_sp.PieceToId(x): x for x in extras} + + def to_int(self, text, *, bos=False, eos=False): + def _single(s): + return ( + ([self.bos_token] if bos else []) + + self._tok_sp.EncodeAsIds(s) + + ([self.eos_token] if eos else []) + ) + if isinstance(text, str): + return _single(text) + return type(text)([_single(s) for s in text]) + + def to_str(self, tokens, *, stop_at_eos=True): + def _single(toks): + toks = [int(t) for t in toks] # We really need this for DecodeIds. + if stop_at_eos: + try: # The SentencePiece strips eos, but does not stop at it, so we do. + toks = toks[:toks.index(self.eos_token)] + except ValueError: # No eos token found, nothing to do. + pass + return self._tok_sp.DecodeIds(toks) + if _iterable(tokens): + return [_single(toks) for toks in tokens] + return _single(tokens) + + def _check_known(self, piece): + if (id_ := self._tok_sp.PieceToId(piece)) == self._tok_sp.unk_id(): + logging.error("Piece '%s' is not known (unk=%s)!", piece, id_) + return id_ + + def to_piece(self, idx): + return self._tok_sp.IdToPiece(int(idx)) + + @property + def pad_token(self): + return self._tok_sp.pad_id() + + @property + def eos_token(self): + return self._tok_sp.eos_id() + + @property + def bos_token(self): + return self._tok_sp.bos_id() + + @property + def vocab_size(self): + return self._tok_sp.GetPieceSize() + + # For the _tf_op variants, we need a lot of wrapping boilerplate. + + def to_int_tf_op(self, text, *, bos=False, eos=False): + text = tf.convert_to_tensor(text) + if text.ndim == 0: + def fn(txt): + string = txt.numpy().decode() + return tf.constant(self.to_int(string, bos=bos, eos=eos), tf.int32) + return tf.py_function(fn, [text], tf.int32) + else: + def fn(txt): + strings = [s.decode() for s in txt.numpy().tolist()] + toks = self.to_int(strings, bos=bos, eos=eos) + return tf.ragged.constant(toks) + out_type = tf.RaggedTensorSpec([tf.shape(text)[0], None], tf.int32) + return tf.py_function(fn, [text], Tout=out_type) + + def to_str_tf_op(self, tokens, *, stop_at_eos=True): + def single(t): + fn = functools.partial(self.to_str, stop_at_eos=stop_at_eos) + return tf.numpy_function(fn, [t], tf.string, stateful=False) + if _iterable(tokens): + return tf.map_fn(single, tokens, tf.string) + return single(tokens) diff --git a/big_vision/pp/ops_text_test.py b/big_vision/pp/ops_text_test.py new file mode 100644 index 0000000000000000000000000000000000000000..5d9e7aaa90d4bd3382d60820492fdc5bb37a838d --- /dev/null +++ b/big_vision/pp/ops_text_test.py @@ -0,0 +1,159 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for ops_text.""" + +import copy + +from absl.testing import parameterized +import big_vision.pp.ops_text as pp +from big_vision.pp.registry import Registry +import numpy as np +import tensorflow as tf + + +class PyToTfWrapper: + """Allows to use `to_{int,str}_tf()` via `to_{int,str}()`.""" + + def __init__(self, tok): + self.tok = tok + self.bos_token = tok.bos_token + self.eos_token = tok.eos_token + self.vocab_size = tok.vocab_size + + def to_int(self, text, *, bos=False, eos=False): + ret = self.tok.to_int_tf_op(text, bos=bos, eos=eos) + if isinstance(ret, tf.RaggedTensor): + return [t.numpy().tolist() for t in ret] + return ret.numpy().tolist() + + def to_str(self, tokens, stop_at_eos=True): + ret = self.tok.to_str_tf_op( + tf.ragged.constant(tokens), + stop_at_eos=stop_at_eos, + ) + if ret.ndim == 0: + return ret.numpy().decode() + return [t.numpy().decode() for t in ret] + + +class PpOpsTest(tf.test.TestCase, parameterized.TestCase): + + def tfrun(self, ppfn, data): + # Run once as standalone, as could happen eg in colab. + yield {k: np.array(v) for k, v in ppfn(copy.deepcopy(data)).items()} + + # And then once again as part of tfdata pipeline. + # You'd be surprised how much these two differ! + tfdata = tf.data.Dataset.from_tensors(copy.deepcopy(data)) + for npdata in tfdata.map(ppfn).as_numpy_iterator(): + yield npdata + + def testtok(self): + # https://github.com/google/sentencepiece/blob/master/python/test/test_model.model + return "test_model.model" # Should we just commit it? It's 200kB + + def test_get_pp_clip_i1k_label_names(self): + op = pp.get_pp_clip_i1k_label_names() + labels = op({"label": tf.constant([0, 1])})["labels"].numpy().tolist() + self.assertAllEqual(labels, ["tench", "goldfish"]) + + @parameterized.parameters((b"Hello world ScAlAr!", b"hello world scalar!"), + (["Decoded Array!"], ["decoded array!"]), + ([b"aA", "bB"], [b"aa", "bb"])) + def test_get_lower(self, inputs, expected_output): + op = pp.get_lower() + out = op({"text": tf.constant(inputs)}) + self.assertAllEqual(out["text"].numpy(), np.array(expected_output)) + + @parameterized.named_parameters( + ("py", False), + ("tf", True), + ) + def test_sentencepiece_tokenizer(self, wrap_tok): + tok = pp.SentencepieceTokenizer(self.testtok()) + if wrap_tok: + tok = PyToTfWrapper(tok) + self.assertEqual(tok.vocab_size, 1000) + bos, eos = tok.bos_token, tok.eos_token + self.assertEqual(bos, 1) + self.assertEqual(eos, 2) + # Note: test model does NOT have a token (similar to e.g. "mistral"). + # `.to_int()` wraps `.to_int_tf_ops` which is thus also tested + self.assertEqual(tok.to_int("blah"), [80, 180, 60]) + self.assertEqual(tok.to_int("blah", bos=True), [bos, 80, 180, 60]) + self.assertEqual(tok.to_int("blah", eos=True), [80, 180, 60, eos]) + self.assertEqual( + tok.to_int("blah", bos=True, eos=True), [bos, 80, 180, 60, eos] + ) + self.assertEqual( + tok.to_int(["blah", "blah blah"]), + [[80, 180, 60], [80, 180, 60, 80, 180, 60]], + ) + # inverse of above + # `.to_str()` wraps `.to_str_tf_ops` which is thus also tested + self.assertEqual(tok.to_str([80, 180, 60]), "blah") + self.assertEqual(tok.to_str([1, 80, 180, 60]), "blah") + self.assertEqual(tok.to_str([80, 180, 60, 2]), "blah") + self.assertEqual( + tok.to_str([[80, 180, 60], [80, 180, 60, 80, 180, 60]]), + ["blah", "blah blah"], + ) + + def test_sentencepiece_tokenizer_tf_op_ndarray_input(self): + tok = pp.SentencepieceTokenizer(self.testtok()) + bos, eos = tok.bos_token, tok.eos_token + arr = np.array([[bos, 80, 180, 60, eos]] * 2, dtype=np.int32) + self.assertEqual(tok.to_str_tf_op(arr).numpy().tolist(), [b"blah"] * 2) + + def test_sentencepiece_tokenizer_tokensets(self): + tok = pp.SentencepieceTokenizer(self.testtok(), tokensets=["loc"]) + self.assertEqual(tok.vocab_size, 2024) + self.assertEqual( + tok.to_int("blah"), [80, 180, 60, 1000, 2023] + ) + + def test_sentencepiece_stop_at_eos(self): + tok = pp.SentencepieceTokenizer(self.testtok()) + self.assertEqual(tok.to_str([80, 180, 60], stop_at_eos=False), "blah") + eos = tok.eos_token + self.assertEqual(tok.to_str([80, eos, 180, 60], stop_at_eos=False), "blah") + self.assertEqual(tok.to_str([80, eos, 180, 60], stop_at_eos=True), "b") + self.assertEqual( + tok.to_str([[80, eos, 180, 60], [80, 180, eos, 60]], stop_at_eos=True), + ["b", "bla"] + ) + + def test_sentencepiece_extra_tokens(self): + tok = pp.SentencepieceTokenizer(self.testtok()) + self.assertEqual(tok.to_str([1, 80, 180, 60, 2], stop_at_eos=False), "blah") + tok = pp.SentencepieceTokenizer( + self.testtok(), tokensets=["sp_extra_tokens"] + ) + self.assertEqual(tok.vocab_size, 1001) # Also added the token. + self.assertEqual( + tok.to_str([1, 80, 180, 60, 2], stop_at_eos=False), " blah" + ) + + +@Registry.register("tokensets.sp_extra_tokens") +def _get_sp_extra_tokens(): + # For sentencepiece, adding these tokens will make them visible when decoding. + # If a token is not found (e.g. "" is not found in "mistral"), then it is + # added to the vocabulary, increasing the vocab_size accordingly. + return ["", "", ""] + + +if __name__ == "__main__": + tf.test.main() diff --git a/big_vision/pp/proj/clippo/download_unifont.sh b/big_vision/pp/proj/clippo/download_unifont.sh new file mode 100644 index 0000000000000000000000000000000000000000..44251749e0c8171318f49e75b9b725d29d06c53a --- /dev/null +++ b/big_vision/pp/proj/clippo/download_unifont.sh @@ -0,0 +1,21 @@ +# Copyright 2022 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/bin/bash +# This is intended to be run from the big_vision repository root: +# +# bash big_vision/pp/proj/clippo/download_unifont.sh +wget https://unifoundry.com/pub/unifont/unifont-9.0.06/font-builds/unifont-9.0.06.hex.gz https://unifoundry.com/pub/unifont/unifont-9.0.06/font-builds/unifont_upper-9.0.06.hex.gz +gunzip unifont-9.0.06.hex.gz unifont_upper-9.0.06.hex.gz +mv unifont-9.0.06.hex unifont_upper-9.0.06.hex big_vision/pp/proj/clippo/ \ No newline at end of file diff --git a/big_vision/pp/proj/clippo/pp_ops.py b/big_vision/pp/proj/clippo/pp_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..e92050dc5f3371c730913637be191c0c63bf5423 --- /dev/null +++ b/big_vision/pp/proj/clippo/pp_ops.py @@ -0,0 +1,153 @@ +# Copyright 2022 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Preprocessing functions for CLIP with Pixels Only (CLIPPO).""" +from absl import logging +from big_vision.pp import utils +from big_vision.pp.registry import Registry +import numpy as np +import tensorflow as tf + + +@Registry.register("preprocess_ops.render_unifont") +@utils.InKeyOutKey(indefault="texts", outdefault="image") +def get_pp_render_text(image_size: int, font_size: int = 16, max_chars=768, + background_brightness=127, text_brightness=0, + lower=True, monospace=False, spacing=1, min_width=4, + resize_method="area"): + """Renders text as image, using binary Unifont. + + Largely based on Jeffrey Sorensen's text rendering implementation. + + Args: + image_size: Width/height of output image. + font_size: Font size to use. Recommended to leave at 16, as this requires + no resizing, and is safe. + max_chars: Maximum inpute characters to render, to make faster. + background_brightness: (r, g, b) of background pixels. + text_brightness: (r, g, b) of text pixels. + lower: whether to lowercase. + monospace: if False, text characters are horizontally trimmed according to + `spacing` and `minwidth` args. + spacing: # pixels between each letter. + min_width: Minimum width of each letter. Useful to make sure e.g. spaces and + full stops aren't collapsed to nothing. + resize_method: resize method to use if fontsize != 16. + + Returns: + Function which renders text as an image. + """ + bit_embedding = np.zeros((0x200000, 32), dtype=np.uint8) + colpattern = {64: range(32), + 32: sorted(tuple(range(0, 32, 4)) + tuple(range(2, 32, 4)))} + + unifont_path = "big_vision/pp/proj/clippo/unifont-9.0.06.hex" + unifont_upper_path = "big_vision/pp/proj/clippo/unifont_upper-9.0.06.hex" + + with tf.io.gfile.GFile(unifont_path) as f: + for line in f: + row = int(line[0:4], 16) + hexbits = line[5:-1] + bit_embedding[row, colpattern[len(hexbits)]] = bytearray.fromhex(hexbits) + + with tf.io.gfile.GFile(unifont_upper_path) as f: + for line in f: + row = int(line[0:6], 16) + hexbits = line[7:-1] + bit_embedding[row, colpattern[len(hexbits)]] = bytearray.fromhex(hexbits) + + params = tf.constant(bit_embedding, dtype=tf.uint8) + + def trim_letter(letter): + """Remove white space based on the letter size.""" + v = tf.reduce_max(letter, axis=0) + has_pixels = tf.reshape(tf.where(v), (-1,), name="RS5") + no_pixels = tf.equal(tf.reduce_max(v), 0) + first = tf.cond(no_pixels, lambda: tf.constant(0, tf.int64), + lambda: has_pixels[0]) + last = tf.cond(no_pixels, lambda: tf.constant(0, tf.int64), + lambda: has_pixels[-1]) + + first = tf.maximum(first - spacing, 0) + last = tf.maximum(last + spacing, first + min_width) + return tf.RaggedTensor.from_tensor(tf.transpose(letter[:, first:last])) + + def to_image(rendered, width, height=None): + """Makes a nice square image from a long string of rendered charcaters.""" + height = height or width + max_letter_width = tf.reduce_max(rendered.row_lengths(1)) + row_lengths = tf.cast(tf.cumsum(rendered.row_lengths(1)), tf.float32) + div = tf.cast(width - max_letter_width, tf.float32) # For rounding errors. + row_idx = tf.cast(tf.floor(row_lengths / div), tf.int64) + row_idx = tf.RaggedTensor.from_value_rowids(tf.range(tf.shape(rendered)[0]), + row_idx) + trimmed = tf.gather(rendered, row_idx, axis=0) + trimmed = trimmed.merge_dims(1, 2) + trimmed = trimmed.to_tensor(default_value=0) + trimmed = tf.transpose(trimmed, (0, 2, 1)) + trimmed = tf.reshape(trimmed, (-1, tf.shape(trimmed)[-1]), name="RS4") + trimmed = trimmed[:height] + + wpad = width - tf.shape(trimmed)[1] + hpad = height - tf.shape(trimmed)[0] + padded = tf.pad(trimmed, [[0, hpad], [0, wpad]]) + tf.assert_equal(tf.shape(padded), tf.constant((height, width))) + return tf.ensure_shape(padded, (width, height)) + + def render(text): + if lower: + text = tf.strings.lower(text) + text = tf.reshape(text, (-1,))[0] + ids = tf.strings.unicode_decode(text, "UTF-8") + if max_chars: + ids = ids[:max_chars] + embed = tf.nn.embedding_lookup(params, ids) # Get the letters + # Each letter is 32 uint8s, but we want binary 16x16 grid. + # The following does that in a rather hard to parse way. + vertical = tf.reshape(embed, [1, -1]) + repl = tf.reshape(tf.transpose(tf.tile(vertical, multiples=[8, 1])), [-1]) + ones = tf.ones_like(repl) + index = tf.cumsum(ones, exclusive=True) + sevens = tf.cast(tf.fill(tf.shape(repl), 7), tf.uint8) + moded = tf.bitwise.bitwise_and(index, sevens) + shifted = tf.bitwise.right_shift(repl, + tf.bitwise.bitwise_xor(moded, sevens)) + anded = tf.bitwise.bitwise_and(shifted, ones) + # And finally, letters; binary, 0 = background, 1 = letter. + letters = tf.reshape(anded, [tf.shape(ids)[0], 16, 16]) + + if font_size != 16: + logging.warning("The unifont text rendering function is highly optimized " + "for font size 16; using font size %i might lead to " + "suboptimal rendering and might degrade performance.", + font_size) + letters = tf.image.resize(letters[..., None], (font_size, font_size), + method=resize_method, antialias=True) + letters = tf.squeeze(letters, axis=-1) + + if monospace: + letters = tf.RaggedTensor.from_tensor(tf.transpose(letters, (0, 2, 1))) + else: + letters = tf.RaggedTensor.from_tensor(letters) + signature = tf.RaggedTensorSpec(shape=(None, font_size), ragged_rank=1, + dtype=letters.dtype) + letters = tf.map_fn(trim_letter, letters, fn_output_signature=signature) + + img = to_image(letters, image_size)[..., None] # A nice square image. + img *= (text_brightness - background_brightness) # Rescale value range. + img += background_brightness + + return tf.image.grayscale_to_rgb(tf.cast(img, tf.uint8)) + + return render diff --git a/big_vision/pp/proj/flaxformer/bert_ops.py b/big_vision/pp/proj/flaxformer/bert_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..bcc76ef30b1187b178a9a1b309f188fee4468971 --- /dev/null +++ b/big_vision/pp/proj/flaxformer/bert_ops.py @@ -0,0 +1,86 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""BERT-related preprocessing ops (using WordPiece tokenizer).""" + +from big_vision.pp import utils +from big_vision.pp.registry import Registry +import tensorflow as tf +import tensorflow_text + + +# Internally using +# BasicTokenizer +# https://github.com/tensorflow/text/blob/df5250d6cf1069990df4bf55154867391ab5381a/tensorflow_text/python/ops/bert_tokenizer.py#L67 +# WordpieceTokenizer +# https://github.com/tensorflow/text/blob/master/tensorflow_text/python/ops/wordpiece_tokenizer.py +def _create_bert_tokenizer(vocab_path): + """Returns cls_token id and tokenizer to use in a tf.Dataset.map function.""" + # Create tokenizer inside a tf.init_scope so the vocab is only loaded from + # disk once per dataset iterator (see: http://(internal link)). + # TODO: Make a local copy of vocab if creating many iterators. + with tf.init_scope(): + tokenizer = tensorflow_text.BertTokenizer( + vocab_path, + token_out_type=tf.int32, + lower_case=True, + ) + + with tf.io.gfile.GFile(vocab_path) as f: + vocab = f.read().split("\n") + cls_token = vocab.index("[CLS]") + + return cls_token, tokenizer + + +@Registry.register("preprocess_ops.bert_tokenize") +@utils.InKeyOutKey(indefault=None, outdefault="labels") +def get_pp_bert_tokenize(vocab_path, max_len, sample_if_multi=True): + """Extracts tokens with tensorflow_text.BertTokenizer. + + Args: + vocab_path: Path to a file containing the vocabulry for the WordPiece + tokenizer. It's the "vocab.txt" file in the zip file downloaded from + the original repo https://github.com/google-research/bert + max_len: Number of tokens after tokenization. + sample_if_multi: Whether the first text should be taken (if set to `False`), + or whether a random text should be tokenized. + + Returns: + A preprocessing Op. + """ + + cls_token, tokenizer = _create_bert_tokenizer(vocab_path) + + def _pp_bert_tokenize(labels): + + labels = tf.reshape(labels, (-1,)) + labels = tf.concat([labels, [""]], axis=0) + if sample_if_multi: + num_texts = tf.maximum(tf.shape(labels)[0] - 1, 1) # Don't sample "". + txt = labels[tf.random.uniform([], 0, num_texts, dtype=tf.int32)] + else: + txt = labels[0] # Always works, since we append "" earlier on. + + token_ids = tokenizer.tokenize(txt[None]) + padded_token_ids, mask = tensorflow_text.pad_model_inputs( + token_ids, max_len - 1) + del mask # Recovered from zero padding in model. + count = tf.shape(padded_token_ids)[0] + padded_token_ids = tf.concat( + [tf.fill([count, 1], cls_token), padded_token_ids], axis=1) + return padded_token_ids[0] + + return _pp_bert_tokenize + \ No newline at end of file diff --git a/big_vision/pp/proj/flaxformer/bert_ops_test.py b/big_vision/pp/proj/flaxformer/bert_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..3df58abd8e4f3712424ccc496d7ed39aae71cdcf --- /dev/null +++ b/big_vision/pp/proj/flaxformer/bert_ops_test.py @@ -0,0 +1,69 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for bert_ops.""" + +import tempfile + +from big_vision import input_pipeline +import big_vision.pp.builder as pp_builder +import big_vision.pp.ops_general # pylint: disable=unused-import +from big_vision.pp.proj.flaxformer import bert_ops # pylint: disable=unused-import +import tensorflow as tf + + +# BERT vocabulary for testing. +_BERT_VOCAB = [ + "[PAD]", + "[UNK]", + "more", + "than", + "one", + "[CLS]", + "[SEP]", +] + + +def _create_ds(pp_str, tensor_slices, num_examples): + return input_pipeline.make_for_inference( + tf.data.Dataset.from_tensor_slices(tensor_slices), + num_ex_per_process=[num_examples], + preprocess_fn=pp_builder.get_preprocess_fn(pp_str), + batch_size=num_examples, + )[0] + + +class BertOpsTest(tf.test.TestCase): + + def test_tokenize(self): + inkey = "texts" + vocab_path = f"{tempfile.mkdtemp()}/vocab.txt" + with open(vocab_path, "w") as f: + f.write("\n".join(_BERT_VOCAB)) + pp_str = ( + f"bert_tokenize(inkey='{inkey}', vocab_path='{vocab_path}', max_len=5)" + f"|keep('labels')" + ) + tensor_slices = { + inkey: tf.ragged.constant([["one more"], ["more than one"], [""]]) + } + ds = _create_ds(pp_str, tensor_slices, 3) + self.assertAllEqual( + next(iter(ds))["labels"], + [[5, 4, 2, 0, 0], [5, 2, 3, 4, 0], [5, 0, 0, 0, 0]], + ) + + +if __name__ == "__main__": + tf.test.main() diff --git a/big_vision/pp/proj/givt/pp_ops.py b/big_vision/pp/proj/givt/pp_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..882bc93c45e4a4fd30d061d7b51a6295241653d6 --- /dev/null +++ b/big_vision/pp/proj/givt/pp_ops.py @@ -0,0 +1,36 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""GIVT-specific preprocessing ops.""" + +from big_vision.pp import registry +from big_vision.pp import utils +import tensorflow as tf + + +@registry.Registry.register("preprocess_ops.bin_nyu_depth") +@utils.InKeyOutKey(indefault="labels", outdefault="labels") +def get_bin_nyu_depth(min_depth=0.001, max_depth=10.0, num_bins=256): + """Binning of NYU depth for UViM in preprocessing rather than model.""" + + def _bin_depth(labels): # pylint: disable=missing-docstring + labels = (labels - min_depth) / (max_depth - min_depth) + labels *= num_bins + labels = tf.cast(tf.floor(labels), tf.int32) + labels = tf.minimum(labels, num_bins - 1) + labels = tf.maximum(labels, 0) + return labels + + return _bin_depth + diff --git a/big_vision/pp/proj/paligemma/ops.py b/big_vision/pp/proj/paligemma/ops.py new file mode 100644 index 0000000000000000000000000000000000000000..a49a4b57d64570850460041a13edf8493fde8d28 --- /dev/null +++ b/big_vision/pp/proj/paligemma/ops.py @@ -0,0 +1,190 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""pp ops.""" + +import functools +import string + +from big_vision.pp import ops_text +from big_vision.pp import utils +from big_vision.pp.registry import Registry +import big_vision.pp.tokenizer as bv_tok +import numpy as np +import tensorflow as tf + + +@Registry.register('tokenizers.gemma') +def get_tokenizer_gemma( + tokensets=(), + model='gs://big_vision/gemma_tokenizer.model', +): + # See (internal link) for colab playground. + return ops_text.SentencepieceTokenizer(model=model, tokensets=tokensets) + + +@functools.cache +def tokenize_constant(model, text, bos='no', eos='no', length=None): + """Tokenize a constant string, with memoization.""" + assert eos in ('no', 'yes', 'sticky') + assert bos in ('no', 'yes') + tokenizer = bv_tok.get_tokenizer(model) + tokens = tokenizer.to_int( + text, bos=bos == 'yes', eos=eos in ('yes', 'sticky')) + + if length is None: + return tokens + + if len(tokens) > length: + if eos == 'sticky': + return np.r_[tokens[:length-1], tokens[-1]] + else: + return tokens[:length] + else: + return np.pad(tokens, [(0, length - len(tokens))], + constant_values=tokenizer.pad_token) + + +@Registry.register('preprocess_ops.tolen') +@utils.InKeyOutKey(indefault=None, outdefault=None, with_data=True) +def get_tolen(length, *, sticky_end=False, pad_value=None, pad_key=None): + """Gets token to a fixed length.""" + def _tolen(x, data): + if not length: + return x + + xlen = tf.shape(x)[0] + + if sticky_end: + trunc_fn = lambda: tf.concat([x[:length - 1], x[-1:]], axis=0) + else: + trunc_fn = lambda: x[:length] + + # Potentially get the pad value from a data key (to be tokenizer agnostic). + pad_value_ = pad_value + if pad_key: + pad_value_ = data[pad_key] + # If coming from a previous tokenization op, it's probably 1D; take first. + if getattr(pad_value_, 'ndim', 0) == 1: + pad_value_ = pad_value_[0] + assert pad_value_ is not None, 'Need either pad_value or pad_key.' + + pad_fn = lambda: tf.pad(x, [(0, length - xlen)], constant_values=pad_value_) + out = tf.cond(xlen >= length, trunc_fn, pad_fn) + out.set_shape([length]) + return out + return _tolen + + +@Registry.register('preprocess_ops.tok') +def get_tokenize(model, length=None, *, bos='no', eos='no', + text=None, key=None, inkey=None, outkey=None): + """Tokenizes and optionally truncates/pads a string.""" + + assert eos in ('no', 'yes', 'sticky') + assert bos in ('no', 'yes') + outkey_ = outkey or key + inkey_ = inkey or key + + if text is not None: + assert inkey is None, 'Either inkey or text, not both.' + tokens = tokenize_constant(model, text, bos=bos, eos=eos, length=length) + def _pp_tokenize_text(data): + data[outkey_] = tokens + return data + return _pp_tokenize_text + + tokenizer = bv_tok.get_tokenizer(model) + + def _pp_tokenize(data): + assert getattr(data[inkey_], 'ndim', 0) == 0, ( + f'Can only tokenize single string ({inkey_}, {data[inkey_].ndim}-D)') + + toks = tokenizer.to_int_tf_op( + data[inkey_], bos=bos == 'yes', eos=eos in ('yes', 'sticky')) + + tolen = get_tolen( + length, sticky_end=eos == 'sticky', + pad_value=bv_tok.get_tokenizer(model).pad_token, + key='tmp', + ) + toks = tolen({'tmp': toks})['tmp'] + + data[outkey_] = toks + return data + return _pp_tokenize + + +@Registry.register('preprocess_ops.masked_concat') +def get_masked_concat(keys, outkey='text', **masks): + assert all(len(keys) == len(m) for m in masks.values()), (keys, masks) + def _masked_concat(data): + data[outkey] = tf.concat([data[k] for k in keys], axis=0) + for mask_name, mask_vals in masks.items(): + m = [tf.fill(tf.shape(data[k]), v) for k, v in zip(keys, mask_vals)] + data[mask_name] = tf.concat(m, axis=0) + return data + return _masked_concat + + +@Registry.register('preprocess_ops.strfmt') +def get_strfmt(template, outkey='text'): + """Formats a string template with content form the data dict.""" + + def _template(data): + outputs = [] + parts = string.Formatter().parse(template) + for (literal_text, field_name, format_spec, conversion) in parts: + # For now, we keep it simple and don't support fancy format specs. + # But we can add support to that via py_func as soon as we need it. + assert not format_spec and not conversion + outputs.append(tf.constant(literal_text)) + if field_name: + value = data[field_name] + # Convert any non-strings (numbers, vectors) to a string. + if tf.convert_to_tensor(value).dtype != tf.string: + value = tf.strings.format('{}', value, summarize=-1) + outputs.append(value) + data[outkey] = tf.strings.join(outputs) + return data + + return _template + + +@Registry.register('preprocess_ops.strjoin') +@utils.InKeyOutKey() +def get_strjoin(glue): + def _strjoin(x): + return tf.strings.reduce_join(x, separator=glue) + return _strjoin + + +@Registry.register('preprocess_ops.majority') +@utils.InKeyOutKey() +def get_majority(): + def _majority(x): + val, _, count = tf.unique_with_counts(x) # Sadly, stablesorted. + return val[tf.argmax(count)] + return _majority + + +@Registry.register('preprocess_ops.getidx') +def getidx(inkey, index_key, outkey=None): + """Indexes a tensor and stores result in outkey.""" + def _getidx(data): + idx = data[index_key] + array = data[inkey] + data[outkey or inkey] = array[idx] + return data + return _getidx diff --git a/big_vision/pp/proj/paligemma/robustness.py b/big_vision/pp/proj/paligemma/robustness.py new file mode 100644 index 0000000000000000000000000000000000000000..e40ed0cc481ff6e5da6349885d3d5a6f50d84da0 --- /dev/null +++ b/big_vision/pp/proj/paligemma/robustness.py @@ -0,0 +1,72 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""pp ops.""" + +import math + +from big_vision.pp import utils +from big_vision.pp.registry import Registry +import tensorflow as tf + + +@Registry.register("preprocess_ops.resize_r") +@utils.InKeyOutKey() +def get_resize_r(size): + """Like standard `resize` but randomize some of its parameters.""" + size = utils.maybe_repeat(size, 2) + + # Sadly TF won't let us pass symbolic arguments, so we need to pre-create all + # variants of function calls we'd like to randomize over... + resize_fns = [ + lambda x, m=m, a=a: tf.image.resize(x, size, method=m, antialias=a) + for m in ["bilinear", "bicubic", "lanczos3", "area", "mitchellcubic"] + for a in [True, False] + ] + + def _resize_r(image): + """Resizes image to a given size.""" + dtype = image.dtype + tf_dtype = tf.type_spec_from_value(image).dtype + ifn = tf.random.uniform((), 0, len(resize_fns), tf.int32) + image = tf.switch_case(ifn, [lambda fn=fn: fn(image) for fn in resize_fns]) + return tf.cast(tf.clip_by_value(image, tf_dtype.min, tf_dtype.max), dtype) + + return _resize_r + + +@Registry.register("preprocess_ops.random_jpeg") +@utils.InKeyOutKey() +def get_random_jpeg(p): + """With probability `p`, randomly encode-decode as jpeg.""" + + fns = [ + lambda x: tf.image.adjust_jpeg_quality( + x, dct_method="INTEGER_FAST", + jpeg_quality=tf.random.uniform((), 75, 96, dtype=tf.int32), + ), + lambda x: tf.image.adjust_jpeg_quality( + x, dct_method="INTEGER_ACCURATE", + jpeg_quality=tf.random.uniform((), 75, 96, dtype=tf.int32), + ), + ] + + def _random_jpeg(image): + """Resizes image to a given size.""" + funcs = [lambda: image] + [lambda fn=fn: fn(image) for fn in fns] + logits = [math.log(prob) for prob in [1 - p] + [p / len(fns)] * len(fns)] + fn_idx = tf.random.categorical([logits], 1, dtype=tf.int32)[0, 0] + return tf.switch_case(fn_idx, funcs) + + return _random_jpeg diff --git a/big_vision/pp/proj/paligemma/sciqa_ops.py b/big_vision/pp/proj/paligemma/sciqa_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..845e0c0dadadb662231053ebdee64f24979ffc7d --- /dev/null +++ b/big_vision/pp/proj/paligemma/sciqa_ops.py @@ -0,0 +1,65 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""pp ops.""" + +from big_vision.pp.registry import Registry +import tensorflow as tf + + +@Registry.register('preprocess_ops.sci_qa_choices_shuffle') +def sci_qa_choices_shuffle( + choice_str_inkey='choices', + ans_inkey='answer', + indexed_choices_outkey='indexed_choices', + indexed_answer_outkey='indexed_answer', +): + """Random shuffle the sci_qa's choice on the fly. + + Args: + choice_str_inkey: the original choice list from + sciqa,e.g['apple','banana',..] + ans_inkey: the original answer from sciqa e.g. 1 + indexed_choices_outkey: shuffled choice (with index suffix concat to string) + e.g."(A) banana, (B) apple" + indexed_answer_outkey: shuffled answer with abc index, e,g + 1(original)->2(shuffled)->'B' (alphabet index) + + Returns: + """ + def _template(data): + alphabet = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ' + abc_tensor = tf.constant([f'({a})' for a in alphabet]) + abcans_tensor = tf.constant([f'{a}' for a in alphabet]) + choices = data[choice_str_inkey] + indices = tf.range(len(choices)) + # Shuffle the indices + shuffled_indices = tf.random.shuffle(indices) + # Use the shuffled indices to shuffle the tensor + shuffled_tensor = tf.gather(choices, shuffled_indices) + + abc_tensor = tf.gather(abc_tensor, indices) + + data[indexed_choices_outkey] = tf.strings.reduce_join( + tf.strings.join([abc_tensor, shuffled_tensor], separator=' '), + separator=', ', + ) + + answer_tensor = data[ans_inkey] + new_ans_indice = tf.where(tf.equal(shuffled_indices, answer_tensor)) + new_ans_indice = tf.gather(abcans_tensor, new_ans_indice) + data[indexed_answer_outkey] = tf.strings.reduce_join(new_ans_indice) + return data + + return _template diff --git a/big_vision/pp/proj/paligemma/segmentation.py b/big_vision/pp/proj/paligemma/segmentation.py new file mode 100644 index 0000000000000000000000000000000000000000..7b062acff63e75fb9b7a7ca975ca043c99e7a973 --- /dev/null +++ b/big_vision/pp/proj/paligemma/segmentation.py @@ -0,0 +1,160 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Segmentation-related ops.""" + +import functools + +from big_vision.pp import registry +import numpy as np +import tensorflow as tf + +from tensorflow.io import gfile + + +_KNOWN_MODELS = { + 'oi': 'gs://big_vision/paligemma/vae-oid.npz', +} + + +@functools.cache +def get_checkpoint(model): + with gfile.GFile(_KNOWN_MODELS.get(model, model), 'rb') as f: + return dict(np.load(f)) + + +@registry.Registry.register('preprocess_ops.refcoco_mask2str') +def get_refcoco_mask2str(model='oi'): + """Returns op for tokenizing a mask.""" + + seg_tokens = tf.constant(['' % i for i in range(128)]) + loc_tokens = tf.constant(['' % i for i in range(1024)]) + checkpoint = get_checkpoint(model) + + def refcoco_mask2str(data): + + mask = data['objects/mask'] + tf.ensure_shape(mask, [None, None, 3]) # requires choice() + sentence = data['objects/refs/sentence'] + tf.ensure_shape(sentence, []) # requires choice() + bbox = data['objects/bbox'] + tf.ensure_shape(bbox, [4]) # requires choice() + + h = tf.cast(tf.shape(mask)[0], tf.float32) + w = tf.cast(tf.shape(mask)[1], tf.float32) + y1 = tf.cast(tf.round(h * bbox[0]), tf.int32) + x1 = tf.cast(tf.round(w * bbox[1]), tf.int32) + y2 = tf.cast(tf.round(h * bbox[2]), tf.int32) + x2 = tf.cast(tf.round(w * bbox[3]), tf.int32) + + assert mask.dtype == tf.uint8, mask.dtype + mask = tf.image.resize( + mask[None, y1:y2, x1:x2, :1], + [64, 64], + method='bilinear', + antialias=True, + ) / 255.0 + + mask_indices = encode_to_codebook_indices(checkpoint, mask)[0] + mask_string = tf.strings.reduce_join(tf.gather(seg_tokens, mask_indices)) + + binned_loc = tf.cast(tf.round(bbox * 1023), tf.int32) + binned_loc = tf.clip_by_value(binned_loc, 0, 1023) + loc_string = tf.strings.reduce_join(tf.gather(loc_tokens, binned_loc)) + + data['prefix'] = sentence + data['suffix'] = tf.strings.join([loc_string, mask_string]) + + return data + + return refcoco_mask2str + + +# Based on https://arxiv.org/abs/2301.02229. + +NUM_DOWNSAMPLE_LAYERS = 4 +NUM_RES_BLOCKS = 2 + + +def encode_to_codebook_indices(checkpoint, masks): + """Encode a batch of binary segmentation masks into 16 tokens each. + + Based on code from https://arxiv.org/abs/2301.02229 + + Args: + checkpoint: model weights from PyTorch model. + masks: Must be in range `[0..1]`, and of shape `[None, 64, 64, 1]`. + + Returns: + A tensor of shape `[None, 16]` with elements in `range(128)`. + """ + + # We require that the input masks are already resized to 64x64. + x = tf.ensure_shape(masks, [None, 64, 64, 1]) + x = _norm(x) + + for n in range(NUM_DOWNSAMPLE_LAYERS): + x = _conv_tf( + checkpoint, x, strides=2, padding='SAME', layer_name=f'encoder.{2*n}' + ) + x = tf.nn.relu(x) + + for n in range(NUM_RES_BLOCKS): + x = _resblock_tf(checkpoint, x, layer_name=f'encoder.{8+n}.net') + + x = _conv_tf( + checkpoint, x, strides=1, padding='SAME', layer_name='encoder.10' + ) + + return _get_codebook_indices(checkpoint, x) + + +def _norm(x): + return 2.0 * (x - 0.5) + + +def _conv_tf(checkpoint, x, strides, padding, layer_name): + kernel = checkpoint[layer_name + '.weight'] + kernel = np.transpose(kernel, (2, 3, 1, 0)) + bias = checkpoint[layer_name + '.bias'] + return tf.nn.conv2d(x, kernel, strides=strides, padding=padding) + bias + + +def _resblock_tf(checkpoint, x, layer_name): + """Apply a residual block of the mask encoder.""" + original_x = x + x = _conv_tf( + checkpoint, x, padding='SAME', strides=1, layer_name=layer_name + '.0' + ) + x = tf.nn.relu(x) + x = _conv_tf( + checkpoint, x, padding='SAME', strides=1, layer_name=layer_name + '.2' + ) + x = tf.nn.relu(x) + x = _conv_tf( + checkpoint, x, padding='SAME', strides=1, layer_name=layer_name + '.4' + ) + return x + original_x + + +def _get_codebook_indices(checkpoint, encoder_output): + embeddings = checkpoint['_vq_vae._embedding'] + flat_input = tf.reshape(encoder_output, [-1, embeddings.shape[1]]) + distances = ( + tf.reduce_sum(flat_input**2, axis=1, keepdims=True) + + tf.reduce_sum(embeddings**2, axis=1) + - 2 * tf.matmul(flat_input, embeddings.T) + ) + indices = tf.argmin(distances, axis=1) + return tf.reshape(indices, [-1, 16]) diff --git a/big_vision/pp/proj/paligemma/video.py b/big_vision/pp/proj/paligemma/video.py new file mode 100644 index 0000000000000000000000000000000000000000..d353d0c64faac0020a5fc3cf0ef29d71724ea766 --- /dev/null +++ b/big_vision/pp/proj/paligemma/video.py @@ -0,0 +1,103 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Preprocessing for videos.""" + +from big_vision.pp import utils +from big_vision.pp.registry import Registry + +import tensorflow as tf + + +@Registry.register('preprocess_ops.video_decode') +def video_decode(res): + """Preprocessing.""" + + def _pp_per_image(img): + # decode + return tf.image.resize(tf.io.decode_jpeg(img), (res, res)) + + def _pp(data): + images = data['episodic_images'] + # resize + images = tf.map_fn(_pp_per_image, images, fn_output_signature=tf.float32) + # rescale + images = 2 * (images / 255.) - 1.0 + data['image'] = images + return data + + return _pp + + +@Registry.register('preprocess_ops.video_ensure_shape') +def video_ensure_shape(key, shape): + """Preprocessing.""" + def _video_ensure_shape(data): + data[key] = tf.ensure_shape(data[key], shape) + return data + + return _video_ensure_shape + + +@Registry.register('preprocess_ops.video_replicate_img') +def video_replicate_img(replicas, num_frames): + """Ensure that for short videos, we have the correct number of frames. + + We replicate and select. + + Args: + replicas: num_replicas before selection. Should be less than num_frames. + num_frames: number of frames + + Returns: + _replicate_img: preprocessing function + """ + + def _replicate_img(data): + # visual analogies + query + image = data['image'] + image = tf.tile(image, [replicas, 1, 1, 1]) + data['image'] = image[:num_frames] + return data + + return _replicate_img + + +@Registry.register('preprocess_ops.video_choice') +@utils.InKeyOutKey() +def video_choice(empty_fallback=None): + """Randomly takes one entry out of a tensor after flattening.""" + + def _choice(x): + x = tf.reshape(x, (-1,)) # Ensure it's a 1D array + + # Append the fallback value so we gracefully handle empty cases. + x0 = tf.zeros(1, x.dtype) if empty_fallback is None else [empty_fallback] + x = tf.concat([x, x0], axis=0) + + num_choices = tf.maximum(tf.shape(x)[0] - 1, 1) # Don't sample x0. + return x[tf.random.uniform([], 0, num_choices, dtype=tf.int32)] + + return _choice + + +@Registry.register('preprocess_ops.stack_images') +def stack_images(inkeys=(), outkey='image'): + + def _pp(data): + images = tf.stack([data[inkey] for inkey in inkeys]) + data[outkey] = images + return data + + return _pp diff --git a/big_vision/pp/proj/paligemma/widgetcap.py b/big_vision/pp/proj/paligemma/widgetcap.py new file mode 100644 index 0000000000000000000000000000000000000000..bfd7b82b977871abbabed42306f73ba13ad38df5 --- /dev/null +++ b/big_vision/pp/proj/paligemma/widgetcap.py @@ -0,0 +1,36 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Widgetcap pp ops.""" + +from big_vision.pp.registry import Registry +import tensorflow as tf + + +@Registry.register("preprocess_ops.draw_bbox") +def get_draw_bbox(image_key="image", bbox_key="bbox"): + """Draw a single bounding box.""" + + def _draw_bbox(data): + """Draw a single bounding box.""" + image = tf.cast(data[image_key], tf.float32) + image = tf.image.draw_bounding_boxes( + tf.expand_dims(image, 0), + tf.reshape(data[bbox_key], [1, 1, 4]), + tf.constant([255, 0, 0], dtype=tf.float32, shape=[1, 3]), + ) + data[image_key] = tf.squeeze(image) + return data + + return _draw_bbox diff --git a/big_vision/pp/proj/uvim/pp_ops.py b/big_vision/pp/proj/uvim/pp_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..929de6e1da161fe450a477cc69c3a42677350373 --- /dev/null +++ b/big_vision/pp/proj/uvim/pp_ops.py @@ -0,0 +1,206 @@ +# Copyright 2022 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Preprocessing ops.""" +from big_vision.pp import utils +from big_vision.pp.registry import Registry +import numpy as np +import tensorflow as tf + + +@Registry.register("preprocess_ops.rgb_to_grayscale_to_rgb") +@utils.InKeyOutKey(indefault="image", outdefault="image") +def get_rgb_to_grayscale_to_rgb(): + def _rgb_to_grayscale_to_rgb(image): + return tf.image.grayscale_to_rgb(tf.image.rgb_to_grayscale(image)) + return _rgb_to_grayscale_to_rgb + + +@Registry.register("preprocess_ops.nyu_eval_crop") +def get_nyu_eval_crop(): + """Crops labels and image to valid eval area.""" + # crop_h = slice(45, 471) + # crop_w = slice(41, 601) + crop_h_start = 54 + crop_h_size = 426 + crop_w_start = 41 + crop_w_size = 560 + + def _pp(data): + tf.debugging.assert_equal(tf.shape(data["labels"]), (480, 640, 1)) + tf.debugging.assert_equal(tf.shape(data["image"]), (480, 640, 3)) + data["labels"] = tf.slice(data["labels"], + [crop_h_start, crop_w_start, 0], + [crop_h_size, crop_w_size, -1]) + data["image"] = tf.slice(data["image"], + [crop_h_start, crop_w_start, 0], + [crop_h_size, crop_w_size, -1]) + return data + return _pp + + +@Registry.register("preprocess_ops.nyu_depth") +@utils.InKeyOutKey(indefault="depth", outdefault="labels") +def get_nyu_depth(): + """Preprocesses NYU depth data.""" + def _pp(depth): + return tf.expand_dims(tf.cast(depth, tf.float32), -1) + return _pp + + +@Registry.register("preprocess_ops.coco_panoptic") +def get_coco_panoptic_pp(): + """COCO-panoptic: produces a mask with labels and a mask with instance ids. + + Instance channel will have values between 1 and N, and -1 for non-annotated + pixels. + + Returns: + COCO panoptic preprocessign op. + """ + def _coco_panoptic(data): + instance_ids = tf.cast(data["panoptic_objects"]["id"], tf.int32) + instance_labels = tf.cast(data["panoptic_objects"]["label"], tf.int32) + + # Convert image with ids split in 3 channels into a an integer id. + id_mask = tf.einsum( + "hwc,c->hw", + tf.cast(data["panoptic_image"], tf.int32), + tf.constant([1, 256, 256**2], tf.int32)) + + # Broadcast into N boolean masks one per instance_id. + n_masks = tf.cast( + id_mask[:, :, None] == instance_ids[None, None, :], tf.int32) + + # Merge into a semantic and an instance id mask. + # Note: pixels which do not belong to any mask, will have value=-1 + # which creates an empty one_hot masks. + # Number instances starting at 1 (0 is treated specially by make_canonical). + instance_idx = tf.range(tf.shape(instance_ids)[-1]) + instances = tf.einsum("hwc,c->hw", n_masks, instance_idx + 1) + semantics = tf.einsum("hwc,c->hw", n_masks, instance_labels + 1) + + data["instances"] = instances[:, :, None] + data["semantics"] = semantics[:, :, None] + return data + + return _coco_panoptic + + +@Registry.register("preprocess_ops.make_canonical") +@utils.InKeyOutKey(indefault="labels", outdefault="labels") +def get_make_canonical(random=False, main_sort_axis="y"): + """Makes id mask ordered from left to right based on the center of mass.""" + # By convention, instances are in the last channel. + def _make_canonical(image): + """Op.""" + instimg = image[..., -1] + + # Compute binary instance masks. Note, we do not touch 0 and neg. ids. + ids = tf.unique(tf.reshape(instimg, [-1])).y + ids = ids[ids > 0] + n_masks = tf.cast( + instimg[None, :, :] == ids[:, None, None], tf.int32) + + if not random: + f = lambda x: tf.reduce_mean(tf.cast(tf.where(x), tf.float32), axis=0) + centers = tf.map_fn(f, tf.cast(n_masks, tf.int64), dtype=tf.float32) + centers = tf.reshape(centers, (tf.shape(centers)[0], 2)) + major = {"y": 0, "x": 1}[main_sort_axis] + perm = tf.argsort( + centers[:, 1 - major] + + tf.cast(tf.shape(instimg)[major], tf.float32) * centers[:, major]) + n_masks = tf.gather(n_masks, perm) + else: + n_masks = tf.random.shuffle(n_masks) + + idx = tf.range(tf.shape(ids)[0]) + can_mask = tf.einsum("chw,c->hw", n_masks, idx + 2) - 1 + # Now, all 0 and neg. ids have collapsed to -1. Thus, we recover 0 id from + # the original mask. + can_mask = tf.where(instimg == 0, 0, can_mask) + return tf.concat([image[..., :-1], can_mask[..., None]], axis=-1) + + return _make_canonical + + +@Registry.register("preprocess_ops.inception_box") +def get_inception_box( + *, area=(0.05, 1.0), aspect=(0.75, 1.33), min_obj_cover=0.0, + outkey="box", inkey="image"): + """Creates an inception style bounding box which can be used to crop.""" + def _inception_box(data): + _, _, box = tf.image.sample_distorted_bounding_box( + tf.shape(data[inkey]), + area_range=area, + aspect_ratio_range=aspect, + min_object_covered=min_obj_cover, + bounding_boxes=(data["objects"]["bbox"][None, :, :] + if min_obj_cover else tf.zeros([0, 0, 4])), + use_image_if_no_bounding_boxes=True) + # bbox is [[[y0,x0,y1,x1]]] + data[outkey] = (box[0, 0, :2], box[0, 0, 2:] - box[0, 0, :2]) + return data + return _inception_box + + +@Registry.register("preprocess_ops.crop_box") +@utils.InKeyOutKey(with_data=True) +def get_crop_box(*, boxkey="box"): + """Crops an image according to bounding box in `boxkey`.""" + def _crop_box(image, data): + shape = tf.shape(image)[:-1] + begin, size = data[boxkey] + begin = tf.cast(begin * tf.cast(shape, tf.float32), tf.int32) + size = tf.cast(size * tf.cast(shape, tf.float32), tf.int32) + begin = tf.concat([begin, tf.constant((0,))], axis=0) + size = tf.concat([size, tf.constant((-1,))], axis=0) + crop = tf.slice(image, begin, size) + # Unfortunately, the above operation loses the depth-dimension. So we need + # to restore it the manual way. + crop.set_shape([None, None, image.shape[-1]]) + return crop + return _crop_box + + +@Registry.register("preprocess_ops.randu") +def get_randu(key): + """Creates a random uniform float [0, 1) in `key`.""" + def _randu(data): + data[key] = tf.random.uniform([]) + return data + return _randu + + +@Registry.register("preprocess_ops.det_fliplr") +@utils.InKeyOutKey(with_data=True) +def get_det_fliplr(*, randkey="fliplr"): + """Flips an image horizontally based on `randkey`.""" + # NOTE: we could unify this with regular flip when randkey=None. + def _det_fliplr(orig_image, data): + flip_image = tf.image.flip_left_right(orig_image) + flip = tf.cast(data[randkey] > 0.5, orig_image.dtype) + return flip_image * flip + orig_image * (1 - flip) + return _det_fliplr + + +@Registry.register("preprocess_ops.strong_hash") +@utils.InKeyOutKey(indefault="tfds_id", outdefault="tfds_id") +def get_strong_hash(): + """Preprocessing that hashes a string.""" + def _strong_hash(string): + return tf.strings.to_hash_bucket_strong( + string, + np.iinfo(int).max, [3714561454027272724, 8800639020734831960]) + return _strong_hash diff --git a/big_vision/pp/proj/uvim/pp_ops_test.py b/big_vision/pp/proj/uvim/pp_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..78c065e372788ee77f03c53aadb444614a391ebc --- /dev/null +++ b/big_vision/pp/proj/uvim/pp_ops_test.py @@ -0,0 +1,128 @@ +# Copyright 2022 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for pp_ops.""" +import copy + +from big_vision.pp.proj.uvim import pp_ops as pp +import numpy as np +import tensorflow as tf + + +def get_image_data(dtype=tf.uint8): + img = tf.random.uniform((640, 320, 3), 0, 255, tf.int32) # Can't ask uint8!? + return {"image": tf.cast(img, dtype)} + + +class PreprocessOpsTest(tf.test.TestCase): + + def tfrun(self, ppfn, data={}): # pylint: disable=dangerous-default-value + # Run once as standalone, as could happen eg in colab. + yield {k: np.array(v) for k, v in ppfn(copy.deepcopy(data)).items()} + + if not data: # tf.data doesn't like completely empty dict...k + data = {"dummy": 0.0} + + # And then once again as part of tfdata pipeline. + # You'd be surprised how much these two differ! + tfdata = tf.data.Dataset.from_tensors(copy.deepcopy(data)) + for npdata in tfdata.map(ppfn).as_numpy_iterator(): + yield npdata + + def test_randu(self): + for output in self.tfrun(pp.get_randu("flip")): + self.assertEqual(output["flip"].shape, ()) + self.assertAllGreaterEqual(output["flip"], 0.0) + self.assertAllLessEqual(output["flip"], 1.0) + + def test_det_flip_lr(self): + # Test both dtypes to make it can be applied correctly to both. + for dtype in [tf.uint8, tf.float32]: + image_data = get_image_data(dtype) + for out in self.tfrun(pp.get_det_fliplr(randkey="rand"), + {"rand": 0.1, **image_data}): + self.assertTrue(np.all(image_data["image"] == out["image"])) + self.assertEqual(out["image"].dtype, dtype) + for out in self.tfrun(pp.get_det_fliplr(randkey="rand"), + {"rand": 0.6, **image_data}): + self.assertTrue(np.all(image_data["image"][:, ::-1, :] == out["image"])) + self.assertEqual(out["image"].dtype, dtype) + + def test_inception_box(self): + for out in self.tfrun(pp.get_inception_box(), get_image_data()): + self.assertEqual(out["box"][0].shape, (2,)) + self.assertEqual(out["box"][1].shape, (2,)) + + def test_crop_box(self): + data = get_image_data() + data["box"] = (tf.constant([0.5, 0.4]), tf.constant([0.25, 0.3])) + for out in self.tfrun(pp.get_crop_box(), data): + self.assertEqual(out["image"].shape, (160, 96, 3)) + self.assertAllEqual( + data["image"][320:320 + 160, 128:128 + 96], + out["image"]) + + def test_make_canonical(self): + orig = np.array([ + [1, 0, 3, 3, -1], + [1, 0, 3, 3, -1], + [1, 0, 2, 2, 2], + [1, 0, 0, -1, -1] + ], np.int32)[:, :, None] + expected = np.array([ + [2, 0, 1, 1, -1], + [2, 0, 1, 1, -1], + [2, 0, 3, 3, 3], + [2, 0, 0, -1, -1] + ], np.int32)[:, :, None] + for out in self.tfrun(pp.get_make_canonical(), {"labels": orig}): + self.assertTrue(np.all(out["labels"] == expected)) + + # Test it only affects last channel. + for out in self.tfrun(pp.get_make_canonical(), + {"labels": tf.tile(orig, (1, 1, 3))}): + self.assertAllEqual(out["labels"][..., 0], orig[..., 0]) + self.assertAllEqual(out["labels"][..., 1], orig[..., 0]) + self.assertAllEqual(out["labels"][..., 2], expected[..., 0]) + + def test_nyu_depth(self): + image = tf.zeros((5, 7, 3), dtype=tf.uint8) + depth = tf.zeros((5, 7), dtype=tf.float16) + data = { + "image": image, + "depth": depth + } + output = pp.get_nyu_depth()(data) + self.assertEqual(output["image"].shape, (5, 7, 3)) + self.assertEqual(output["image"].dtype, tf.uint8) + self.assertEqual(output["labels"].shape, (5, 7, 1)) + self.assertEqual(output["labels"].dtype, tf.float32) + + def test_nyu_eval_crop(self): + image = tf.zeros((480, 640, 3), dtype=tf.uint8) + depth = tf.zeros((480, 640), dtype=tf.float16) + data = { + "image": image, + "depth": depth + } + data = pp.get_nyu_depth()(data) + output = pp.get_nyu_eval_crop()(data) + self.assertEqual(output["image"].shape, (426, 560, 3)) + self.assertEqual(output["image"].dtype, tf.uint8) + self.assertEqual(output["labels"].shape, (426, 560, 1)) + self.assertEqual(output["labels"].dtype, tf.float32) + + +if __name__ == "__main__": + tf.test.main() diff --git a/big_vision/pp/registry.py b/big_vision/pp/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..f5c7d996d756be16ba68a5fcb143f23129e1249d --- /dev/null +++ b/big_vision/pp/registry.py @@ -0,0 +1,163 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Global Registry for big_vision pp ops. + +Author: Joan Puigcerver (jpuigcerver@) +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import ast +import contextlib +import functools + + +def parse_name(string_to_parse): + """Parses input to the registry's lookup function. + + Args: + string_to_parse: can be either an arbitrary name or function call + (optionally with positional and keyword arguments). + e.g. "multiclass", "resnet50_v2(filters_factor=8)". + + Returns: + A tuple of input name, argument tuple and a keyword argument dictionary. + Examples: + "multiclass" -> ("multiclass", (), {}) + "resnet50_v2(9, filters_factor=4)" -> + ("resnet50_v2", (9,), {"filters_factor": 4}) + + Author: Joan Puigcerver (jpuigcerver@) + """ + expr = ast.parse(string_to_parse, mode="eval").body # pytype: disable=attribute-error + if not isinstance(expr, (ast.Attribute, ast.Call, ast.Name)): + raise ValueError( + "The given string should be a name or a call, but a {} was parsed from " + "the string {!r}".format(type(expr), string_to_parse)) + + # Notes: + # name="some_name" -> type(expr) = ast.Name + # name="module.some_name" -> type(expr) = ast.Attribute + # name="some_name()" -> type(expr) = ast.Call + # name="module.some_name()" -> type(expr) = ast.Call + + if isinstance(expr, ast.Name): + return string_to_parse, (), {} + elif isinstance(expr, ast.Attribute): + return string_to_parse, (), {} + + def _get_func_name(expr): + if isinstance(expr, ast.Attribute): + return _get_func_name(expr.value) + "." + expr.attr + elif isinstance(expr, ast.Name): + return expr.id + else: + raise ValueError( + "Type {!r} is not supported in a function name, the string to parse " + "was {!r}".format(type(expr), string_to_parse)) + + def _get_func_args_and_kwargs(call): + args = tuple([ast.literal_eval(arg) for arg in call.args]) + kwargs = { + kwarg.arg: ast.literal_eval(kwarg.value) for kwarg in call.keywords + } + return args, kwargs + + func_name = _get_func_name(expr.func) + func_args, func_kwargs = _get_func_args_and_kwargs(expr) + + return func_name, func_args, func_kwargs + + +class Registry(object): + """Implements global Registry. + + Authors: Joan Puigcerver (jpuigcerver@), Alexander Kolesnikov (akolesnikov@) + """ + + _GLOBAL_REGISTRY = {} + + @staticmethod + def global_registry(): + return Registry._GLOBAL_REGISTRY + + @staticmethod + def register(name, replace=False): + """Creates a function that registers its input.""" + + def _register(item): + if name in Registry.global_registry() and not replace: + raise KeyError("The name {!r} was already registered.".format(name)) + + Registry.global_registry()[name] = item + return item + + return _register + + @staticmethod + def lookup(lookup_string, kwargs_extra=None): + """Lookup a name in the registry.""" + + try: + name, args, kwargs = parse_name(lookup_string) + except ValueError as e: + raise ValueError(f"Error parsing:\n{lookup_string}") from e + if kwargs_extra: + kwargs.update(kwargs_extra) + item = Registry.global_registry()[name] + return functools.partial(item, *args, **kwargs) + + @staticmethod + def knows(lookup_string): + try: + name, _, _ = parse_name(lookup_string) + except ValueError as e: + raise ValueError(f"Error parsing:\n{lookup_string}") from e + return name in Registry.global_registry() + + +@contextlib.contextmanager +def temporary_ops(**kw): + """Registers specified pp ops for use in a `with` block. + + Example use: + + with pp_registry.remporary_ops( + pow=lambda alpha: lambda d: {k: v**alpha for k, v in d.items()}): + pp = pp_builder.get_preprocess_fn("pow(alpha=2.0)|pow(alpha=0.5)") + features = pp(features) + + Args: + **kw: Names are preprocess string function names to be used to specify the + preprocess function. Values are functions that can be called with params + (e.g. the `alpha` param in above example) and return functions to be used + to transform features. + + Yields: + A context manager to be used in a `with` statement. + """ + reg = Registry.global_registry() + kw = {f"preprocess_ops.{k}": v for k, v in kw.items()} + for k in kw: + assert k not in reg + for k, v in kw.items(): + reg[k] = v + try: + yield + finally: + for k in kw: + del reg[k] diff --git a/big_vision/pp/registry_test.py b/big_vision/pp/registry_test.py new file mode 100644 index 0000000000000000000000000000000000000000..2296e7de91ce0495bade59e8e65417384507e58e --- /dev/null +++ b/big_vision/pp/registry_test.py @@ -0,0 +1,128 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for registry.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from unittest import mock + +from absl.testing import absltest +from big_vision.pp import registry + + +class RegistryTest(absltest.TestCase): + + def setUp(self): + super(RegistryTest, self).setUp() + # Mock global registry in each test to keep them isolated and allow for + # concurrent tests. + self.addCleanup(mock.patch.stopall) + self.global_registry = dict() + self.mocked_method = mock.patch.object( + registry.Registry, "global_registry", + return_value=self.global_registry).start() + + def test_parse_name(self): + name, args, kwargs = registry.parse_name("f") + self.assertEqual(name, "f") + self.assertEqual(args, ()) + self.assertEqual(kwargs, {}) + + name, args, kwargs = registry.parse_name("f()") + self.assertEqual(name, "f") + self.assertEqual(args, ()) + self.assertEqual(kwargs, {}) + + name, args, kwargs = registry.parse_name("func(a=0,b=1,c='s')") + self.assertEqual(name, "func") + self.assertEqual(args, ()) + self.assertEqual(kwargs, {"a": 0, "b": 1, "c": "s"}) + + name, args, kwargs = registry.parse_name("func(1,'foo',3)") + self.assertEqual(name, "func") + self.assertEqual(args, (1, "foo", 3)) + self.assertEqual(kwargs, {}) + + name, args, kwargs = registry.parse_name("func(1,'2',a=3,foo='bar')") + self.assertEqual(name, "func") + self.assertEqual(args, (1, "2")) + self.assertEqual(kwargs, {"a": 3, "foo": "bar"}) + + name, args, kwargs = registry.parse_name("foo.bar.func(a=0,b=(1),c='s')") + self.assertEqual(name, "foo.bar.func") + self.assertEqual(kwargs, dict(a=0, b=1, c="s")) + + with self.assertRaises(SyntaxError): + registry.parse_name("func(0") + with self.assertRaises(SyntaxError): + registry.parse_name("func(a=0,,b=0)") + with self.assertRaises(SyntaxError): + registry.parse_name("func(a=0,b==1,c='s')") + with self.assertRaises(ValueError): + registry.parse_name("func(a=0,b=undefined_name,c='s')") + + def test_register(self): + # pylint: disable=unused-variable + @registry.Registry.register("func1") + def func1(): + pass + + self.assertLen(registry.Registry.global_registry(), 1) + + def test_lookup_function(self): + + @registry.Registry.register("func1") + def func1(arg1, arg2, arg3): # pylint: disable=unused-variable + return arg1, arg2, arg3 + + self.assertTrue(callable(registry.Registry.lookup("func1"))) + self.assertEqual(registry.Registry.lookup("func1")(1, 2, 3), (1, 2, 3)) + self.assertEqual( + registry.Registry.lookup("func1(arg3=9)")(1, 2), (1, 2, 9)) + self.assertEqual( + registry.Registry.lookup("func1(arg2=9,arg1=99)")(arg3=3), (99, 9, 3)) + self.assertEqual( + registry.Registry.lookup("func1(arg2=9,arg1=99)")(arg1=1, arg3=3), + (1, 9, 3)) + + self.assertEqual( + registry.Registry.lookup("func1(1)")(1, 2), (1, 1, 2)) + self.assertEqual( + registry.Registry.lookup("func1(1)")(arg3=3, arg2=2), (1, 2, 3)) + self.assertEqual( + registry.Registry.lookup("func1(1, 2)")(3), (1, 2, 3)) + self.assertEqual( + registry.Registry.lookup("func1(1, 2)")(arg3=3), (1, 2, 3)) + self.assertEqual( + registry.Registry.lookup("func1(1, arg2=2)")(arg3=3), (1, 2, 3)) + self.assertEqual( + registry.Registry.lookup("func1(1, arg3=2)")(arg2=3), (1, 3, 2)) + self.assertEqual( + registry.Registry.lookup("func1(1, arg3=2)")(3), (1, 3, 2)) + + with self.assertRaises(TypeError): + registry.Registry.lookup("func1(1, arg2=2)")(3) + with self.assertRaises(TypeError): + registry.Registry.lookup("func1(1, arg3=3)")(arg3=3) + with self.assertRaises(TypeError): + registry.Registry.lookup("func1(1, arg3=3)")(arg1=3) + with self.assertRaises(SyntaxError): + registry.Registry.lookup("func1(arg1=1, 3)")(arg2=3) + + +if __name__ == "__main__": + absltest.main() diff --git a/big_vision/pp/tokenizer.py b/big_vision/pp/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..681494e436aacd48d5d720e07d2df1a80c704eb2 --- /dev/null +++ b/big_vision/pp/tokenizer.py @@ -0,0 +1,103 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""The tokenizer API for big_vision, and central registration place.""" +import functools +import importlib +from typing import Protocol + +from absl import logging +from big_vision.pp import registry +import big_vision.utils as u +import numpy as np + + +class Tokenizer(Protocol): + """Just to unify on the API as we now have mmany different ones.""" + + def to_int(self, text, *, bos=False, eos=False): + """Tokenizes `text` into a list of integer tokens. + + Args: + text: can be a single string, or a list of strings. + bos: Whether a beginning-of-sentence token should be prepended. + eos: Whether an end-of-sentence token should be appended. + + Returns: + List or list-of-list of tokens. + """ + + def to_int_tf_op(self, text, *, bos=False, eos=False): + """Same as `to_int()`, but as TF ops to be used in pp.""" + + def to_str(self, tokens, *, stop_at_eos=True): + """Inverse of `to_int()`. + + Args: + tokens: list of tokens, or list of lists of tokens. + stop_at_eos: remove everything that may come after the first EOS. + + Returns: + A string (if `tokens` is a list of tokens), or a list of strings. + Note that most tokenizers strip select few control tokens like + eos/bos/pad/unk from the output string. + """ + + def to_str_tf_op(self, tokens, *, stop_at_eos=True): + """Same as `to_str()`, but as TF ops to be used in pp.""" + + @property + def pad_token(self): + """Token id of padding token.""" + + @property + def eos_token(self): + """Token id of end-of-sentence token.""" + + @property + def bos_token(self): + """Token id of beginning-of-sentence token.""" + + @property + def vocab_size(self): + """Returns the size of the vocabulary.""" + + +@functools.cache +def get_tokenizer(name): + with u.chrono.log_timing(f"z/secs/tokenizer/{name}"): + if not registry.Registry.knows(f"tokenizers.{name}"): + raw_name, *_ = registry.parse_name(name) + logging.info("Tokenizer %s not registered, " + "trying import big_vision.pp.%s", name, raw_name) + importlib.import_module(f"big_vision.pp.{raw_name}") + + return registry.Registry.lookup(f"tokenizers.{name}")() + + +def get_extra_tokens(tokensets): + extra_tokens = [] + for tokenset in tokensets: + extra_tokens.extend(registry.Registry.lookup(f"tokensets.{tokenset}")()) + return list(np.unique(extra_tokens)) # Preserves order. Dups make no sense. + + +@registry.Registry.register("tokensets.loc") +def _get_loc1024(n=1024): + return [f"" for i in range(n)] + + +@registry.Registry.register("tokensets.seg") +def _get_seg(n=128): + return [f"" for i in range(n)] diff --git a/big_vision/pp/utils.py b/big_vision/pp/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3ee834560246549c71f0a6d9785694fd1507ca9b --- /dev/null +++ b/big_vision/pp/utils.py @@ -0,0 +1,53 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Preprocessing utils.""" + +from collections import abc + + +def maybe_repeat(arg, n_reps): + if not isinstance(arg, abc.Sequence) or isinstance(arg, str): + arg = (arg,) * n_reps + return arg + + +class InKeyOutKey(object): + """Decorator for preprocessing ops, which adds `inkey` and `outkey` arguments. + + Note: Only supports single-input single-output ops. + """ + + def __init__(self, indefault="image", outdefault="image", with_data=False): + self.indefault = indefault + self.outdefault = outdefault + self.with_data = with_data + + def __call__(self, orig_get_pp_fn): + + def get_ikok_pp_fn(*args, key=None, + inkey=self.indefault, outkey=self.outdefault, **kw): + + orig_pp_fn = orig_get_pp_fn(*args, **kw) + def _ikok_pp_fn(data): + # Optionally allow the function to get the full data dict as aux input. + if self.with_data: + data[key or outkey] = orig_pp_fn(data[key or inkey], data=data) + else: + data[key or outkey] = orig_pp_fn(data[key or inkey]) + return data + + return _ikok_pp_fn + + return get_ikok_pp_fn diff --git a/big_vision/pp/utils_test.py b/big_vision/pp/utils_test.py new file mode 100644 index 0000000000000000000000000000000000000000..beec18cef62a9638ed143229d7aedc5e218a70b6 --- /dev/null +++ b/big_vision/pp/utils_test.py @@ -0,0 +1,53 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for preprocessing utils.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from big_vision.pp import utils +import tensorflow.compat.v1 as tf + + +class UtilsTest(tf.test.TestCase): + + def test_maybe_repeat(self): + self.assertEqual((1, 1, 1), utils.maybe_repeat(1, 3)) + self.assertEqual((1, 2), utils.maybe_repeat((1, 2), 2)) + self.assertEqual([1, 2], utils.maybe_repeat([1, 2], 2)) + + def test_inkeyoutkey(self): + @utils.InKeyOutKey() + def get_pp_fn(shift, scale=0): + def _pp_fn(x): + return scale * x + shift + return _pp_fn + + data = {"k_in": 2, "other": 3} + ppfn = get_pp_fn(1, 2, inkey="k_in", outkey="k_out") # pylint: disable=unexpected-keyword-arg + self.assertEqual({"k_in": 2, "k_out": 5, "other": 3}, ppfn(data)) + + data = {"k": 6, "other": 3} + ppfn = get_pp_fn(1, inkey="k", outkey="k") # pylint: disable=unexpected-keyword-arg + self.assertEqual({"k": 1, "other": 3}, ppfn(data)) + + data = {"other": 6, "image": 3} + ppfn = get_pp_fn(5, 2) + self.assertEqual({"other": 6, "image": 11}, ppfn(data)) + + +if __name__ == "__main__": + tf.test.main() diff --git a/big_vision/requirements.txt b/big_vision/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..f910f346b52191617afdc0cebb070cb6c735d38e --- /dev/null +++ b/big_vision/requirements.txt @@ -0,0 +1,18 @@ +numpy>=1.26 +absl-py +git+https://github.com/google/CommonLoopUtils +distrax +editdistance +einops +flax +optax +git+https://github.com/google/flaxformer +git+https://github.com/akolesnikoff/panopticapi.git@mute +overrides +protobuf +sentencepiece +tensorflow-cpu +tfds-nightly +tensorflow-text +tensorflow-gan +pycocoevalcap diff --git a/big_vision/run_tpu.sh b/big_vision/run_tpu.sh new file mode 100644 index 0000000000000000000000000000000000000000..3c3da2e44e7d2829a00188f5e6177ea9d6e3ba4d --- /dev/null +++ b/big_vision/run_tpu.sh @@ -0,0 +1,35 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/bin/bash + +if [ ! -d "bv_venv" ] +then + sudo apt-get update + sudo apt install -y python3-venv + python3 -m venv bv_venv + . bv_venv/bin/activate + + pip install -U pip # Yes, really needed. + # NOTE: doesn't work when in requirements.txt -> cyclic dep + pip install "jax[tpu]>=0.4.25" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html + pip install -r big_vision/requirements.txt +else + . bv_venv/bin/activate +fi + +if [ $# -ne 0 ] +then + env TFDS_DATA_DIR=$TFDS_DATA_DIR BV_JAX_INIT=1 python3 -m "$@" +fi diff --git a/big_vision/sharding.py b/big_vision/sharding.py new file mode 100644 index 0000000000000000000000000000000000000000..be76cb3a1f6b8bc0e494515bac2a54528a53494c --- /dev/null +++ b/big_vision/sharding.py @@ -0,0 +1,197 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Big vision sharding utilities.""" + +from absl import logging + +from big_vision.pp.registry import Registry +import big_vision.utils as u +import flax.linen as nn +import jax +import numpy as np + + +NamedSharding = jax.sharding.NamedSharding +P = jax.sharding.PartitionSpec + + +def _replicated(mesh): + return NamedSharding(mesh, P()) + + +def _shard_along_axis(mesh, i, axis_name): + return NamedSharding(mesh, P(*((None,) * i + (axis_name,)))) + + +def infer_sharding(params, strategy, mesh): + """Infers `params` sharding based on strategy. + + Args: + params: a pytree of arrays. + strategy: sharding strategy. + mesh: jax device mesh. + + Returns: + A pytree with shardings, that has the same shape as the `tree` argument. + """ + patterns, tactics = zip(*strategy) + + x_with_names, tree_def = u.tree_flatten_with_names(params) + names = tree_def.unflatten(list(zip(*x_with_names))[0]) + + # Follows big_vision conventions: each variable is matched at most once, + # early patterns get matching priority. + mask_trees = u.make_mask_trees(params, patterns) + + specs = jax.tree.map(lambda x: (None,) * x.ndim, params) + + for mask_tree, tactic in zip(mask_trees, tactics): + for op_str in tactic.split("|"): + op = Registry.lookup(f"shardings.{op_str}")() + specs = jax.tree.map( + lambda x, n, match, spec, op=op: op(spec, mesh, n, x) + if match else spec, + params, names, mask_tree, specs, + is_leaf=lambda v: isinstance(v, nn.Partitioned)) + + # Two-level tree_map to prevent it from doing traversal inside the spec. + specs = jax.tree.map(lambda _, spec: P(*spec), nn.unbox(params), specs) + return jax.tree.map(lambda spec: NamedSharding(mesh, spec), specs) + + +# Sharding rules +# +# Each rule needs to be added to the registry, can accept custom args, and +# returns a function that updates the current spec. The arguments are: +# 1. Variable name +# 2. Variable itself (or placeholder with .shape and .dtype properties) +# 3. The current sharing spec. + + +@Registry.register("shardings.replicate") +def replicate(): + """Full replication sharding rule. + + Note full replication is deafult, so this can be skipped and useful to + explicitly state in the config that certrain parameters are replicated. + TODO: can be generalized to support replication over a sub-mesh. + + Returns: + A function that updates the sharding spec. + """ + def _update_spec(cur_spec, mesh, name, x): + del x, mesh + if not all(axis is None for axis in cur_spec): + raise ValueError(f"Inconsistent sharding instructions: " + f"parameter {name} has spec {cur_spec}, " + f"so it can't be fully replicated.") + return cur_spec + return _update_spec + + +@Registry.register("shardings.fsdp") +def fsdp(axis, min_size_to_shard_mb=4): + """FSDP sharding rule. + + Shards the largest dimension that is not sharded already and is divisible + by the total device count. + + Args: + axis: mesh axis name for FSDP, or a collection of names. + min_size_to_shard_mb: minimal tensor size to bother with sharding. + + Returns: + A function that updates the sharding spec. + """ + axis = axis if isinstance(axis, str) else tuple(axis) + axis_tuple = axis if isinstance(axis, tuple) else (axis,) + def _update_spec(cur_spec, mesh, name, x): + shape = x.shape + axis_size = np.prod([mesh.shape[a] for a in axis_tuple]) + + if np.prod(shape) * x.dtype.itemsize <= min_size_to_shard_mb * (2 ** 20): + return cur_spec + + # Partition along largest axis that is divisible and not taken. + idx = np.argsort(shape)[::-1] + for i in idx: + if shape[i] % axis_size == 0: + if cur_spec[i] is None: + return cur_spec[:i] + (axis,) + cur_spec[i+1:] + + logging.info("Failed to apply `fsdp` rule to the parameter %s:%s, as all " + "its dimensions are not divisible by the requested axis: " + "%s:%i, or already occupied by other sharding rules: %s", + name, shape, axis, axis_size, cur_spec) + return cur_spec + return _update_spec + + +@Registry.register("shardings.logical_partitioning") +def logical_partitioning(): + """Manual sharding based on Flax's logical partitioning annotations. + + Uses logical sharding annotations added in model code with + `nn.with_logical_partitioning`. Respects logical to mesh name mapping rules + (typically defined in the dynamic context using + `with nn.logical_axis_rules(rules): ...`). + + Returns: + A function that outputs the sharding spec of `nn.LogicallyPartitioned` boxed + specs. + """ + def _update_spec(cur_spec, mesh, name, x): + del x, name, mesh + if isinstance(cur_spec, nn.LogicallyPartitioned): + return nn.logical_to_mesh_axes(cur_spec.names) + return cur_spec + return _update_spec + + +@Registry.register("shardings.shard_dim") +def shard_dim(axis, dim, ignore_ndim_error=False): + """Shards the given dimension along the given axis. + + Args: + axis: mesh axis name for sharding. + dim: dimension to shard (can be negative). + ignore_ndim_error: if True, a warning error is logged instead of raising an + exception when the given dimension is not compatible with the number of + dimensions of the array. + + Returns: + A function that updates the sharding spec. + """ + def _update_spec(cur_spec, mesh, name, x): + del mesh, x + if np.abs(dim) >= len(cur_spec): + msg = f"Cannot shard_dim({axis}, {dim}): name={name} cur_spec={cur_spec}" + if ignore_ndim_error: + logging.warning(msg) + return cur_spec + else: + raise ValueError(msg) + pos_dim = dim + if pos_dim < 0: + pos_dim += len(cur_spec) + if cur_spec[pos_dim] is not None: + raise ValueError( + f"Already sharded: shard_dim({axis}, {dim}):" + f" name={name} cur_spec={cur_spec}" + ) + new_spec = cur_spec[:pos_dim] + (axis,) + cur_spec[pos_dim + 1 :] + return new_spec + + return _update_spec diff --git a/big_vision/tools/download_tfds_datasets.py b/big_vision/tools/download_tfds_datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..b64c33d51a7cb8df063d15e828e30b07007ff6b0 --- /dev/null +++ b/big_vision/tools/download_tfds_datasets.py @@ -0,0 +1,44 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Download and prepare TFDS datasets for the big_vision codebase. + +This python script covers cifar10, cifar100, oxford_iiit_pet +and oxford_flowers10. + +If you want to integrate other public or custom datasets, please follow: +https://www.tensorflow.org/datasets/catalog/overview +""" + +from absl import app +import tensorflow_datasets as tfds + + +def main(argv): + if len(argv) > 1 and "download_tfds_datasets.py" in argv[0]: + datasets = argv[1:] + else: + datasets = [ + "cifar10", + "cifar100", + "oxford_iiit_pet", + "oxford_flowers102", + "imagenet_v2", + ] + for d in datasets: + tfds.load(name=d, download=True) + + +if __name__ == "__main__": + app.run(main) diff --git a/big_vision/tools/eval_only.py b/big_vision/tools/eval_only.py new file mode 100644 index 0000000000000000000000000000000000000000..abdde4a6c0aa656a2e8ec76ce645982a2a6723b3 --- /dev/null +++ b/big_vision/tools/eval_only.py @@ -0,0 +1,146 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Script that loads a model and only runs evaluators.""" + +from functools import partial +import importlib + +import os + +from absl import app +from absl import flags +from absl import logging +import big_vision.evaluators.common as eval_common +import big_vision.utils as u +from clu import parameter_overview +import flax +import flax.jax_utils as flax_utils +import jax +import jax.numpy as jnp +from ml_collections import config_flags +from tensorflow.io import gfile + + +config_flags.DEFINE_config_file( + "config", None, "Training configuration.", lock_config=True) + +flags.DEFINE_string("workdir", default=None, help="Work unit directory.") +flags.DEFINE_boolean("cleanup", default=False, + help="Delete workdir (only) after successful completion.") + +# Adds jax flags to the program. +jax.config.parse_flags_with_absl() + + +def main(argv): + del argv + + config = flags.FLAGS.config + workdir = flags.FLAGS.workdir + logging.info("Workdir: %s", workdir) + + # Here we register preprocessing ops from modules listed on `pp_modules`. + for m in config.get("pp_modules", ["ops_general", "ops_image"]): + importlib.import_module(f"big_vision.pp.{m}") + + # These functions do more stuff internally, for OSS release we mock them by + # trivial alternatives in order to minize disruptions in the code. + xid, wid = -1, -1 + def write_note(note): + if jax.process_index() == 0: + logging.info("NOTE: %s", note) + + mw = u.BigVisionMetricWriter(xid, wid, workdir, config) + u.chrono.inform(measure=mw.measure, write_note=write_note) + + write_note(f"Initializing {config.model_name} model...") + assert config.get("model.reinit") is None, ( + "I don't think you want any part of the model to be re-initialized.") + model_mod = importlib.import_module(f"big_vision.models.{config.model_name}") + model_kw = dict(config.get("model", {})) + if "num_classes" in config: # Make it work for regular + image_text. + model_kw["num_classes"] = config.num_classes + model = model_mod.Model(**model_kw) + + # We want all parameters to be created in host RAM, not on any device, they'll + # be sent there later as needed, otherwise we already encountered two + # situations where we allocate them twice. + @partial(jax.jit, backend="cpu") + def init(rng): + input_shapes = config.get("init_shapes", [(1, 224, 224, 3)]) + input_types = config.get("init_types", [jnp.float32] * len(input_shapes)) + dummy_inputs = [jnp.zeros(s, t) for s, t in zip(input_shapes, input_types)] + things = flax.core.unfreeze(model.init(rng, *dummy_inputs)) + return things.get("params", {}) + + with u.chrono.log_timing("z/secs/init"): + params_cpu = init(jax.random.PRNGKey(42)) + if jax.process_index() == 0: + parameter_overview.log_parameter_overview(params_cpu, msg="init params") + num_params = sum(p.size for p in jax.tree.leaves(params_cpu)) + mw.measure("num_params", num_params) + + # The use-case for not loading an init is testing and debugging. + if config.get("model_init"): + write_note(f"Initialize model from {config.model_init}...") + params_cpu = model_mod.load( + params_cpu, config.model_init, config.get("model"), + **config.get("model_load", {})) + if jax.process_index() == 0: + parameter_overview.log_parameter_overview(params_cpu, msg="loaded params") + + write_note("Replicating...") + params_repl = flax_utils.replicate(params_cpu) + + def predict_fn(params, *a, **kw): + return model.apply({"params": params}, *a, **kw) + + evaluators = eval_common.from_config( + config, {"predict": predict_fn, "model": model}, + lambda s: write_note(f"Initializing evaluator: {s}..."), + lambda key, cfg: 1, # Ignore log_steps, always run. + ) + + # Allow running for multiple steps can be useful for couple cases: + # 1. non-deterministic evaluators + # 2. warmup when timing evaluators (eg compile cache etc). + for s in range(config.get("eval_repeats", 1)): + mw.step_start(s) + for (name, evaluator, _, prefix) in evaluators: + write_note(f"{name} evaluation step {s}...") + with u.profile(name, noop=name in config.get("no_profile", [])): + with u.chrono.log_timing(f"z/secs/eval/{name}"): + for key, value in evaluator.run(params_repl): + mw.measure(f"{prefix}{key}", value) + u.sync() # sync barrier to get correct measurements + u.chrono.flush_timings() + mw.step_end() + + write_note("Done!") + mw.close() + + # Make sure all hosts stay up until the end of main. + u.sync() + + if workdir and flags.FLAGS.cleanup and jax.process_index() == 0: + gfile.rmtree(workdir) + try: # Only need this on the last work-unit, if already empty. + gfile.remove(os.path.join(workdir, "..")) + except tf.errors.OpError: + pass + + +if __name__ == "__main__": + app.run(main) diff --git a/big_vision/tools/lit_demo/README.md b/big_vision/tools/lit_demo/README.md new file mode 100644 index 0000000000000000000000000000000000000000..c7c1825219944bbf8141eeea4b8bbca0bc1f6738 --- /dev/null +++ b/big_vision/tools/lit_demo/README.md @@ -0,0 +1,26 @@ +# LiT-Demo + +See https://blog.tensorflow.org/2022/08/jax-on-web-with-tensorflowjs.html + +Demo originally appeared on Twitter +https://twitter.com/AndreasPSteiner/status/1514722383818543106 + +App published at +https://google-research.github.io/vision_transformer/lit + +## Build + +Install packages (tested with node v16.17.0 and yarn 1.22.19) + +```bash +yarn +``` + + +## Run + +The web app will appear on http://localhost:8000 + +``` +node build.js +``` diff --git a/big_vision/tools/lit_demo/build.js b/big_vision/tools/lit_demo/build.js new file mode 100644 index 0000000000000000000000000000000000000000..a44aa8aa7b05c7ecb5acc4f7478a175897f15b9d --- /dev/null +++ b/big_vision/tools/lit_demo/build.js @@ -0,0 +1,39 @@ +/** + * @license + * Copyright Big Vision Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +const sassPlugin = require('esbuild-sass-plugin').sassPlugin; + +require('esbuild').serve({ + servedir: 'src', + port: 8000, +}, { + entryPoints: ['src/app.ts'], + bundle: true, + outfile: 'src/index.js', + plugins: [ + sassPlugin({ + filter: /style.scss$/, + type: 'style' + }), + sassPlugin({ + type: 'lit-css', + }), + ], + sourcemap: true, +}).then(() => { + console.log('Serving on port 8000'); +}).catch(() => process.exit(1)); diff --git a/big_vision/tools/lit_demo/package.json b/big_vision/tools/lit_demo/package.json new file mode 100644 index 0000000000000000000000000000000000000000..838cf1ff349f7496fc8edd43aa207b2c00174d20 --- /dev/null +++ b/big_vision/tools/lit_demo/package.json @@ -0,0 +1,54 @@ +{ + "name": "lit-demo", + "version": "0.0.2", + "description": "", + "main": "src/app.ts", + "license": "Apache-2.0", + "private": true, + "engines": { + "node": ">=8.9.0" + }, + "scripts": { + "serve": "node build.js", + "test": "ts-node --skip-ignore --project tsconfig.test.json run_tests.ts" + }, + "devDependencies": { + "@babel/core": "^7.7.5", + "@babel/plugin-transform-runtime": "^7.7.6", + "@babel/polyfill": "^7.10.4", + "@babel/preset-env": "^7.7.6", + "@tensorflow/tfjs-backend-cpu": "^3.15.0", + "@tensorflow/tfjs-backend-webgl": "^3.15.0", + "@tensorflow/tfjs-converter": "3.20.0", + "@tensorflow/tfjs-core": "3.20.0", + "babel-preset-env": "^1.7.0", + "esbuild": "^0.15.5", + "esbuild-sass-plugin": "^2.3.2", + "jasmine": "^3.3.1", + "lit": "^2.3.1", + "naughty-words": "^1.2.0", + "sass": "^1.50.0", + "ts-node": "~5.0.0", + "typescript": "4.1.3" + }, + "resolutions": { + "is-svg": "4.3.1" + }, + "eslintConfig": { + "extends": "google", + "rules": { + "require-jsdoc": 0, + "valid-jsdoc": 0 + }, + "env": { + "es6": true + }, + "parserOptions": { + "ecmaVersion": 8, + "sourceType": "module" + } + }, + "eslintIgnore": [ + "dist/" + ] +} diff --git a/big_vision/tools/lit_demo/src/app.ts b/big_vision/tools/lit_demo/src/app.ts new file mode 100644 index 0000000000000000000000000000000000000000..3fccbc940cd173826dad6c6d25fd11a2c177ce16 --- /dev/null +++ b/big_vision/tools/lit_demo/src/app.ts @@ -0,0 +1,22 @@ +/** + * @license + * Copyright Big Vision Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {LitDemoApp} from './components/lit-demo-app'; +import './style.scss'; + +// tslint:disable-next-line:no-any +(window as any).LitDemoApp = LitDemoApp; diff --git a/big_vision/tools/lit_demo/src/components/image-carousel.scss b/big_vision/tools/lit_demo/src/components/image-carousel.scss new file mode 100644 index 0000000000000000000000000000000000000000..2da94515d9d3627611c4f3d2b6daf6762916e42d --- /dev/null +++ b/big_vision/tools/lit_demo/src/components/image-carousel.scss @@ -0,0 +1,32 @@ +@import '../style/mixins'; + +.selector { + overflow: scroll; + padding-bottom: 10px; // OS X scroll bar + + .inner { + white-space: nowrap; + + .thumb { + display: inline-block; + + img { + cursor: pointer; + + width: 20vmin; + height: 20vmin; + max-width: 200px; + max-height: 200px; + + @include phone-portrait { + width: 33vmin; + height: 33vmin; + } + + margin: 10px; + + box-shadow: 0 0 10px #888; + } + } + } +} diff --git a/big_vision/tools/lit_demo/src/components/image-carousel.ts b/big_vision/tools/lit_demo/src/components/image-carousel.ts new file mode 100644 index 0000000000000000000000000000000000000000..b392d5b95a3a048deb370fa68cac460bf62f9be2 --- /dev/null +++ b/big_vision/tools/lit_demo/src/components/image-carousel.ts @@ -0,0 +1,70 @@ +/** + * @license + * Copyright Big Vision Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * @fileoverview Carousel of images. + */ + +import {html, LitElement} from 'lit'; + +import {app} from '../lit_demo/app'; +import {getImageUrl} from '../lit_demo/constants'; +import {ImageRow} from '../lit_demo/data'; + +import {customElement} from 'lit/decorators.js'; +import styles from './image-carousel.scss'; + +/** + * Shows multiple images in a horizontal carousel. + * + * Dispatches `'image-select'` event when an image is clicked/tapped. + */ +@customElement('image-carousel') +export class ImageCarousel extends LitElement { + static override styles = [styles]; + + onClick(id: string) { + const event = + new CustomEvent('image-select', {composed: true, detail: {id}}); + this.dispatchEvent(event); + } + + override render() { + const images = app.imageData.rows.map( + (row: ImageRow) => html` +
+ { + this.onClick(row.id); + }} data-id=${row.id} src="${getImageUrl(row.id)}"> +
+ `); + return html` +
+
+ ${images} +
+
+

Select an image 👆 to get started.

+ `; + } +} + +declare global { + interface HTMLElementTagNameMap { + 'image-carousel': ImageCarousel; + } +} diff --git a/big_vision/tools/lit_demo/src/components/image-prompts.scss b/big_vision/tools/lit_demo/src/components/image-prompts.scss new file mode 100644 index 0000000000000000000000000000000000000000..66cd06817ee5d2f5e4af4c392052d357bc4e3456 --- /dev/null +++ b/big_vision/tools/lit_demo/src/components/image-prompts.scss @@ -0,0 +1,124 @@ +@import '../style/mixins'; + +.image-prompt { + display: flex; + gap: 1.5em; + align-items: flex-start; + margin-top: 2rem; + + @include phone-portrait { + align-items: center; + flex-direction: column; + gap: 0; + margin-bottom: 5rem; + } + + .left { + display: flex; + flex-direction: column; + + .wrapper { + position: relative; + + .src { + position: absolute; + right: 2rem; + bottom: 2rem; + color: white; + font-size: 1.5rem; + text-shadow: 2px 2px black; + text-decoration: none; + } + } + + .animation { + position: relative; + width: 224px; + height: 15px; + opacity: 0; + + .computing { + text-align: center; + } + } + } + + .right { + display: flex; + flex-grow: 1; + flex-direction: column; + gap: 0.5em; + + .top { + text-align: right; + height: 30px; + } + + .buttons { + display: flex; + flex-wrap: wrap; + justify-content: flex-end; + gap: 1em; + align-items: center; + } + + .item { + position: relative; + display: flex; + + .pct { + display: inline-block; + margin-right: 1em; + width: 3.5em; + text-align: right; + opacity: 0; + transition: opacity 0.5s; + } + + input { + flex-grow: 1; + max-width: 70vw; + border-radius: 0; + background: transparent; + border: 0; + border-bottom: 1px solid var(--text-fg); + color: var(--text-fg); + outline: none; + + &.toolong { + border-bottom: 1px solid var(--text-red); + color: var(--text-red); + } + } + + .bar { + position: absolute; + display: inline-block; + top: 5%; + left: 0; + z-index: -1; + background: var(--bar-col); + height: 90%; + width: 0; + transition: width 0.5s; + } + } + + .bottom { + display: flex; + flex-wrap: wrap; + justify-content: flex-end; + gap: 1em; + align-items: center; + opacity: 0; + + .tweet { + background: rgb(18, 150, 223); + color: white; + text-decoration: none; + padding: 0px 15px; + border-radius: 16px; + } + } + } +} diff --git a/big_vision/tools/lit_demo/src/components/image-prompts.ts b/big_vision/tools/lit_demo/src/components/image-prompts.ts new file mode 100644 index 0000000000000000000000000000000000000000..823f166518f5011d17d560b6b1e246060118863f --- /dev/null +++ b/big_vision/tools/lit_demo/src/components/image-prompts.ts @@ -0,0 +1,250 @@ +/** + * @license + * Copyright Big Vision Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * @fileoverview Image and text prompts. + */ + +import {html, LitElement} from 'lit'; +import * as naughtyWords from 'naughty-words'; + +import {app} from '../lit_demo/app'; +import {getBackend} from '../lit_demo/compute'; +import {getImageUrl} from '../lit_demo/constants'; +import {getUrl} from '../lit_demo/url_utils'; + +import {MessageList} from './message-list'; + +import {customElement, query} from 'lit/decorators.js'; +import styles from './image-prompts.scss'; + +const setHref = (anchorEl: HTMLAnchorElement, href:string) => { + anchorEl. href = href; +}; + +const HTML_TEMPLATE = ` + We cannot include the word "{word}" as it is found on the list + naughty-words/{lang}. + We understand blocklists are an imperfect solution but we believe it's + important to ensure these models are not misused, and hope that in this + instance it does not serve to marginalise anybody. If you don't agree, + please reach out via + form link. +`; + + +/** + * Shows image and text prompts, and computes similarities. + * + * Also dispatches some events like `'duplicate'` to the parent. + */ +@customElement('image-prompts') +export class ImagePrompts extends LitElement { + + static override styles = [styles]; + + @query('message-list') + messageList!: MessageList; + @query('.animation') + animation!: HTMLElement; + @query('.bottom') + bottom!: HTMLElement; + + lastPrompts?: string[]; + + constructor(private readonly imageId: string) { + super(); + } + + override firstUpdated() { + if (getBackend() !== 'webgl') { + this.messageList.warning( + 'Please activate WebGL. Running ML demos on ' + + 'CPU will drain your battery in no time...'); + } + } + + onDuplicate() { + this.dispatchEvent(new Event('duplicate')); + } + + onRemove() { + this.remove(); + } + + onClear() { + this.shadowRoot!.querySelectorAll('.prompt').forEach((input: Element) => { + (input as HTMLInputElement).value = ''; + }); + (this.shadowRoot!.querySelector('.prompt') as HTMLInputElement).focus(); + } + + onKeyup(event: KeyboardEvent) { + if (event.key === 'Enter') { + this.onCompute(); + } + } + + async setPrompts(prompts: string[]) { + await this.updateComplete; + this.shadowRoot!.querySelectorAll('.prompt').forEach((input: Element, idx: number) => { + (input as HTMLInputElement).value = prompts[idx] || ''; + }); + } + + getPrompts(): string[] { + return [...this.shadowRoot!.querySelectorAll('.prompt')].map((input: Element) => + (input as HTMLInputElement).value + ); + } + + override render() { + const row = app.imageData.get(this.imageId); + const inputs = row.prompts.split(',').map((prompt: string, idx: number) => { + return html` +
+
+ +
+
+ `; + }); + return html` +
+
+
+ + source +
+
+
✨✨Computing✨✨
+
+
+
+ +
+ + + + +
+ ${inputs} +
+ Model: ? + tweet +
+
+
+ `; + } + + onCompute() { + if (!app.models.ready) { + this.messageList.warning('Model not ready yet.'); + return; + } + + const model = app.models.model!; + const zimgIdx = model.zimgIds!.indexOf(this.imageId); + if (zimgIdx === -1) { + this.messageList.warning('Model is missing this image embedding'); + return; + } + + const texts = this.getPrompts(); + for (const text of texts) { + for (const word of text.toLocaleLowerCase().split(/\s+/g)) { + // tslint:disable-next-line:ban-module-namespace-object-escape + for (const lang of Object.keys(naughtyWords)) { + if (lang === 'default') { + continue; + } + // tslint:disable-next-line:ban-module-namespace-object-escape + const words = (naughtyWords as {[key: string]: string[]})[lang]; + if (words.indexOf(word) !== -1) { + const msg = HTML_TEMPLATE.replace(/\{word\}/g, word).replace(/\{lang\}/g, lang); + this.messageList.warning(msg, {rawHtml: true}); + return; + } + } + } + } + + const compute = () => { + let probs: number[]|undefined; + try { + // ??? how to move into webworker (to avoid freezing UI) ? + // https://github.com/tensorflow/tfjs/issues/102 + probs = model.computeProbabilities(texts, zimgIdx); + } catch (error) { + if ((error as Error).message.toLocaleLowerCase().match(/greater than .* maximum/)) { + this.messageList.warning('Model too large for Browser!'); + return; + } + throw error; + } + this.setProbabilities(probs); + this.lastPrompts = this.getPrompts(); + this.animation.style.opacity = '0'; + }; + + this.animation.style.opacity = '1'; + this.messageList.clear(); + setTimeout(compute, 10); // Give UI some time to update. + } + + setProbabilities(probs: number[]) { + const pcts = [...this.shadowRoot!.querySelectorAll('.pct')] as HTMLElement[]; + const bars = [...this.shadowRoot!.querySelectorAll('.bar')] as HTMLElement[]; + this.hideBottom(); + for(let i = 0; i < Math.max(probs.length, pcts.length, bars.length); i++) { + const prob = probs[i] || 0; + const pct = `${Math.round(prob * 1e3) / 1e1}%`; + bars[i].style.width = pct; + if (prob) { + pcts[i].innerText = pct; + pcts[i].style.opacity = '1'; + } else { + pcts[i].style.opacity = '0'; + } + } + this.updateBottom(); + } + + updateBottom() { + const tweet = this.shadowRoot!.querySelector('.tweet') as HTMLAnchorElement; + const url = getUrl(app.models.model!.name, this.imageId, this.getPrompts()); + const description = app.imageData.get(this.imageId).description; + const text = `LiT matching prompts to an image of "${description}"\n\n#lit_demo\n`; + setHref(tweet, 'https://twitter.com/intent/tweet' + + '?url=' + encodeURIComponent(url) + + '&text=' + encodeURIComponent(text)); + this.bottom.style.opacity = '1'; + const model = this.shadowRoot!.querySelector('.model') as HTMLAnchorElement; + model.innerText = app.models.model!.name; + } + + hideBottom() { + this.bottom.style.opacity = '0'; + } +} + +declare global { + interface HTMLElementTagNameMap { + 'image-prompts': ImagePrompts; + } +} diff --git a/big_vision/tools/lit_demo/src/components/lit-demo-app.scss b/big_vision/tools/lit_demo/src/components/lit-demo-app.scss new file mode 100644 index 0000000000000000000000000000000000000000..884aeb635db195350a8316c95648942be939d4ee --- /dev/null +++ b/big_vision/tools/lit_demo/src/components/lit-demo-app.scss @@ -0,0 +1,3 @@ +.loading-container { + text-align: center; +} diff --git a/big_vision/tools/lit_demo/src/components/lit-demo-app.ts b/big_vision/tools/lit_demo/src/components/lit-demo-app.ts new file mode 100644 index 0000000000000000000000000000000000000000..41f42f484c2ccb79b42858af45d1ddfd3ee304f6 --- /dev/null +++ b/big_vision/tools/lit_demo/src/components/lit-demo-app.ts @@ -0,0 +1,127 @@ +/** + * @license + * Copyright Big Vision Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * @fileoverview Main application. + */ + +import {html, LitElement} from 'lit'; + +import {app} from '../lit_demo/app'; +import {parseUrl, State} from '../lit_demo/url_utils'; + +import './image-carousel'; +import {ImagePrompts} from './image-prompts'; +import './loading-animation'; +import {MessageList} from './message-list'; +import {ModelControls} from './model-controls'; + +import {customElement, property, query} from 'lit/decorators.js'; +import styles from './lit-demo-app.scss'; + +/** + * Main application container. + */ +@customElement('lit-demo-app') +export class LitDemoApp extends LitElement { + + static override styles = [styles]; + + @property({type: Boolean}) + loading: boolean = true; + + @query('message-list') + messageList!: MessageList; + @query('model-controls') + modelControls!: ModelControls; + @query('#examples') + examples!: HTMLElement; + + state?: State; + lingeringWarning?: string; + + constructor() { + super(); + window.onerror = this.onglobalerror.bind(this); + this.load(); + } + + onglobalerror(message: string|Event, source: string|undefined, lineno: number|undefined) { + source = source || ''; + source = source.substring(source.lastIndexOf('/') + 1); + this.messageList.error( + `Javascript error at ${source}:${lineno}
` + + `${message}`, + {rawHtml: true}); + } + + async load() { + await app.load(); + this.loading = false; + try { + this.state = parseUrl(); + } catch (error) { + this.messageList.warning(`Could not parse URL: ${error}`); + } + } + + override updated() { + if (this.state && this.examples) { + this.modelControls.setModel(this.state.modelName); + this.addFromState(this.state); + this.state = undefined; + } + } + + override render() { + return html` + ${this.loading ? html`` : html` + + + `} + + ${this.loading ? html` +
+ +
+ ` : html` +
+
+ `} + `; + } + + onImageSelect(event: CustomEvent) { + this.addImagePrompts(event.detail.id); + } + + addFromState(state: State) { + const imagePrompts = new ImagePrompts(state.imageId); + imagePrompts.setPrompts(state.prompts); + this.examples.insertBefore(imagePrompts, this.examples.childNodes[0]); + } + + addImagePrompts(id: string): ImagePrompts { + const imagePrompts = new ImagePrompts(id); + imagePrompts.addEventListener('duplicate', (event: Event) => { + const duplicated = this.addImagePrompts(id); + duplicated.setPrompts(imagePrompts.getPrompts()); + }); + this.examples.insertBefore(imagePrompts, this.examples.childNodes[0]); + return imagePrompts; + } +} diff --git a/big_vision/tools/lit_demo/src/components/loading-animation.scss b/big_vision/tools/lit_demo/src/components/loading-animation.scss new file mode 100644 index 0000000000000000000000000000000000000000..032be7a31f2a07138f509c7c6a85063c0d25864a --- /dev/null +++ b/big_vision/tools/lit_demo/src/components/loading-animation.scss @@ -0,0 +1,65 @@ +// CC0 from https://loading.io/css/ + +@import '../style/colors'; + +.lds-ellipsis { + display: inline-block; + position: relative; + width: 80px; + height: 80px; + + div { + position: absolute; + top: 33px; + width: 13px; + height: 13px; + border-radius: 50%; + background: var(--text-fg); + animation-timing-function: cubic-bezier(0, 1, 1, 0); + } + + div:nth-child(1) { + left: 8px; + animation: lds-ellipsis1 0.6s infinite; + } + + div:nth-child(2) { + left: 8px; + animation: lds-ellipsis2 0.6s infinite; + } + + div:nth-child(3) { + left: 32px; + animation: lds-ellipsis2 0.6s infinite; + } + + div:nth-child(4) { + left: 56px; + animation: lds-ellipsis3 0.6s infinite; + } +} + +@keyframes lds-ellipsis1 { + 0% { + transform: scale(0); + } + 100% { + transform: scale(1); + } +} +@keyframes lds-ellipsis3 { + 0% { + transform: scale(1); + } + 100% { + transform: scale(0); + } +} +@keyframes lds-ellipsis2 { + 0% { + transform: translate(0, 0); + } + 100% { + transform: translate(24px, 0); + } +} diff --git a/big_vision/tools/lit_demo/src/components/loading-animation.ts b/big_vision/tools/lit_demo/src/components/loading-animation.ts new file mode 100644 index 0000000000000000000000000000000000000000..8af9de23b54d3b6b46ad5d9c8a7ac75bd344ac22 --- /dev/null +++ b/big_vision/tools/lit_demo/src/components/loading-animation.ts @@ -0,0 +1,51 @@ +/** + * @license + * Copyright Big Vision Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * @fileoverview Carousel of images. + */ + +import {html, LitElement} from 'lit'; + +import {customElement} from 'lit/decorators.js'; +import styles from './loading-animation.scss'; + +/** + * Shows an animated loading animation. + */ +@customElement('loading-animation') +export class LoadingAnimation extends LitElement { + + static override styles = [styles]; + + override render() { + return html` +
+
+
+
+
+
+ `; + } +} + +declare global { + interface HTMLElementTagNameMap { + 'loading-animation': LoadingAnimation; + } +} diff --git a/big_vision/tools/lit_demo/src/components/message-list.scss b/big_vision/tools/lit_demo/src/components/message-list.scss new file mode 100644 index 0000000000000000000000000000000000000000..79aa76465b2120d256c14c3a4583709a050e2be4 --- /dev/null +++ b/big_vision/tools/lit_demo/src/components/message-list.scss @@ -0,0 +1,26 @@ +@import '../style/colors'; + +.message { + padding: 0.1rem 0.5rem; + margin-bottom: 1rem; +} + +.warning { + background: var(--warn-bg); + color: var(--warn-fg); +} + +.error { + background: var(--error-bg); + color: var(--error-fg); +} + +.info { + background: var(--note-bg); + color: var(--note-fg); +} + +.close { + float: right; + cursor: pointer; +} diff --git a/big_vision/tools/lit_demo/src/components/message-list.ts b/big_vision/tools/lit_demo/src/components/message-list.ts new file mode 100644 index 0000000000000000000000000000000000000000..542bfb0f9502d55c65ff131c43684246192cb6f7 --- /dev/null +++ b/big_vision/tools/lit_demo/src/components/message-list.ts @@ -0,0 +1,97 @@ +/** + * @license + * Copyright Big Vision Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * @fileoverview A list of dismissable info/warning/error messages. + */ + +import {html, LitElement} from 'lit'; + +import {unsafeHTML} from 'lit/directives/unsafe-html.js'; + +import {customElement} from 'lit/decorators.js'; +import styles from './message-list.scss'; + +enum MessageType { + INFO = 'info', + WARNING = 'warning', + ERROR = 'error', +} + +interface Message { + message: string; + type: MessageType; + rawHtml: boolean; +} + + +/** + * Shows info/warning/error messages that remain until closed by user. + */ +@customElement('message-list') +export class MessageList extends LitElement { + static override styles = [styles]; + + messages: Message[] = []; + + addMessage(message: Message) { + this.messages.push(message); + this.requestUpdate(); + } + + info(message: string, {rawHtml = false}: {rawHtml?: boolean} = {}) { + this.addMessage({message, type: MessageType.INFO, rawHtml}); + } + + warning(message: string, {rawHtml = false}: {rawHtml?: boolean} = {}) { + this.addMessage({message, type: MessageType.WARNING, rawHtml}); + } + + error(message: string, {rawHtml = false}: {rawHtml?: boolean} = {}) { + this.addMessage({message, type: MessageType.ERROR, rawHtml}); + } + + removeMessage(event: Event, idx: number) { + this.messages.splice(idx, 1); + (event.target! as HTMLElement).closest('.message')!.remove(); + } + + clear() { + this.messages = []; + while (this.firstChild) this.firstChild.remove(); + } + + override render() { + return this.messages.map( + (message: Message, idx: number) => html` +
+ ${ + message.rawHtml ? unsafeHTML(message.message) : + message.message} + { + this.removeMessage(e, idx); + }} class="close">✖ +
+ `); + } +} + +declare global { + interface HTMLElementTagNameMap { + 'message-list': MessageList; + } +} diff --git a/big_vision/tools/lit_demo/src/components/model-controls.scss b/big_vision/tools/lit_demo/src/components/model-controls.scss new file mode 100644 index 0000000000000000000000000000000000000000..a7627c8cb9d15b825cd3a84cdc217eb2b39cfc05 --- /dev/null +++ b/big_vision/tools/lit_demo/src/components/model-controls.scss @@ -0,0 +1,12 @@ +.controls { + margin: 1em 0; + display: flex; + + select { + margin-left: 0.5em; + } + + progress { + margin: 0 1em; + } +} diff --git a/big_vision/tools/lit_demo/src/components/model-controls.ts b/big_vision/tools/lit_demo/src/components/model-controls.ts new file mode 100644 index 0000000000000000000000000000000000000000..40fdb646c436ef7dfd9332835819f0e9534c067b --- /dev/null +++ b/big_vision/tools/lit_demo/src/components/model-controls.ts @@ -0,0 +1,93 @@ +/** + * @license + * Copyright Big Vision Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * @fileoverview Controls to choose model. + */ + +import {html, LitElement} from 'lit'; + +import {getModels} from '../lit_demo/constants'; +import {app} from '../lit_demo/app'; + +import {customElement, property} from 'lit/decorators.js'; +import styles from './model-controls.scss'; + +/** + * Shows controls for model selection, progress bar, and status text. + */ +@customElement('model-controls') +export class ModelControls extends LitElement { + + static override styles = [styles]; + + @property({attribute: false}) + progress: number = 0; + + @property({attribute: false}) + status: string = 'Initializing...'; + + constructor() { + super(); + app.models.addListener(this.onModelUpdate.bind(this)); + app.models.load(getModels()[0]); + } + + onModelUpdate(progress: number, message?: string) { + this.progress = progress; + if (message) this.status = message; + } + + onModelChange(event: Event) { + const target = event.target as HTMLSelectElement; + const name = target.value; + app.models.load(name).catch((error) => { + this.status = `ERROR loading model "${name}": ${error}`; + }); + } + + async setModel(model: string) { + if (getModels().indexOf(model) === -1) { + throw new Error(`Model "${model}" not found!`); + } + await this.updateComplete; + const dropdown = this.shadowRoot!.querySelector('#model_dropdown') as HTMLSelectElement; + dropdown.value = model; + dropdown.dispatchEvent(new Event('change')); + } + + override render() { + const options = getModels().map((model: string) => + html``); + return html` +
+ + + +
${this.status}
+
+ `; + } +} + +declare global { + interface HTMLElementTagNameMap { + 'model-controls': ModelControls; + } +} diff --git a/big_vision/tools/lit_demo/src/exports.ts b/big_vision/tools/lit_demo/src/exports.ts new file mode 100644 index 0000000000000000000000000000000000000000..b756e172aa6f89692faf1ba295e38ea14c59ccb0 --- /dev/null +++ b/big_vision/tools/lit_demo/src/exports.ts @@ -0,0 +1,38 @@ +/** + * @license + * Copyright Big Vision Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * @fileoverview some useful exports to play around with the models & + * tokenizers. + * + * Simple usage (see ./playground.html for more complete usage example): + * + * model = lit.Model('tiny'); + * model.load(progress => console.log('loading...', progress)); + * console.log(model.computeProbabilities(['a dog', 'a cat'], '0')); + */ + +import {Model} from './lit_demo/compute'; +import {getImageUrl, setBaseUrl} from './lit_demo/constants'; +import {ImageData} from './lit_demo/data'; +import * as tf from '@tensorflow/tfjs-core'; + +// tslint:disable-next-line:no-any Export symbols into global namespace. +(window as any).lit = { Model, getImageUrl, ImageData, setBaseUrl }; +// tslint:disable-next-line:no-any Export symbols into global namespace. +// tslint:disable-next-line:ban-module-namespace-object-escape Export all of TF. +(window as any).tf = tf; diff --git a/big_vision/tools/lit_demo/src/index.html b/big_vision/tools/lit_demo/src/index.html new file mode 100644 index 0000000000000000000000000000000000000000..580e328f1d0efa0b8d682d7d59e20d58b75e2367 --- /dev/null +++ b/big_vision/tools/lit_demo/src/index.html @@ -0,0 +1,80 @@ + + + + + + + Lit Demo App + + + + + + + +

LiT: Zero-Shot Transfer with Locked-image Tuning

+ +

+ This page is an interactive demo of the Google AI blog post + LiT: adding language understanding to image models + – please refer to that page for a detailed explanation of how a LiT model works. + If you're interested in how this demo makes a JAX model run on device in your + browser, check out our other blog post + JAX on the Web with TensorFlow.js. +

+ +

+ Below you can choose an image from a selection and then write free-form + text prompts that are matched to the image. Once you hit return on your + keyboard or press the "compute" button, a text encoder implemented in + TensorFlow.js + will compute embeddings for the provided text on your local device, and the + similarity of these text embeddings to the image embedding will be displayed. +

+ +

+ The prompts can be used to classify an image into multiple categories, listing + each category individually with a prompt "an image of a X". But you can also + probe the model interactively with more detailed prompts, comparing the + different results when small details change in the text. +

+ +

+ Please use this demo responsibly. The models will always compare the image to + the prompts you provide, and it is therefore trivial to construct situations + where the model picks from a bunch of bad options. +

+ +

+ Note: + The models available in this interactive demo are not those from the + paper. + We had to train much smaller text towers and tokenizers to avoid + overloading your browser. Please see + our GitHub repository + for the models from the paper pre-trained on public datasets. + Multilingual models coming soon. +

+ + + + \ No newline at end of file diff --git a/big_vision/tools/lit_demo/src/lit_demo/app.ts b/big_vision/tools/lit_demo/src/lit_demo/app.ts new file mode 100644 index 0000000000000000000000000000000000000000..0e104f298033933514a575ef8ba796c70b11c4c4 --- /dev/null +++ b/big_vision/tools/lit_demo/src/lit_demo/app.ts @@ -0,0 +1,47 @@ +/** + * @license + * Copyright Big Vision Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * @fileoverview Global app state. + */ + +import {ImageData} from './data'; +import {Models} from './compute'; + +/** + * Container class holding image data and models. + * + * The main application component would typically call `load()` and then show + * the components depending on this class asynchronously. + */ +export class App { + + imageData = new ImageData(); + models = new Models(); + + ready: boolean = false; + + async load() { + await this.imageData.load(); + this.ready = true; + } +} + +/** + * Global app state. + */ +export const app = new App(); diff --git a/big_vision/tools/lit_demo/src/lit_demo/compute.ts b/big_vision/tools/lit_demo/src/lit_demo/compute.ts new file mode 100644 index 0000000000000000000000000000000000000000..06b9c7eb72a4e565047fc6966e56cc67465c6145 --- /dev/null +++ b/big_vision/tools/lit_demo/src/lit_demo/compute.ts @@ -0,0 +1,293 @@ +/** + * @license + * Copyright Big Vision Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * @fileoverview Model code. + */ + +import '@tensorflow/tfjs-backend-webgl'; + +import * as tfconv from '@tensorflow/tfjs-converter'; +import * as tf from '@tensorflow/tfjs-core'; +import {MathBackendWebGL} from '@tensorflow/tfjs-backend-webgl'; + +import {getTokenizer, Tokenizer} from '../tokenizers/index'; + +import {getModelFileUrl} from './constants'; + +/** + * Callback to be updated with model load status. + * + * @param progress: the callback function is repeatedly called with values from + * 0 to 1 (both endpoints are guaranteed to be included) + * @param message: optional message to be displayed to user + */ +export type StatusCallback = (progress: number, message?: string) => void; + +const canonicalize = (s: string) => { + s = s.toLocaleLowerCase(); + s = s.replace(/[^\w ]/g, ''); + s = s.replace(/\s+/g, ' '); + return s.trim(); +}; + +/** + * The model definition is read from a JSON and specifies model details. + */ +// tslint:disable:enforce-name-casing +export interface ModelDefinition { + /** Human-readable description of the model. */ + description: string; + /** Tokenizer type. See ./tokenizers/index */ + tokenizer_type: string; + /** Temperature for computing softmax. */ + temperature: number; + /** Token used for padding. */ + pad_value: number; + /** Maximum token length. */ + max_len: number; + /** Dimensionality of image/text embeddings. */ + embedding_size: number; +} +// tslint:enable:enforce-name-casing + +/** + * TFJS model to compute text embeddings and similarities. + */ +export class Model { + def?: ModelDefinition; + tokenizer?: Tokenizer; + model?: tfconv.GraphModel; + /** Pre-computed image embeddings. */ + zimgs?: tf.Tensor; + /** Pre-computed text embeddings. */ + ztxts?: tf.Tensor; + /** IDs for pre-computed image embeddings. */ + zimgIds?: string[]; + /** Prompts for pre-computed text embeddings. */ + ztxtPrompts?: string[]; + /** Will be set to `true` when `load()` has completed successfully. */ + ready: boolean = false; + + /** + * @param name: Name of the model to be loaded. Will be used to construct the + * model URL. Note that the model must be loaded via calling `load()` before + * it can be used. + */ + constructor(public name: string) { + } + + /** + * Loads model, tokenizer, and pre-computed embeddings. + */ + async load(callback?: StatusCallback) { + this.def = + await fetch(getModelFileUrl(this.name, 'def.json')).then(resp => { + if (resp.ok) return resp.json(); + throw new Error(`Could not load model def: ${resp.status}`); + }); + console.log('def', this.def); + + const tokenizer = fetch(getModelFileUrl(this.name, 'vocabulary.json')) + .then(resp => resp.json()) + .then( + vocabulary => getTokenizer( + this.def!.tokenizer_type, vocabulary)); + + const model = + tfconv.loadGraphModel(getModelFileUrl(this.name, 'tfjs/model.json'), { + onProgress: (progress: number) => { + callback && callback(progress); + } + }); + + const fetchBin = async (name: string) => { + const response = await fetch(getModelFileUrl(this.name, `${name}.bin`)); + const blob = await response.blob(); + const data = await new Promise(resolve => { + const reader = new FileReader(); + reader.addEventListener('loadend', () => { + resolve(reader.result); + }); + reader.readAsArrayBuffer(blob); + }); + const arr = new Float32Array(data as Iterable); + const n = arr.length / this.def!.embedding_size; + return tf.tensor(arr, [n, this.def!.embedding_size]); + }; + const fetchTxt = (name: string) => + fetch(getModelFileUrl(this.name, `${name}.txt`)) + .then(response => response.text()) + .then(text => text.split(/\n/g)); + + [this.tokenizer, + this.model, + this.zimgs, + this.ztxts, + this.zimgIds, + this.ztxtPrompts, + ] = + [ + await tokenizer, + await model, + await fetchBin('zimgs'), + await fetchBin('ztxts'), + await fetchTxt('zimgs'), + await fetchTxt('ztxts'), + ]; + this.ready = true; + await this.warmup(); + if (callback) callback(1, 'Done.'); + } + + private async warmup() { + if (getBackend() !== 'webgl') return; + + const webGLBackend = tf.backend() as MathBackendWebGL; + tf.env().set('ENGINE_COMPILE_ONLY', true); + const tokens = tf.zeros([5, this.def!.max_len], 'int32'); + const preCompileResults = + this.model!.predict({inputs: tokens}) as tf.Tensor; + webGLBackend.checkCompileCompletion(); + webGLBackend.getUniformLocations(); + + tf.env().set('ENGINE_COMPILE_ONLY', false); + const warmUpResults = this.model!.predict({inputs: tokens}) as tf.Tensor; + await warmUpResults.data(); + + preCompileResults.dispose(); + warmUpResults.dispose(); + } + + /** + * Tokenizes strings with the model's tokenizer. + */ + tokenize(texts: string[]): tf.Tensor { + if (!this.ready) throw new Error('Cannot tokenize: not ready'); + const tokenize = (text: string) => { + const maxLen = this.def!.max_len || 16; + const tokens = this.tokenizer!.encode(text).slice(0, maxLen); + // eos="sticky" + const tokenEos = tf.tensor( + [ + ...tokens, + ...new Array(16 - tokens.length).fill(this.def!.pad_value), + ], + undefined, 'int32'); + return tokenEos; + }; + return tf.stack(texts.map(tokenize)); + } + + /** + * Computes embeddings for text tokenized via `tokenize()`. + */ + embed(tokens: tf.Tensor): tf.Tensor { + if (!this.ready) throw new Error('Cannot embed: not ready'); + return this.model!.execute({inputs: tokens}) as tf.Tensor; + } + + /** + * Computes similarities between specified prompts and images. Images are + * referenced by their ID. + */ + computeSimilarities(texts: string[], imgidxs: number[]) { + if (!this.ready) throw new Error('Cannot compute similarities: not ready'); + texts = texts.map(canonicalize); + const precomputed = + texts + .map(text => { + const idx = this.ztxtPrompts!.indexOf(text); + return idx === -1 ? null : tf.slice(this.ztxts!, idx, 1); + }) + .filter((x: tf.Tensor|null) => !!x) as tf.Tensor[]; + console.log(texts.length, 'texts, ', precomputed.length, 'precomputed'); + const textEmbeddings = texts.length === precomputed.length ? + tf.concat(precomputed) : + this.embed(this.tokenize(texts)); + const imageEmbeddingsTransposed = tf.transpose( + tf.concat(imgidxs.map(idx => tf.slice(this.zimgs!, idx, 1)))); + const sims = tf.matMul(textEmbeddings, imageEmbeddingsTransposed); + sims.print(); + return sims; + } + + /** + * Computes probabilities between a set of prompts and a single image + * (identified by its ID). + */ + computeProbabilities(texts: string[], imgidx: number): number[] { + const sims = this.computeSimilarities(texts, [imgidx]); + const row = tf.squeeze(tf.slice(tf.transpose(sims), 0, 1)); + return [...tf.softmax(tf.mul(this.def!.temperature, row)).dataSync()]; + } +} + +/** + * Container that holds a set of models. + */ +export class Models { + private readonly map = new Map(); + private readonly listeners = new Set(); + model?: Model; + + /** + * Adds a listener to be updated about individual models' loading progress. + */ + addListener(callback: StatusCallback) { + this.listeners.add(callback); + } + + /** + * Updates all listeners wth `progress` and `message`. + */ + onUpdate(progress: number, message?: string) { + if (progress === 1) { + message = `Loaded model "${this.model?.name}".`; + } + for (const callback of this.listeners) { + callback(progress, message); + } + } + + /** + * Loads model and sets `model` attribute when ready. + */ + async load(name: string) { + if (this.map.has(name)) { + this.model = this.map.get(name); + this.onUpdate(1, `Loaded "${name}".`); + return; + } + this.onUpdate(0, 'Loading...'); + this.model = new Model(name); + await this.model.load(this.onUpdate.bind(this)); + this.map.set(name, this.model); + } + + /** + * Whether model referenced by `model` attribute is ready. + */ + get ready(): boolean { + return !!this.model?.ready; + } +} + +/** Returns backend, such as "cpu" or "webgl". */ +export function getBackend(): string { + return tf.getBackend(); +} diff --git a/big_vision/tools/lit_demo/src/lit_demo/constants.ts b/big_vision/tools/lit_demo/src/lit_demo/constants.ts new file mode 100644 index 0000000000000000000000000000000000000000..45dab49ca5699a2897b98c0a412c838d8d8ba0e7 --- /dev/null +++ b/big_vision/tools/lit_demo/src/lit_demo/constants.ts @@ -0,0 +1,50 @@ +/** + * @license + * Copyright Big Vision Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * @fileoverview Project-wide constants. + */ + +// Can be overwritten with setBaseUrl() below. +// let baseUrl = 'https://google-research.github.io/vision_transformer/lit'; +let baseUrl = 'https://figur.li/jax2tfjs'; +// Can be overwritten with setModels() below. +let models = ['tiny', 'small']; + +/** Allows to set abnew base URL. ase URL on which all other. */ +export const setBaseUrl = (newBaseUrl: string) => { + baseUrl = newBaseUrl; +}; + +/** Retrieves URL for a model-specific file (vocabulary, embeddings, ...). */ +export const getModelFileUrl = (name: string, relativePath: string) => ( + `${baseUrl}/data/models/${name}/${relativePath}` +); + +/** Retrieves the URL for images information JSON file. */ +export const getImagesInfoUrl = () => `${baseUrl}/data/images/info.json`; + +/** Retrieves the URL for an image. */ +export const getImageUrl = (id: string) => `${baseUrl}/data/images/${id}.jpg`; + +/** Returns names of available models. */ +export const getModels = () => models; + +/** Sets names of available models. */ +export const setModels = (newModels: string[]) => { + models = newModels; +}; diff --git a/big_vision/tools/lit_demo/src/lit_demo/data.ts b/big_vision/tools/lit_demo/src/lit_demo/data.ts new file mode 100644 index 0000000000000000000000000000000000000000..6322d7dcefe22147d5715ee4710d2d7f2daf951a --- /dev/null +++ b/big_vision/tools/lit_demo/src/lit_demo/data.ts @@ -0,0 +1,76 @@ +/** + * @license + * Copyright Big Vision Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * @fileoverview Accessing additional data. + */ + +import {getImagesInfoUrl} from './constants'; + +/** + * Information about a single image. + */ +export interface ImageRow { + /** Stable ID of the image. */ + id: string; + /** Set of example prompts for this image. */ + prompts: string; + /** License of the image. */ + license: string; + /** Where the image was originally downloaded from. */ + source: string; + /** Short description of image. */ + description: string; +} +/** + * Contains information about all images. + */ +export class ImageData { + + rows: ImageRow[] = []; + /** Will be set to `true` when `load()` finishes. */ + ready = false; + + /** + * Gets an image by ID. Throws an error if image is not found, data is not + * loaded, or ID is not unique. + */ + get(id: string): ImageRow { + if (!this.ready) { + throw new Error('ImageData not loaded!'); + } + const matching = this.rows.filter(row => row.id === id); + if (matching.length !== 1) { + throw new Error(`Got unexpected ${matching.length} matches for id="${id}"`); + } + return matching[0]; + } + + /** + * Loads image data asynchronously. + */ + async load() { + this.rows = ( + await fetch(getImagesInfoUrl()) + .then(response => { + console.log('response', response); + return response.json(); + }) + ); + this.ready = true; + } +} diff --git a/big_vision/tools/lit_demo/src/lit_demo/url_utils.ts b/big_vision/tools/lit_demo/src/lit_demo/url_utils.ts new file mode 100644 index 0000000000000000000000000000000000000000..d169bc4711d44fbeb146a85de55d47a882fe695b --- /dev/null +++ b/big_vision/tools/lit_demo/src/lit_demo/url_utils.ts @@ -0,0 +1,92 @@ +/** + * @license + * Copyright Big Vision Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * @fileoverview (De)serialize state from/to URL. + */ + +// Should be updated whenever URLs are not compatible anymore +// (e.g. adding new images) +export const VERSION = 'v2'; +// version history: +// v1 used row number instead of image id + +const V1_IMAGE_IDS = [ + '1', '48', '43', '22', '2', '3', '4', '5', '6', '7', '8', '9', + '10', '11', '12', '13', '14', '15', '16', '17', '18', '19', '20', '21', + '23', '24', '25', '26', '27', '28', '29', '30', '31', '32', '33', '34', + '35', '36', '37', '38', '39', '40', '41', '42', '44', '45', '46', '47', + '49', '50', '51', '52', '53', '54', '55', '56', '57', '58', '59', '60' +]; + +/** + * State that can be stored in the URL. + */ +export interface State { + /** Name of the model. */ + modelName: string; + /** ID Of the image. */ + imageId: string; + /** List of text prompts. */ + prompts: string[]; +} + +/** + * Returns a URL for provided model/image/prompts. + */ +export const getUrl = + (modelName: string, imageId: string, prompts: string[]): string => { + let href = window.location.href; + if (href.indexOf('#') !== -1) { + href = href.substring(0, href.indexOf('#')); + } + const parts = [ + VERSION, + modelName, + imageId, + ...prompts, + ]; + return href + '#' + parts.map(encodeURIComponent).join('|'); + }; + +/** + * Parses an URL and returns a `State`, or undefined if no state is spefified. + * + * Raises an exception if there was a problem with the parsing of the URL. + */ +export const parseUrl = (): State|undefined => { + const hash = window.location.hash.substring(1); + if (!hash) return; + const parts = hash.split(/\|/g); + if (parts.length < 4) { + throw new Error(`Invalid URL: "${hash}"`); + } + let [version, modelName, imageId, ...texts] = parts; + if (version === VERSION) { + } else if (version === 'v1') { + const idx = Number(imageId); + if (isNaN(idx)) throw new Error(`Expected idx="${idx}" to be numerical!`); + imageId = V1_IMAGE_IDS[idx]; + } else { + throw new Error(`Incompatible version: ${version} (supported: ${VERSION})`); + } + return { + modelName, + imageId, + prompts: texts.map(decodeURIComponent), + }; +}; diff --git a/big_vision/tools/lit_demo/src/playground.html b/big_vision/tools/lit_demo/src/playground.html new file mode 100644 index 0000000000000000000000000000000000000000..112d69eb294ea364ea6b843a491ddb78fb1ff343 --- /dev/null +++ b/big_vision/tools/lit_demo/src/playground.html @@ -0,0 +1,92 @@ + + + + + + +

+ A simple demonstration how to use LiT models in a JS application using global exports. + See source code of this file for API usage. +

+ +

+    
+
+
+
+ + + + diff --git a/big_vision/tools/lit_demo/src/style.scss b/big_vision/tools/lit_demo/src/style.scss new file mode 100644 index 0000000000000000000000000000000000000000..7fab695cb992dd0552068f92cad28a7c6546eb07 --- /dev/null +++ b/big_vision/tools/lit_demo/src/style.scss @@ -0,0 +1,80 @@ +// General styles for the page. + +@import './style/colors'; +@import './style/mixins'; + +html { + font-size: 14px; + line-height: 1.6em; + font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, + Ubuntu, Cantarell, 'Fira Sans', 'Droid Sans', 'Helvetica Neue', Arial, + sans-serif; + text-size-adjust: 100%; + -ms-text-size-adjust: 100%; + -webkit-text-size-adjust: 100%; + + @media (min-width: 1200px) { + width: 1024px; + margin: 0 auto; + } + @media (min-width: 768px) { + font-size: 16px; + } + + color: var(--text-fg); + background: var(--text-bg); + + body { + margin: 0; + padding: 0rem 1rem 10rem; + } +} + +a, +a:visited { + color: var(--link-col); +} + +h1 { + font-weight: 700; + font-size: 2rem; + line-height: 1.3em; +} + +p { + font-size: 1.06rem; + line-height: 1.3em; +} + +input { + font-size: 1rem; + + &::placeholder { + color: var(--placeholder-col); + } +} + +.note { + font-style: normal; + border: none; + border-radius: 2px; + margin-left: auto; + margin-right: auto; + + padding: 0.5rem 0.5rem 0.5rem 2rem; + width: 90%; + + @include phone-portrait { + width: 100%; + padding: 0.5rem; + box-sizing: border-box; + } + + background-color: var(--note-bg); + color: var(--note-fg); + + &.warning { + background-color: var(--warn-bg); + color: var(--warn-fg); + } +} diff --git a/big_vision/tools/lit_demo/src/style/colors.scss b/big_vision/tools/lit_demo/src/style/colors.scss new file mode 100644 index 0000000000000000000000000000000000000000..8c03eefa93fed513fee7cf58660f7ee7fb8bf6a4 --- /dev/null +++ b/big_vision/tools/lit_demo/src/style/colors.scss @@ -0,0 +1,35 @@ +// Dark and light mode colors. + +:root { + --text-bg: hsl(0, 0%, 97%); + --gray-border: hsla(0, 0%, 0%, 0.1); + --gray: rgba(0, 0, 0, 0.6); + --border-radius: 5px; + --orange: hsl(24, 100%, 50%); + --distill-blue: hsl(200, 50%, 25%); + --blue: #337699; + --green: #3db867; + --text-fg: rgb(15, 15, 15); + --text-red: rgb(220, 0, 0); + --bar-col: rgb(171, 199, 227); + --link-col: rgb(0, 0, 238); + --placeholder-col: rgb(166, 166, 166); + --note-bg: #e1f5fe; + --note-fg: #1a6ebb; + --warn-bg: #ffe1aa; + --warn-fg: #a16800; + --error-bg: #850000; + --error-fg: white; + + @media (prefers-color-scheme: dark) { + --text-bg: rgb(56, 56, 56); + --text-fg: rgb(213, 213, 213); + --bar-col: rgb(20, 109, 163); + --link-col: rgb(66, 165, 245); + + --note-fg: rgb(121 157 190); + --note-bg: rgb(2 59 85); + --warn-bg: #784e00; + --warn-fg: #edbe68; + } +} diff --git a/big_vision/tools/lit_demo/src/style/mixins.scss b/big_vision/tools/lit_demo/src/style/mixins.scss new file mode 100644 index 0000000000000000000000000000000000000000..8e4ad3fd9e0bdd24b1ceb08dc55551a05e37f343 --- /dev/null +++ b/big_vision/tools/lit_demo/src/style/mixins.scss @@ -0,0 +1,8 @@ +// Useful mixins. + +// To wrap styles that should only trigger for phones in portrait mode. +@mixin phone-portrait { + @media only screen and (max-device-width: 800px) and (orientation: portrait) { + @content; + } +} diff --git a/big_vision/tools/lit_demo/src/tokenizers/common.ts b/big_vision/tools/lit_demo/src/tokenizers/common.ts new file mode 100644 index 0000000000000000000000000000000000000000..dd601bf48f9cc66327d574efc0e189b2228c8330 --- /dev/null +++ b/big_vision/tools/lit_demo/src/tokenizers/common.ts @@ -0,0 +1,58 @@ +/** + * @license + * Copyright Big Vision Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * @fileoverview Utility code shared between tokenizers. + */ + +/** + * A vocabulary consists of a list of tokens, and optional numerical value. + * The numerical value is used by the unigram algorithnm to find the best + * tokenizaion, and is ignored by the BPE algorithm. + */ +export type Vocabulary = Array<[string, number]>; + +/** + * Converts a string to a sequence of tokens. + */ +export interface Tokenizer { + encode(input: string): number[]; +} + +/** + * Factory for new `Tokenizer`. + */ +export interface TokenizerConstructor { + new (vocabulary: Vocabulary): Tokenizer; +} + +/** + * Unicode-aware character iteration of strings. + */ +export const stringToChars = (input: string): string[] => { + const symbols = []; + for (const symbol of input) { + symbols.push(symbol); + } + return symbols; +}; + +/** + * Special separator character used to delimit sub-word tokens. + */ +export const TOKEN_SEPARATOR = + '\u2581'; // This is the unicode character 'lower one eighth block'. diff --git a/big_vision/tools/lit_demo/src/tokenizers/index.ts b/big_vision/tools/lit_demo/src/tokenizers/index.ts new file mode 100644 index 0000000000000000000000000000000000000000..687e4da496cdfa726599f2791ba601fde62a324d --- /dev/null +++ b/big_vision/tools/lit_demo/src/tokenizers/index.ts @@ -0,0 +1,40 @@ +/** + * @license + * Copyright Big Vision Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * @fileoverview Tokenizers and tokenizer mappings. + */ + +import {Tokenizer, TokenizerConstructor, Vocabulary} from './common'; +import * as sentencepieceBpe from './sentencepiece_bpe'; +import * as sentencepieceUnigram from './sentencepiece_unigram'; + +export {Tokenizer, Vocabulary} from './common'; + +const TOKENIZERS = new Map([ + ['BPE', sentencepieceBpe.Tokenizer], + ['UNIGRAM', sentencepieceUnigram.Tokenizer], +]); + +/** + * Returns a tokenizer of type `name` using `vocabulary`. + */ +export const getTokenizer = (name: string, vocabulary: Vocabulary): Tokenizer => { + const ctor = TOKENIZERS.get(name); + if (!ctor) throw new Error(`Unknown tokenizer: ${name}`); + return new ctor(vocabulary); +}; diff --git a/big_vision/tools/lit_demo/src/tokenizers/sentencepiece_bpe.ts b/big_vision/tools/lit_demo/src/tokenizers/sentencepiece_bpe.ts new file mode 100644 index 0000000000000000000000000000000000000000..b2603bac81b0066e0c63bf445bc2e41703bf7bb3 --- /dev/null +++ b/big_vision/tools/lit_demo/src/tokenizers/sentencepiece_bpe.ts @@ -0,0 +1,80 @@ +/** + * @license + * Copyright Big Vision Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {stringToChars, TOKEN_SEPARATOR, Vocabulary, Tokenizer as TokenizerInterface} from './common'; + +interface Candidate { + piece: string; + pos: number; + score: number; +} + +const scoreDesc = (a: Candidate, b: Candidate) => b.score - a.score; + +function processInput(str: string): string { + const normalized = str.normalize('NFKC'); + return normalized.length > 0 ? + TOKEN_SEPARATOR + normalized.replace(/ /g, TOKEN_SEPARATOR) : + normalized; +} + +/** + * Sentencepiece tokenizer implementing the BPE algorithm. + */ +export class Tokenizer implements TokenizerInterface { + + // piece -> [score, index] + private readonly map: Map; + + constructor(vocabulary: Vocabulary) { + this.map = new Map(); + vocabulary.forEach(([piece, score], idx) => { + if (this.map.has(piece)) { + throw new Error(`Piece "${piece}" occurs multiple times in vocabulary`); + } + this.map.set(piece, [score, idx]); + }); + } + + encode(input: string): number[] { + const processed: string = processInput(input); + let pieces: string[] = stringToChars(processed); + + while (true) { + const candidates: Candidate[] = []; + for (let i = 0; i < pieces.length - 1; i++) { + const fused = pieces[i] + pieces[i + 1]; + const el = this.map.get(fused); + if (el) { + candidates.push({ piece: fused, pos: i, score: el[0] }); + } + } + if (candidates.length === 0) { + break; + } + candidates.sort(scoreDesc); + const best = candidates[0]; + pieces = [ + ...pieces.slice(0, best.pos), + best.piece, + ...pieces.slice(best.pos + 2) + ]; + } + + return pieces.map(piece => this.map.get(piece)![1]); + } +} diff --git a/big_vision/tools/lit_demo/src/tokenizers/sentencepiece_bpe_test.ts b/big_vision/tools/lit_demo/src/tokenizers/sentencepiece_bpe_test.ts new file mode 100644 index 0000000000000000000000000000000000000000..022ae1065fe89dbc28695ac9d8f16dbf2f7aaa36 --- /dev/null +++ b/big_vision/tools/lit_demo/src/tokenizers/sentencepiece_bpe_test.ts @@ -0,0 +1,48 @@ +/** + * @license + * Copyright Big Vision Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import 'jasmine'; + +describe('sentencepiece bpe test', () => { + it('computes a thing when asked', () => {}); +}); + +import * as bpe from './sentencepiece_bpe'; +import {TOKEN_SEPARATOR, Vocabulary} from './common'; + +const vocab: Vocabulary = [ + [TOKEN_SEPARATOR, 0], // 0 + ['a', 0], // 1 + ['e', 0], // 2 + ['s', 0], // 3 + ['t', 0], // 4 + ['te', -1], // 5 + ['st', -2], // 6 + ['test', -3], // 7 + ['tes', -4], // 8 +]; + +describe('BPE Tokenizer', () => { + let tokenizer: bpe.Tokenizer; + beforeAll(() => { + tokenizer = new bpe.Tokenizer(vocab); + }); + + it('should tokenize correctly', () => { + expect(tokenizer.encode('a test')).toEqual([0, 1, 0, 7]); + }); +}); diff --git a/big_vision/tools/lit_demo/src/tokenizers/sentencepiece_unigram.ts b/big_vision/tools/lit_demo/src/tokenizers/sentencepiece_unigram.ts new file mode 100644 index 0000000000000000000000000000000000000000..7e3ac0a40ff83678423978c046f7d035a85de2dd --- /dev/null +++ b/big_vision/tools/lit_demo/src/tokenizers/sentencepiece_unigram.ts @@ -0,0 +1,134 @@ +/** + * @license + * Copyright Big Vision Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Copied & adapted from +// https://github.com/tensorflow/tfjs-models/blob/master/universal-sentence-encoder/src/tokenizer/index.ts + +import {TOKEN_SEPARATOR, stringToChars, Tokenizer as TokenizerInterface, Vocabulary} from './common'; +import {Trie} from './trie'; + +function processInput(str: string): string { + const normalized = str.normalize('NFKC'); + return normalized.length > 0 ? + TOKEN_SEPARATOR + normalized.replace(/ /g, TOKEN_SEPARATOR) : + normalized; +} + +// The first tokens are reserved for unk, control symbols, and user-defined +// symbols. +const RESERVED_SYMBOLS_COUNT = 6; + +interface Score { + key: string[]; + score: number; + index: number; +} + +/** + * Sentencepiece tokenizer implementing the UNIGRAM algorithm. + * + * `Tokenizer.encode()` is a port of `EncodeAsIds` from the SentencePiece + * library (https://github.com/google/sentencepiece). Encode uses the Viterbi + * algorithm to find the most likely sequence of tokens that comprise the input. + * For more details, refer to https://arxiv.org/pdf/1804.10959.pdf. + */ +export class Tokenizer implements TokenizerInterface { + trie: Trie; + + constructor( + private readonly vocabulary: Vocabulary, + private readonly reservedSymbolsCount = RESERVED_SYMBOLS_COUNT) { + this.trie = new Trie(); + + for (let i = this.reservedSymbolsCount; i < this.vocabulary.length; i++) { + this.trie.insert(this.vocabulary[i][0], this.vocabulary[i][1], i); + } + } + + encode(input: string): number[] { + const nodes: Array<{[index: number]: Score[]}> = []; + const words: number[] = []; + const best: number[] = []; + + input = processInput(input); + + const symbols = stringToChars(input); + + for (let i = 0; i <= symbols.length; i++) { + nodes.push({}); + words.push(0); + best.push(0); + } + + // Construct the lattice. + for (let i = 0; i < symbols.length; i++) { + const matches = this.trie.commonPrefixSearch(symbols.slice(i)); + + for (let j = 0; j < matches.length; j++) { + const piece = matches[j]; + const obj = {key: piece[0], score: piece[1], index: piece[2]}; + + const endPos = piece[0].length; + if (nodes[i + endPos][i] == null) { + nodes[i + endPos][i] = []; + } + + nodes[i + endPos][i].push(obj); + } + } + + for (let endPos = 0; endPos <= symbols.length; endPos++) { + for (const startPos in nodes[endPos]) { + if (!nodes[endPos].hasOwnProperty(startPos)) continue; + const arr = nodes[endPos][startPos]; + + for (let j = 0; j < arr.length; j++) { + const word = arr[j]; + const score = word.score + best[endPos - word.key.length]; + + if (best[endPos] === 0 || score >= best[endPos]) { + best[endPos] = score; + words[endPos] = arr[j].index; + } + } + } + } + + const results: number[] = []; + + // Backward pass. + let iter = words.length - 1; + while (iter > 0) { + results.push(words[iter]); + iter -= this.vocabulary[words[iter]][0].length; + } + + // Merge consecutive unks. + const merged = []; + let isPreviousUnk = false; + for (let i = 0; i < results.length; i++) { + const id = results[i]; + if (!(isPreviousUnk && id === 0)) { + merged.push(id); + } + + isPreviousUnk = id === 0; + } + + return merged.reverse(); + } +} diff --git a/big_vision/tools/lit_demo/src/tokenizers/sentencepiece_unigram_test.ts b/big_vision/tools/lit_demo/src/tokenizers/sentencepiece_unigram_test.ts new file mode 100644 index 0000000000000000000000000000000000000000..d52af2ba624e6d104b3529491fee6ba6f6c61754 --- /dev/null +++ b/big_vision/tools/lit_demo/src/tokenizers/sentencepiece_unigram_test.ts @@ -0,0 +1,71 @@ +/** + * @license + * Copyright Big Vision Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {Tokenizer} from './sentencepiece_unigram'; + +const stubbedTokenizerVocab = [ + ['�', 0], + ['', 0], + ['', 0], + ['extra_token_id_1', 0], + ['extra_token_id_2', 0], + ['extra_token_id_3', 0], + ['▁', -2], + ['▁a', -1], + ['▁ç', -2], + ['a', -3], + ['.', -1], + ['▁I', -1], + ['▁like', -1], + ['▁it', -1], + ['I', -2], + ['like', -2], + ['it', -2], + ['l', -3], + ['i', -3], + ['k', -3], + ['e', -3], + ['i', -3], + ['t', -3] +]; + +describe('Universal Sentence Encoder tokenizer', () => { + let tokenizer: Tokenizer; + beforeAll(() => { + tokenizer = new Tokenizer(stubbedTokenizerVocab as Array<[string, number]>); + }); + + it('basic usage', () => { + expect(tokenizer.encode('Ilikeit.')).toEqual([11, 15, 16, 10]); + }); + + it('handles whitespace', () => { + expect(tokenizer.encode('I like it.')).toEqual([11, 12, 13, 10]); + }); + + it('should normalize inputs', () => { + expect(tokenizer.encode('ça')).toEqual(tokenizer.encode('c\u0327a')); + }); + + it('should handle unknown inputs', () => { + expect(() => tokenizer.encode('😹')).not.toThrow(); + }); + + it('should treat consecutive unknown inputs as a single word', () => { + expect(tokenizer.encode('a😹😹')).toEqual([7, 0]); + }); +}); diff --git a/big_vision/tools/lit_demo/src/tokenizers/trie.ts b/big_vision/tools/lit_demo/src/tokenizers/trie.ts new file mode 100644 index 0000000000000000000000000000000000000000..2f6954bb962cad952e2f246a65a46bcb4fdb5756 --- /dev/null +++ b/big_vision/tools/lit_demo/src/tokenizers/trie.ts @@ -0,0 +1,96 @@ +/** + * @license + * Copyright Big Vision Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Copied from +// https://github.com/tensorflow/tfjs-models/blob/master/universal-sentence-encoder/src/tokenizer/trie.ts + +import {stringToChars} from './common'; + +// [token, score, index] +type OutputNode = [string[], number, number]; + +class TrieNode { + parent: TrieNode|null; + end: boolean; + children: {[firstSymbol: string]: TrieNode}; + word: OutputNode; + + constructor() { + this.parent = null; + this.children = {}; + this.end = false; + this.word = [[], 0, 0]; + } +} + +/** + * Simple Trie datastructure. + */ +export class Trie { + root: TrieNode; + + constructor() { + this.root = new TrieNode(); + } + + /** + * Inserts a token into the trie. + */ + insert(word: string, score: number, index: number) { + let node = this.root; + + const symbols = stringToChars(word); + + for (let i = 0; i < symbols.length; i++) { + if (!node.children[symbols[i]]) { + node.children[symbols[i]] = new TrieNode(); + node.children[symbols[i]].parent = node; + node.children[symbols[i]].word[0] = node.word[0].concat(symbols[i]); + } + + node = node.children[symbols[i]]; + if (i === symbols.length - 1) { + node.end = true; + node.word[1] = score; + node.word[2] = index; + } + } + } + + /** + * Returns an array of all tokens starting with ss. + * + * @param ss The prefix to match on. + */ + commonPrefixSearch(ss: string[]): OutputNode[] { + const output: OutputNode[] = []; + let node = this.root.children[ss[0]]; + + for (let i = 0; i < ss.length && node; i++) { + if (node.end) { + output.push(node.word); + } + node = node.children[ss[i + 1]]; + } + + if (!output.length) { + output.push([[ss[0]], 0, 0]); + } + + return output; + } +} diff --git a/big_vision/tools/lit_demo/src/tsconfig.json b/big_vision/tools/lit_demo/src/tsconfig.json new file mode 100644 index 0000000000000000000000000000000000000000..f77877754a883975f3680f4cdcdc7df1f0c5e796 --- /dev/null +++ b/big_vision/tools/lit_demo/src/tsconfig.json @@ -0,0 +1,41 @@ +{ + "compilerOptions": { + "outDir": "dist", + "target": "es6", + "module": "commonjs", + "lib": ["dom", "DOM.Iterable", "es2019", "es2020.string"], + "types": ["node", "jasmine", "resize-observer-browser"], + "moduleResolution": "node", + "allowJs": false, + "pretty": true, + "resolveJsonModule": true, + "sourceMap": false, + "skipLibCheck": true, + "removeComments": true, + "esModuleInterop": true, + "importsNotUsedAsValues": "preserve", + "downlevelIteration": true, + "skipDefaultLibCheck": true, + "preserveConstEnums": false, + "experimentalDecorators": true, + "emitDecoratorMetadata": true, + "noErrorTruncation": false, + "noEmitOnError": false, + "declaration": false, + "stripInternal": true, + "inlineSourceMap": true, + "inlineSources": true, + "importHelpers": true, + "allowUnreachableCode": false, + "noFallthroughCasesInSwitch": true, + "noImplicitAny": true, + "noImplicitReturns": false, + "noImplicitThis": true, + "strictBindCallApply": true, + "strictFunctionTypes": true, + "strictNullChecks": false, + "strictPropertyInitialization": false + }, + "include": ["./client", "./examples"], + "compileOnSave": false +} diff --git a/big_vision/train.py b/big_vision/train.py new file mode 100644 index 0000000000000000000000000000000000000000..42381890a3f26068da3f98afcc711b1ef00fb049 --- /dev/null +++ b/big_vision/train.py @@ -0,0 +1,522 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Training loop example. + +This is a basic variant of a training loop, good starting point for fancy ones. +""" +# pylint: disable=consider-using-from-import +# pylint: disable=logging-fstring-interpolation + +import functools +import importlib +import multiprocessing.pool +import os + +from absl import app +from absl import flags +from absl import logging +import big_vision.evaluators.common as eval_common +import big_vision.input_pipeline as input_pipeline +import big_vision.optax as bv_optax +import big_vision.sharding as bv_sharding +import big_vision.utils as u +from clu import parameter_overview +import flax.linen as nn +import jax +from jax.experimental import mesh_utils +from jax.experimental import multihost_utils +from jax.experimental.array_serialization import serialization as array_serial +from jax.experimental.shard_map import shard_map +import jax.numpy as jnp +from ml_collections import config_flags +import numpy as np +import optax +import tensorflow as tf + +from tensorflow.io import gfile + + +config_flags.DEFINE_config_file( + "config", None, "Training configuration.", lock_config=True) + +flags.DEFINE_string("workdir", default=None, help="Work unit directory.") +flags.DEFINE_boolean("cleanup", default=False, + help="Delete workdir (only) after successful completion.") + +# Adds jax flags to the program. +jax.config.parse_flags_with_absl() +# Transfer guard will fail the program whenever that data between a host and +# a device is transferred implicitly. This often catches subtle bugs that +# cause slowdowns and memory fragmentation. Explicit transfers are done +# with jax.device_put and jax.device_get. +jax.config.update("jax_transfer_guard", "disallow") +# Fixes design flaw in jax.random that may cause unnecessary d2d comms. +jax.config.update("jax_threefry_partitionable", True) + + +NamedSharding = jax.sharding.NamedSharding +P = jax.sharding.PartitionSpec + + +def main(argv): + del argv + + # This is needed on multihost systems, but crashes on non-TPU single-host. + if os.environ.get("BV_JAX_INIT"): + jax.distributed.initialize() + + # Make sure TF does not touch GPUs. + tf.config.set_visible_devices([], "GPU") + + config = flags.FLAGS.config + +################################################################################ +# # +# Set up logging # +# # +################################################################################ + + # Set up work directory and print welcome message. + workdir = flags.FLAGS.workdir + logging.info( + f"\u001b[33mHello from process {jax.process_index()} holding " + f"{jax.local_device_count()}/{jax.device_count()} devices and " + f"writing to workdir {workdir}.\u001b[0m") + + save_ckpt_path = None + if workdir: # Always create if requested, even if we may not write into it. + gfile.makedirs(workdir) + save_ckpt_path = os.path.join(workdir, "checkpoint.bv") + + # The pool is used to perform misc operations such as logging in async way. + pool = multiprocessing.pool.ThreadPool() + + # Here we register preprocessing ops from modules listed on `pp_modules`. + for m in config.get("pp_modules", ["ops_general", "ops_image", "ops_text"]): + importlib.import_module(f"big_vision.pp.{m}") + + # Setup up logging and experiment manager. + xid, wid = -1, -1 + fillin = lambda s: s + def info(s, *a): + logging.info("\u001b[33mNOTE\u001b[0m: " + s, *a) + def write_note(note): + if jax.process_index() == 0: + info("%s", note) + + mw = u.BigVisionMetricWriter(xid, wid, workdir, config) + + # Allow for things like timings as early as possible! + u.chrono.inform(measure=mw.measure, write_note=write_note) + +################################################################################ +# # +# Set up Mesh # +# # +################################################################################ + + # We rely on jax mesh_utils to organize devices, such that communication + # speed is the fastest for the last dimension, second fastest for the + # penultimate dimension, etc. + config_mesh = config.get("mesh", [("data", jax.device_count())]) + + # Sharding rules with default + sharding_rules = config.get("sharding_rules", [("act_batch", "data")]) + + mesh_axes, mesh_size = tuple(zip(*config_mesh)) + + # Because jax.utils do not support `-1` shape size. + mesh_size = np.array(jax.devices()).reshape(mesh_size).shape + + device_mesh = mesh_utils.create_device_mesh( + mesh_size, allow_split_physical_axes=config.get( + "mesh_allow_split_physical_axes", False)) + + # Consistent device order is important to ensure correctness of various train + # loop components, such as input pipeline, update step, evaluators. The + # order presribed by the `devices_flat` variable should be used throughout + # the program. + devices_flat = device_mesh.flatten() + +################################################################################ +# # +# Input Pipeline # +# # +################################################################################ + + write_note("Initializing train dataset...") + batch_size = config.input.batch_size + if batch_size % jax.device_count() != 0: + raise ValueError(f"Batch size ({batch_size}) must " + f"be divisible by device number ({jax.device_count()})") + info("Global batch size %d on %d hosts results in %d local batch size. With " + "%d dev per host (%d dev total), that's a %d per-device batch size.", + batch_size, jax.process_count(), batch_size // jax.process_count(), + jax.local_device_count(), jax.device_count(), + batch_size // jax.device_count()) + + train_ds, ntrain_img = input_pipeline.training(config.input) + + total_steps = u.steps("total", config, ntrain_img, batch_size) + def get_steps(name, default=ValueError, cfg=config): + return u.steps(name, cfg, ntrain_img, batch_size, total_steps, default) + + u.chrono.inform(total_steps=total_steps, global_bs=batch_size, + steps_per_epoch=ntrain_img / batch_size) + + info("Running for %d steps, that means %f epochs", + total_steps, total_steps * batch_size / ntrain_img) + + # Start input pipeline as early as possible. + n_prefetch = config.get("prefetch_to_device", 1) + train_iter = input_pipeline.start_global(train_ds, devices_flat, n_prefetch) + +################################################################################ +# # +# Create Model & Optimizer # +# # +################################################################################ + + write_note("Creating model...") + model_mod = importlib.import_module(f"big_vision.models.{config.model_name}") + model = model_mod.Model( + num_classes=config.num_classes, **config.get("model", {})) + + def init(rng): + batch = jax.tree.map(lambda x: jnp.zeros(x.shape, x.dtype.as_numpy_dtype), + train_ds.element_spec) + params = model.init(rng, batch["image"])["params"] + + # Set bias in the head to a low value, such that loss is small initially. + if "init_head_bias" in config: + params["head"]["bias"] = jnp.full_like(params["head"]["bias"], + config["init_head_bias"]) + + return params + + # This seed makes the Jax part of things (like model init) deterministic. + # However, full training still won't be deterministic, for example due to the + # tf.data pipeline not being deterministic even if we would set TF seed. + # See (internal link) for a fun read on what it takes. + rng = jax.random.PRNGKey(u.put_cpu(config.get("seed", 0))) + + write_note("Inferring parameter shapes...") + rng, rng_init = jax.random.split(rng) + params_shape = jax.eval_shape(init, rng_init) + + write_note("Inferring optimizer state shapes...") + tx, sched_fns = bv_optax.make(config, nn.unbox(params_shape), sched_kw=dict( + total_steps=total_steps, batch_size=batch_size, data_size=ntrain_img)) + opt_shape = jax.eval_shape(tx.init, params_shape) + # We jit this, such that the arrays are created on the CPU, not device[0]. + sched_fns_cpu = [u.jit_cpu()(sched_fn) for sched_fn in sched_fns] + + if jax.process_index() == 0: + num_params = sum(np.prod(p.shape) for p in jax.tree.leaves(params_shape)) + mw.measure("num_params", num_params) + +################################################################################ +# # +# Shard & Transfer # +# # +################################################################################ + + write_note("Creating device mesh...") + mesh = jax.sharding.Mesh(device_mesh, mesh_axes) + repl_sharding = jax.sharding.NamedSharding(mesh, P()) + + write_note("Inferring shardings...") + train_state_shape = {"params": params_shape, "opt": opt_shape} + + strategy = config.get("sharding_strategy", [(".*", "replicate")]) + with nn.logical_axis_rules(sharding_rules): + train_state_sharding = bv_sharding.infer_sharding( + train_state_shape, strategy=strategy, mesh=mesh) + + write_note("Transferring train_state to devices...") + # RNG is always replicated + rng_init = u.reshard(rng_init, repl_sharding) + + # Parameters and the optimizer are now global (distributed) jax arrays. + params = jax.jit(init, out_shardings=train_state_sharding["params"])(rng_init) + opt = jax.jit(tx.init, out_shardings=train_state_sharding["opt"])(params) + + rng, rng_loop = jax.random.split(rng, 2) + rng_loop = u.reshard(rng_loop, repl_sharding) + del rng # not used anymore, so delete it. + + # At this point we have everything we need to form a train state. It contains + # all the parameters that are passed and updated by the main training step. + # From here on, we have no need for Flax AxisMetadata (such as partitioning). + train_state = nn.unbox({"params": params, "opt": opt}) + del params, opt # Delete to avoid memory leak or accidental reuse. + + write_note("Logging parameter overview...") + parameter_overview.log_parameter_overview( + train_state["params"], msg="Init params", + include_stats="global", jax_logging_process=0) + +################################################################################ +# # +# Update Step # +# # +################################################################################ + + @functools.partial( + jax.jit, + donate_argnums=(0,), + out_shardings=(train_state_sharding, repl_sharding)) + def update_fn(train_state, rng, batch): + """Update step.""" + + images, labels = batch["image"], batch["labels"] + + step_count = bv_optax.get_count(train_state["opt"], jittable=True) + rng = jax.random.fold_in(rng, step_count) + + if config.get("mixup") and config.mixup.p: + # The shard_map below makes mixup run on every device independently and + # thus avoids unnecessary communication. + sharded_mixup_fn = shard_map( + u.get_mixup(rng, config.mixup.p), + mesh=jax.sharding.Mesh(devices_flat, ("data",)), + in_specs=P("data"), out_specs=(P(), P("data"), P("data"))) + rng, (images, labels), _ = sharded_mixup_fn(images, labels) + + # Get device-specific loss rng. + rng, rng_model = jax.random.split(rng, 2) + + def loss_fn(params): + logits, _ = model.apply( + {"params": params}, images, + train=True, rngs={"dropout": rng_model}) + return getattr(u, config.get("loss", "sigmoid_xent"))( + logits=logits, labels=labels) + + params, opt = train_state["params"], train_state["opt"] + loss, grads = jax.value_and_grad(loss_fn)(params) + updates, opt = tx.update(grads, opt, params) + params = optax.apply_updates(params, updates) + + measurements = {"training_loss": loss} + gs = jax.tree.leaves(bv_optax.replace_frozen(config.schedule, grads, 0.)) + measurements["l2_grads"] = jnp.sqrt(sum([jnp.sum(g * g) for g in gs])) + ps = jax.tree.leaves(params) + measurements["l2_params"] = jnp.sqrt(sum([jnp.sum(p * p) for p in ps])) + us = jax.tree.leaves(updates) + measurements["l2_updates"] = jnp.sqrt(sum([jnp.sum(u * u) for u in us])) + + return {"params": params, "opt": opt}, measurements + +################################################################################ +# # +# Load Checkpoint # +# # +################################################################################ + + # Decide how to initialize training. The order is important. + # 1. Always resumes from the existing checkpoint, e.g. resumes a finetune job. + # 2. Resume from a previous checkpoint, e.g. start a cooldown training job. + # 3. Initialize model from something, e,g, start a fine-tuning job. + # 4. Train from scratch. + resume_ckpt_path = None + if save_ckpt_path and gfile.exists(f"{save_ckpt_path}-LAST"): + resume_ckpt_path = save_ckpt_path + elif config.get("resume"): + resume_ckpt_path = fillin(config.resume) + + ckpt_mngr = None + if save_ckpt_path or resume_ckpt_path: + ckpt_mngr = array_serial.GlobalAsyncCheckpointManager() + + if resume_ckpt_path: + write_note(f"Resuming training from checkpoint {resume_ckpt_path}...") + jax.tree.map(lambda x: x.delete(), train_state) + del train_state + shardings = { + **train_state_sharding, + "chrono": jax.tree.map(lambda _: repl_sharding, + u.chrono.save()), + } + loaded = u.load_checkpoint_ts( + resume_ckpt_path, tree=shardings, shardings=shardings) + train_state = {key: loaded[key] for key in train_state_sharding.keys()} + + u.chrono.load(jax.device_get(loaded["chrono"])) + del loaded + elif config.get("model_init"): + write_note(f"Initialize model from {config.model_init}...") + # TODO: when updating the `load` API soon, do pass and request the + # full `train_state` from it. Examples where useful: VQVAE, BN. + train_state["params"] = model_mod.load( + train_state["params"], config.model_init, config.get("model"), + **config.get("model_load", {})) + + # load has the freedom to return params not correctly sharded. Think of for + # example ViT resampling position embedings on CPU as numpy arrays. + train_state["params"] = u.reshard( + train_state["params"], train_state_sharding["params"]) + + parameter_overview.log_parameter_overview( + train_state["params"], msg="restored params", + include_stats="global", jax_logging_process=0) + + +################################################################################ +# # +# Setup Evals # +# # +################################################################################ + + # We do not jit/pmap this function, because it is passed to evaluator that + # does it later. We output as many intermediate tensors as possible for + # maximal flexibility. Later `jit` will prune out things that are not needed. + def eval_logits_fn(train_state, batch): + logits, out = model.apply({"params": train_state["params"]}, batch["image"]) + return logits, out + + def eval_loss_fn(train_state, batch): + logits, _ = model.apply({"params": train_state["params"]}, batch["image"]) + loss_fn = getattr(u, config.get("loss", "sigmoid_xent")) + return { + "loss": loss_fn(logits=logits, labels=batch["labels"], reduction=False) + } + + eval_fns = { + "predict": eval_logits_fn, + "loss": eval_loss_fn, + } + + # Only initialize evaluators when they are first needed. + @functools.lru_cache(maxsize=None) + def evaluators(): + return eval_common.from_config( + config, eval_fns, + lambda s: write_note(f"Init evaluator: {s}…\n{u.chrono.note}"), + lambda key, cfg: get_steps(key, default=None, cfg=cfg), + devices_flat, + ) + + # At this point we need to know the current step to see whether to run evals. + write_note("Inferring the first step number...") + first_step_device = bv_optax.get_count(train_state["opt"], jittable=True) + first_step = int(jax.device_get(first_step_device)) + u.chrono.inform(first_step=first_step) + + # Note that training can be pre-empted during the final evaluation (i.e. + # just after the final checkpoint has been written to disc), in which case we + # want to run the evals. + if first_step in (total_steps, 0): + write_note("Running initial or final evals...") + mw.step_start(first_step) + for (name, evaluator, _, prefix) in evaluators(): + if config.evals[name].get("skip_first") and first_step != total_steps: + continue + write_note(f"{name} evaluation...\n{u.chrono.note}") + with u.chrono.log_timing(f"z/secs/eval/{name}"): + with mesh, nn.logical_axis_rules(sharding_rules): + for key, value in evaluator.run(train_state): + mw.measure(f"{prefix}{key}", value) + +################################################################################ +# # +# Train Loop # +# # +################################################################################ + + prof = None # Keeps track of start/stop of profiler state. + + write_note("Starting training loop, compiling the first step...") + for step, batch in zip(range(first_step + 1, total_steps + 1), train_iter): + mw.step_start(step) + + with jax.profiler.StepTraceAnnotation("train_step", step_num=step): + with u.chrono.log_timing("z/secs/update0", noop=step > first_step + 1): + with mesh, nn.logical_axis_rules(sharding_rules): + train_state, measurements = update_fn(train_state, rng_loop, batch) + + # On the first host, let's always profile a handful of early steps. + if jax.process_index() == 0: + prof = u.startstop_prof(prof, step, first_step, get_steps("log_training")) + + # Report training progress + if (u.itstime(step, get_steps("log_training"), total_steps, host=0) + or u.chrono.warmup and jax.process_index() == 0): + for i, sched_fn_cpu in enumerate(sched_fns_cpu): + mw.measure(f"global_schedule{i if i else ''}", + sched_fn_cpu(u.put_cpu(step - 1))) + measurements = jax.device_get(measurements) + for name, value in measurements.items(): + mw.measure(name, value) + u.chrono.tick(step) + if not np.isfinite(measurements["training_loss"]): + raise RuntimeError(f"The loss became nan or inf somewhere within steps " + f"[{step - get_steps('log_training')}, {step}]") + + # Checkpoint saving + keep_ckpt_steps = get_steps("keep_ckpt", None) or total_steps + if save_ckpt_path and ( + (keep := u.itstime(step, keep_ckpt_steps, total_steps, first=False)) + or u.itstime(step, get_steps("ckpt", None), total_steps, first=True) + ): + u.chrono.pause(wait_for=train_state) + + # Copy because we add extra stuff to the checkpoint. + ckpt = {**train_state} + + # To save chrono state correctly and safely in a multihost setup, we + # broadcast the state to all hosts and convert it to a global array. + with jax.transfer_guard("allow"): + chrono_ckpt = multihost_utils.broadcast_one_to_all(u.chrono.save()) + chrono_shardings = jax.tree.map(lambda _: repl_sharding, chrono_ckpt) + ckpt = ckpt | {"chrono": u.reshard(chrono_ckpt, chrono_shardings)} + + u.save_checkpoint_ts(ckpt_mngr, ckpt, save_ckpt_path, step, keep) + u.chrono.resume() + + for (name, evaluator, log_steps, prefix) in evaluators(): + if u.itstime(step, log_steps, total_steps, first=False, last=True): + u.chrono.pause(wait_for=train_state) + u.chrono.tick(step) # Record things like epoch number, core hours etc. + write_note(f"{name} evaluation...\n{u.chrono.note}") + with u.chrono.log_timing(f"z/secs/eval/{name}"): + with mesh, nn.logical_axis_rules(sharding_rules): + for key, value in evaluator.run(train_state): + mw.measure(f"{prefix}{key}", jax.device_get(value)) + u.chrono.resume() + mw.step_end() + + # Always give a chance to stop the profiler, no matter how things ended. + # TODO: can we also do this when dying of an exception like OOM? + if jax.process_index() == 0 and prof is not None: + u.startstop_prof(prof) + + # Last note needs to happen before the pool's closed =) + write_note(f"Done!\n{u.chrono.note}") + + pool.close() + pool.join() + mw.close() + if ckpt_mngr: + ckpt_mngr.wait_until_finished() + + # Make sure all hosts stay up until the end of main. + u.sync() + + u.maybe_cleanup_workdir(workdir, flags.FLAGS.cleanup, info) + + +if __name__ == "__main__": + app.run(main) diff --git a/big_vision/trainers/proj/cappa/generative.py b/big_vision/trainers/proj/cappa/generative.py new file mode 100644 index 0000000000000000000000000000000000000000..a4b8873a45eac7e5c4d5d9dfea8c42abccda8283 --- /dev/null +++ b/big_vision/trainers/proj/cappa/generative.py @@ -0,0 +1,498 @@ +# Copyright 2023 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Training loop for CapPa (https://arxiv.org/abs/2306.07915).""" +# pylint: disable=consider-using-from-import +import functools +import importlib +import multiprocessing.pool +import os + +from absl import app +from absl import flags +from absl import logging +import big_vision.datasets.core as ds_core +import big_vision.evaluators.common as eval_common +import big_vision.input_pipeline as input_pipeline +import big_vision.optax as bv_optax +import big_vision.sharding as bv_sharding +import big_vision.trainers.proj.cappa.predict_fns as predict_fns +import big_vision.utils as u +from clu import parameter_overview +import flax +import flax.linen as nn +import jax +from jax.experimental import mesh_utils +from jax.experimental import multihost_utils +from jax.experimental.array_serialization import serialization as array_serial +import jax.numpy as jnp +from ml_collections import config_flags +import numpy as np +import optax +import tensorflow as tf + +from tensorflow.io import gfile + +# pylint: disable=logging-fstring-interpolation + + +config_flags.DEFINE_config_file( + "config", None, "Training configuration.", lock_config=True) + +flags.DEFINE_string("workdir", default=None, help="Work unit directory.") +flags.DEFINE_boolean("cleanup", default=False, + help="Delete workdir (only) after successful completion.") + +# Adds jax flags to the program. +jax.config.parse_flags_with_absl() +# Transfer guard will fail the program whenever that data between a host and +# a device is transferred implicitly. This often catches subtle bugs that +# cause slowdowns and memory fragmentation. Explicit transfers are done +# with jax.device_put and jax.device_get. +jax.config.update("jax_transfer_guard", "disallow") +# Fixes design flaw in jax.random that may cause unnecessary d2d comms. +jax.config.update("jax_threefry_partitionable", True) + + +def main(argv): + del argv + + try: + jax.distributed.initialize() + except ValueError as e: + logging.warning('Could not initialize distributed environment: %s', e) + + # Make sure TF does not touch GPUs. + tf.config.set_visible_devices([], "GPU") + + # Consistent device order is important to ensure correctness of various train + # loop components, such as input pipeline, update step, evaluators. We use + # jax utils to infer device order that will be used throughout the program. + config = flags.FLAGS.config + num_sharded_replicas = config.get("num_sharded_replicas", 1) + assert jax.device_count() % num_sharded_replicas == 0, ( + num_sharded_replicas, jax.device_count()) + devices = mesh_utils.create_device_mesh( + (num_sharded_replicas, jax.device_count() // num_sharded_replicas)) + devices_flat = devices.reshape(-1) + +################################################################################ +# # +# Set up logging # +# # +################################################################################ + + # Set up work directory and print welcome message. + config = flags.FLAGS.config + workdir = flags.FLAGS.workdir + logging.info( + f"\u001b[33mHello from process {jax.process_index()} holding " + f"{jax.local_device_count()}/{jax.device_count()} devices and " + f"writing to workdir {workdir}.\u001b[0m") + + save_ckpt_path = None + if workdir: # Always create if requested, even if we may not write into it. + gfile.makedirs(workdir) + save_ckpt_path = os.path.join(workdir, "checkpoint.bv") + + # The pool is used to perform misc operations such as logging in async way. + pool = multiprocessing.pool.ThreadPool() + + # Here we register preprocessing ops from modules listed on `pp_modules`. + for m in config.get("pp_modules", ["ops_general", "ops_image", "ops_text"]): + importlib.import_module(f"big_vision.pp.{m}") + + # Setup up logging and experiment manager. + xid, wid = -1, -1 + fillin = lambda s: s + def info(s, *a): + logging.info("\u001b[33mNOTE\u001b[0m: " + s, *a) + def write_note(note): + if jax.process_index() == 0: + info("%s", note) + + mw = u.BigVisionMetricWriter(xid, wid, workdir, config) + +################################################################################ +# # +# Input Pipeline # +# # +################################################################################ + + write_note("Initializing train dataset...") + batch_size = config.input.batch_size + if batch_size % jax.device_count() != 0: + raise ValueError(f"Batch size ({batch_size}) must " + f"be divisible by device number ({jax.device_count()})") + info("Global batch size %d on %d hosts results in %d local batch size. With " + "%d dev per host (%d dev total), that's a %d per-device batch size.", + batch_size, jax.process_count(), batch_size // jax.process_count(), + jax.local_device_count(), jax.device_count(), + batch_size // jax.device_count()) + + train_ds, ntrain_img = input_pipeline.training(config.input) + + total_steps = u.steps("total", config, ntrain_img, batch_size) + def get_steps(name, default=ValueError, cfg=config): + return u.steps(name, cfg, ntrain_img, batch_size, total_steps, default) + + u.chrono.inform(total_steps=total_steps, global_bs=batch_size, + steps_per_epoch=ntrain_img / batch_size, + measure=mw.measure, write_note=write_note) + + info("Running for %d steps, that means %f epochs", + total_steps, total_steps * batch_size / ntrain_img) + + # Start input pipeline as early as possible. + n_prefetch = config.get("prefetch_to_device", 1) + train_iter = input_pipeline.start_global(train_ds, devices_flat, n_prefetch) + + # For mixed data, add per-dataset epoch and examples seen measurements. + if isinstance(config.input.data.get("name"), str): + measure_per_dataset_times = lambda step: None # No-op + else: + nexamples = { + name: ds_core.get(**config.input[name].data).total_examples + for name in config.input.data + } + def measure_per_dataset_times(step): + total = sum(config.input.data.values()) + for name, w in config.input.data.items(): + w = w / total + mw.measure(f"examples_seen_{name}", u.chrono.accum_examples_seen * w) + mw.measure(f"epoch_{name}", step * batch_size * w / nexamples[name]) + +################################################################################ +# # +# Create Model & Optimizer # +# # +################################################################################ + + write_note(f"Initializing {config.model_name} model...") + model_mod = importlib.import_module(f"big_vision.models.{config.model_name}") + model = model_mod.Model(**config.get("model", {})) + + def init(rng): + bs = batch_size // jax.device_count() + img_shape = (bs,) + tuple(train_ds.element_spec["image"].shape[1:]) + txt_shape = (bs,) + tuple(train_ds.element_spec["labels"].shape[1:]) + dummy_img = jnp.zeros(img_shape, jnp.float32) + dummy_txt = jnp.zeros(txt_shape, jnp.int64) + variables = model.init(rng, dummy_img, dummy_txt) + return flax.core.unfreeze(variables["params"]) + + # This seed makes the Jax part of things (like model init) deterministic. + # However, full training still won't be deterministic, for example due to the + # tf.data pipeline not being deterministic even if we would set TF seed. + # See (internal link) for a fun read on what it takes. + rng = jax.random.PRNGKey(u.put_cpu(config.get("seed", 0))) + + rng, rng_init = jax.random.split(rng) + with u.chrono.log_timing("z/secs/init"): + write_note("Inferring parameter shapes...") + params_shape = jax.eval_shape(init, rng_init) + + write_note("Inferring optimizer state shapes...") + tx, sched_fns = bv_optax.make(config, params_shape, sched_kw=dict( + total_steps=total_steps, batch_size=batch_size, data_size=ntrain_img)) + opt_shape = jax.eval_shape(tx.init, params_shape) + # We jit this, such that the arrays are created on the CPU, not device[0]. + sched_fns_cpu = [u.jit_cpu()(sched_fn) for sched_fn in sched_fns] + + num_params = sum(np.prod(p.shape) for p in jax.tree_leaves(params_shape)) + mw.measure("num_params", num_params) + +################################################################################ +# # +# Shard & Transfer # +# # +################################################################################ + + # Using a 2D mesh where the model/optimizer parameters are replicated along + # the "replica" axis and sharded along the "fsdp" axis if the sharding is set + # to "fully_sharded" for the model/optimizer. + write_note("Creating device mesh...") + mesh = jax.sharding.Mesh(devices, ("replica", "fsdp")) + repl_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()) + + write_note("Inferring shardings...") + params_sharding = bv_sharding.infer_sharding( + params_shape, mesh, axis_name="fsdp", + # TODO: implement scan for parameter sharding. + strategy=config.get("param_sharding", "replicated"), + extra_strategy_args=config.get("param_sharding_args", {})) + opt_sharding = bv_sharding.infer_sharding( + opt_shape, mesh, axis_name="fsdp", + strategy=config.get("optim_sharding", "replicated"), + extra_strategy_args=config.get("optim_sharding_args", {})) + + write_note("Transferring train_state to devices...") + # RNG is always replicated + rng_init = u.reshard(rng_init, repl_sharding) + + # Parameters and the optimizer are now global (distributed) jax arrays. + params = jax.jit(init, out_shardings=params_sharding)(rng_init) + opt = jax.jit(tx.init, out_shardings=opt_sharding)(params) + + rng, rng_loop = jax.random.split(rng, 2) + rng_loop = u.reshard(rng_loop, repl_sharding) + del rng # not used anymore, so delete it. + + # At this point we have everything we need to form a train state. It contains + # all the parameters that are passed and updated by the main training step. + train_state_sharding = { + "params": params_sharding, "opt": opt_sharding, "rng": repl_sharding} + train_state = { + "params": params, "opt": opt, "rng": rng_loop} + del params, opt, rng_loop # Delete to avoid memory leak or accidental reuse. + + write_note("Logging parameter overview...") + parameter_overview.log_parameter_overview( + train_state["params"], msg="Init params", + include_stats="global", jax_logging_process=0) + +################################################################################ +# # +# Update Step # +# # +################################################################################ + + @functools.partial( + jax.jit, + donate_argnums=(0,), + out_shardings=(train_state_sharding, repl_sharding)) + def update_fn(train_state, batch): + """Update step.""" + + images, labels, label_masks = ( + batch["image"], batch["labels"], batch.get("label_masks")) + + # Get device-specific loss rng. + rng = train_state["rng"] + rng, rng_model = jax.random.split(rng, 2) + + def loss_fn(params): + logits = model.apply( + {"params": params}, images, labels, + train=True, rngs={"dropout": rng_model}) + + weights = jnp.where(labels != config.get("pad_token", 0), 1.0, 0.0) + if label_masks is not None: + weights = weights * label_masks + + loss = u.weighted_softmax_xent( + logits=logits, labels=labels, + weights=weights, label_smoothing=config.get("label_smoothing", 0.0), + reduction=True, normalize=True) + + return loss + + params, opt = train_state["params"], train_state["opt"] + loss, grads = jax.value_and_grad(loss_fn)(params) + updates, opt = tx.update(grads, opt, params) + params = optax.apply_updates(params, updates) + + measurements = {"training_loss": loss} + gs = jax.tree_leaves(bv_optax.replace_frozen(config.schedule, grads, 0.)) + measurements["l2_grads"] = jnp.sqrt(sum([jnp.sum(g * g) for g in gs])) + ps = jax.tree_leaves(params) + measurements["l2_params"] = jnp.sqrt(sum([jnp.sum(p * p) for p in ps])) + us = jax.tree_leaves(updates) + measurements["l2_updates"] = jnp.sqrt(sum([jnp.sum(u * u) for u in us])) + + return {"params": params, "opt": opt, "rng": rng}, measurements + +################################################################################ +# # +# Load Checkpoint # +# # +################################################################################ + + # Decide how to initialize training. The order is important. + # 1. Always resumes from the existing checkpoint, e.g. resumes a finetune job. + # 2. Resume from a previous checkpoint, e.g. start a cooldown training job. + # 3. Initialize model from something, e,g, start a fine-tuning job. + # 4. Train from scratch. + resume_ckpt_path = None + if save_ckpt_path and gfile.exists(f"{save_ckpt_path}-LAST"): + resume_ckpt_path = save_ckpt_path + elif config.get("resume"): + resume_ckpt_path = fillin(config.resume) + + ckpt_mngr = None + if save_ckpt_path or resume_ckpt_path: + ckpt_mngr = array_serial.GlobalAsyncCheckpointManager() + + if resume_ckpt_path: + write_note(f"Resuming training from checkpoint {resume_ckpt_path}...") + jax.tree_map(lambda x: x.delete(), train_state) + del train_state + shardings = { + **train_state_sharding, + "chrono": jax.tree_map(lambda _: repl_sharding, + u.chrono.save()), + } + loaded = u.load_checkpoint_ts( + resume_ckpt_path, tree=shardings, shardings=shardings) + train_state = {key: loaded[key] for key in train_state_sharding.keys()} + + u.chrono.load(jax.device_get(loaded["chrono"])) + del loaded + elif config.get("model_init"): + write_note(f"Initialize model from {config.model_init}...") + train_state["params"] = model_mod.load( + train_state["params"], config.model_init, config.get("model"), + **config.get("model_load", {})) + + # load has the freedom to return params not correctly sharded. Think of for + # example ViT resampling position embedings on CPU as numpy arrays. + train_state["params"] = u.reshard( + train_state["params"], train_state_sharding["params"]) + + parameter_overview.log_parameter_overview( + train_state["params"], msg="restored params", + include_stats="global", jax_logging_process=0) + +################################################################################ +# # +# Setup Evals # +# # +################################################################################ + + # Only initialize evaluators when they are first needed. + @functools.lru_cache(maxsize=None) + def evaluators(): + return eval_common.from_config( + config, + predict_fns.get_predict_fns(model), + lambda s: write_note(f"Init evaluator: {s}…\n{u.chrono.note}"), + lambda key, cfg: get_steps(key, default=None, cfg=cfg), + devices_flat, + ) + + # At this point we need to know the current step to see whether to run evals. + write_note("Inferring the first step number...") + first_step_device = bv_optax.get_count(train_state["opt"], jittable=True) + first_step = int(jax.device_get(first_step_device)) + u.chrono.inform(first_step=first_step) + + # Note that training can be pre-empted during the final evaluation (i.e. + # just after the final checkpoint has been written to disc), in which case we + # want to run the evals. + if first_step in (total_steps, 0): + write_note("Running initial or final evals...") + mw.step_start(first_step) + for (name, evaluator, _, prefix) in evaluators(): + if config.evals[name].get("skip_first") and first_step != total_steps: + continue + write_note(f"{name} evaluation...\n{u.chrono.note}") + with u.chrono.log_timing(f"z/secs/eval/{name}"): + with mesh, nn.logical_axis_rules([("act_batch", ("replica", "fsdp"))]): # pytype: disable=wrong-arg-types + for key, value in evaluator.run(train_state): + mw.measure(f"{prefix}{key}", value) + +################################################################################ +# # +# Train Loop # +# # +################################################################################ + + prof = None # Keeps track of start/stop of profiler state. + + write_note("Starting training loop, compiling the first step...") + for step, batch in zip(range(first_step + 1, total_steps + 1), train_iter): + mw.step_start(step) + + with jax.profiler.StepTraceAnnotation("train_step", step_num=step): + with u.chrono.log_timing("z/secs/update0", noop=step > first_step + 1): + with mesh, nn.logical_axis_rules([("act_batch", ("replica", "fsdp"))]): # pytype: disable=wrong-arg-types + train_state, measurements = update_fn(train_state, batch) + + # On the first host, let's always profile a handful of early steps. + if jax.process_index() == 0: + prof = u.startstop_prof(prof, step, first_step, get_steps("log_training")) + + # Report training progress + if (u.itstime(step, get_steps("log_training"), total_steps, host=0) + or u.chrono.warmup and jax.process_index() == 0): + for i, sched_fn_cpu in enumerate(sched_fns_cpu): + mw.measure(f"global_schedule{i if i else ''}", + sched_fn_cpu(u.put_cpu(step - 1))) + measurements = jax.device_get(measurements) + for name, value in measurements.items(): + mw.measure(name, value) + u.chrono.tick(step) + measure_per_dataset_times(step) + + if not np.isfinite(measurements["training_loss"]): + raise RuntimeError(f"The loss became nan or inf somewhere within steps " + f"[{step - get_steps('log_training')}, {step}]") + + # Checkpoint saving + keep_ckpt_steps = get_steps("keep_ckpt", None) or total_steps + if save_ckpt_path and ( + (keep := u.itstime(step, keep_ckpt_steps, total_steps, first=False)) + or u.itstime(step, get_steps("ckpt", None), total_steps, first=True) + ): + u.chrono.pause(wait_for=train_state) + + # Copy because we add extra stuff to the checkpoint. + ckpt = {**train_state} + + # To save chrono state correctly and safely in a multihost setup, we + # broadcast the state to all hosts and convert it to a global array. + with jax.transfer_guard("allow"): + chrono_ckpt = multihost_utils.broadcast_one_to_all(u.chrono.save()) + chrono_shardings = jax.tree_map(lambda _: repl_sharding, chrono_ckpt) + ckpt = ckpt | {"chrono": u.reshard(chrono_ckpt, chrono_shardings)} + + u.save_checkpoint_ts(ckpt_mngr, ckpt, save_ckpt_path, step, keep) + u.chrono.resume() + + for (name, evaluator, log_steps, prefix) in evaluators(): + if u.itstime(step, log_steps, total_steps, first=False, last=True): + u.chrono.pause(wait_for=train_state) + u.chrono.tick(step) # Record things like epoch number, core hours etc. + write_note(f"{name} evaluation...\n{u.chrono.note}") + with u.chrono.log_timing(f"z/secs/eval/{name}"): + with mesh, nn.logical_axis_rules( + [("act_batch", ("replica", "fsdp"))]): # pytype: disable=wrong-arg-types + for key, value in evaluator.run(train_state): + mw.measure(f"{prefix}{key}", jax.device_get(value)) + u.chrono.resume() + mw.step_end() + + # Always give a chance to stop the profiler, no matter how things ended. + # TODO: can we also do this when dying of an exception like OOM? + if jax.process_index() == 0 and prof is not None: + u.startstop_prof(prof) + + # Last note needs to happen before the pool's closed =) + write_note(f"Done!\n{u.chrono.note}") + + pool.close() + pool.join() + mw.close() + + if ckpt_mngr: + ckpt_mngr.wait_until_finished() + + # Make sure all hosts stay up until the end of main. + u.sync() + + u.maybe_cleanup_workdir(workdir, flags.FLAGS.cleanup, info) + + +if __name__ == "__main__": + app.run(main) diff --git a/big_vision/trainers/proj/cappa/predict_fns.py b/big_vision/trainers/proj/cappa/predict_fns.py new file mode 100644 index 0000000000000000000000000000000000000000..4e369f23cd840370cdd419b5e67f377e2dcf4ad1 --- /dev/null +++ b/big_vision/trainers/proj/cappa/predict_fns.py @@ -0,0 +1,118 @@ +# Copyright 2023 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Prediction functions for clippo/generative.py.""" + +import functools + +import big_vision.pp.ops_text as pp_ops_text +import big_vision.utils as u +import jax +import jax.numpy as jnp +import numpy as np + +# pylint: disable=missing-function-docstring + + +# We do not jit/pmap this function, because it is passed to evaluator that +# does it later. We output as many intermediate tensors as possible for +# maximal flexibility. Later `jit` will prune out things that are not needed. +def predict_fn_perplexity( + train_state, batch, *, model): + logits = model.apply( + {"params": train_state["params"]}, + batch["image"], + batch["labels"], + train=False, + ) + return logits, {"logits": logits} + + +def predict_fn_enc_rep(train_state, batch, *, model): + logits, out = model.apply( + {"params": train_state["params"]}, + batch["image"], + None, + train=False, + return_enc_features=True, + ) + return logits, out + + +def predict_fn_score( + train_state, batch, *, model, prompt="", prompt_tokenizer=""): + """For a batch of images, return score (LL) for each image-label pair.""" + encoded = model.apply( + {"params": train_state["params"]}, + batch["image"], + train=False, + method=model.encode, + ) + + # This needs to be added by the evaluator. It is the pre-computed tokenized + # list of all available labels. For ImageNet-1k, that's (1000, 13). + all_labels = batch["_label_tokens"] + + # Optionally prefix a single prompt to all labels: + if prompt and prompt_tokenizer: + prompt = make_prompt(prompt, prompt_tokenizer) # Note: this is cached. + prompts = jnp.tile(prompt, (all_labels.shape[0], 1)) + all_labels = jnp.concatenate([prompts, all_labels], axis=-1) + # For ImageNet-1k and a prompt of length 2, we now have (1000, 15). + + def score_label(label): + """Score (LogLik) each minibatch example (image) with a single `label`.""" + label_rep = jnp.tile(label, (encoded.shape[0], 1)) + logits = model.apply( + {"params": train_state["params"]}, + encoded, + label_rep, + train=False, + decode=False, + method=model.decode, + ) + # The returned value is (batch,) scalars, the score each image has with + # this label. We turn the softmax_xent's NLL into LL so higher = better. + return -u.weighted_softmax_xent( + logits=logits, + labels=label_rep, + weights=(label_rep > 0).astype(jnp.float32), # Ignore (=0). + reduction=False, + normalize=False, + ) + + # Use lax.map() instead of vmap() to conserve memory. + nlls = jax.lax.map(score_label, all_labels) # -> (nlabel, batch) + return nlls.T # -> (batch, nlabel) array of scores. + + +@functools.cache +def make_prompt(prompt, tokenizer_path, seq_len=None): + """Tokenizes `prompt` with specified tokenizer, with optional padding.""" + tokenizer = pp_ops_text.create_tokenizer(tokenizer_path, add_eos=False) + + prompt = tokenizer.tokenize(prompt).numpy() + if seq_len: + prompt = np.pad(prompt, (0, seq_len - len(prompt))).astype(np.int32) + return prompt + + +def get_predict_fns(model): + """Returns `predict_fns` for evaluators.""" + fns = { + "perplexity": predict_fn_perplexity, + "score": predict_fn_score, + "enc_rep": predict_fn_enc_rep, + } + return {name: functools.partial(fn, model=model) for name, fn in fns.items()} diff --git a/big_vision/trainers/proj/distill/distill.py b/big_vision/trainers/proj/distill/distill.py new file mode 100644 index 0000000000000000000000000000000000000000..e42fbc2c7b9c72c1ba78d1dddc949f65fc90be69 --- /dev/null +++ b/big_vision/trainers/proj/distill/distill.py @@ -0,0 +1,473 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Training loop for distillation as in https://arxiv.org/abs/2106.05237. + +It works by having a (set of) teacher model(s) defined the same way as student +in the config and, for now, only distilling logits with one of many loss +functions. + +We explored distilling intermediate feature maps, extra data, and other tricks +in depth in two interships in a separate prototype codebase but eventually they +are not necessary, and thus not (yet?) implemented in this codebase. + +Thus, for now, there are no extra learnable parameters besides the student. +This keeps code relatively simple. +""" +# pylint: disable=consider-using-from-import +import functools +import importlib +import multiprocessing.pool +import os + +from absl import app +from absl import flags +from absl import logging +import big_vision.evaluators.common as eval_common +import big_vision.evaluators.proj.distill.distance as dd +import big_vision.input_pipeline as input_pipeline +import big_vision.optax as bv_optax +import big_vision.utils as u +from clu import parameter_overview +import flax +import jax +from jax.experimental import mesh_utils +import jax.numpy as jnp +from jax.sharding import NamedSharding +from jax.sharding import PartitionSpec as P +from ml_collections import config_flags +import numpy as np +import optax +import tensorflow as tf + +from tensorflow.io import gfile + +# pylint: disable=logging-fstring-interpolation + + +config_flags.DEFINE_config_file( + "config", None, "Training configuration.", lock_config=True) + +flags.DEFINE_string("workdir", default=None, help="Work unit directory.") +flags.DEFINE_boolean("cleanup", default=False, + help="Delete workdir (only) after successful completion.") + +# Adds jax flags to the program. +jax.config.parse_flags_with_absl() +# Fixes design flaw in jax.random that may cause unnecessary d2d comms. +jax.config.update("jax_threefry_partitionable", True) + + +def getfirst(d, *keys): + """Returns the first of `keys` that's present in mapping `d`.""" + result, found = None, False + for k in reversed(keys): + if k in d: + result, found = d[k], True + if found: + return result + else: + raise KeyError(f"None of {keys} is in {d.keys()}") + + +def main(argv): + del argv + tf.config.experimental.set_visible_devices([], "GPU") + + config = flags.FLAGS.config + workdir = flags.FLAGS.workdir + logging.info( + f"\u001b[33mHello from process {jax.process_index()} holding " + f"{jax.local_device_count()}/{jax.device_count()} devices and " + f"writing to workdir {workdir}.\u001b[0m") + + save_ckpt_path = None + if workdir: # Always create if requested, even if we may not write into it. + gfile.makedirs(workdir) + save_ckpt_path = os.path.join(workdir, "checkpoint.npz") + + # The pool is used to perform misc operations such as logging in async way. + pool = multiprocessing.pool.ThreadPool() + + # Here we register preprocessing ops from modules listed on `pp_modules`. + for m in config.get("pp_modules", ["ops_general", "ops_image", "ops_text"]): + importlib.import_module(f"big_vision.pp.{m}") + + # This seed makes the Jax part of things (like model init) deterministic. + # However, full training still won't be deterministic, for example due to the + # tf.data pipeline not being deterministic even if we would set TF seed. + # See (internal link) for a fun read on what it takes. + rng = jax.random.PRNGKey(config.get("seed", 0)) + + # These functions do more stuff internally, for OSS release we mock them by + # trivial alternatives in order to minize disruptions in the code. + xid, wid = -1, -1 + fillin = lambda s: s + def info(s, *a): + logging.info("\u001b[33mNOTE\u001b[0m: " + s, *a) + def write_note(note): + if jax.process_index() == 0: + info("%s", note) + + write_note("Initializing...") + + batch_size = config.input.batch_size + if batch_size % jax.device_count() != 0: + raise ValueError(f"Batch size ({batch_size}) must " + f"be divisible by device number ({jax.device_count()})") + info("Global batch size %d on %d hosts results in %d local batch size. With " + "%d dev per host (%d dev total), that's a %d per-device batch size.", + batch_size, jax.process_count(), batch_size // jax.process_count(), + jax.local_device_count(), jax.device_count(), + batch_size // jax.device_count()) + + # First thing after above sanity checks, so we can log "start" ticks. + mw = u.BigVisionMetricWriter(xid, wid, workdir, config) + + devices = mesh_utils.create_device_mesh((jax.device_count(),)) + mesh = jax.sharding.Mesh(devices, ("data",)) + repl_sharding = NamedSharding(mesh, P()) + + write_note("Initializing train dataset...") + train_ds, ntrain_img = input_pipeline.training(config.input) + + # Start prefetching already. + n_prefetch = config.get("prefetch_to_device", 1) + train_iter = input_pipeline.start_global(train_ds, devices, n_prefetch) + + total_steps = u.steps("total", config, ntrain_img, batch_size) + def get_steps(name, default=ValueError, cfg=config): + return u.steps(name, cfg, ntrain_img, batch_size, total_steps, default) + + u.chrono.inform(total_steps=total_steps, global_bs=batch_size, + steps_per_epoch=ntrain_img / batch_size, + measure=mw.measure, write_note=write_note) + + info("Running for %d steps, that means %f epochs", + total_steps, total_steps * batch_size / ntrain_img) + + # Create student and teacher models + def get_model_mod(name): # Used many times. + mod_name = config[f"{name}_name"] + return importlib.import_module(f"big_vision.models.{mod_name}") + + write_note("Initializing models...") + def make_model(name): + return get_model_mod(name).Model( + num_classes=config.num_classes, **config.get(name, {})) + + models = { + "student": make_model("student"), + **{t: make_model(t) for t in config.teachers} + } + + # We want all parameters to be created in host RAM, not on any device, they'll + # be sent there later as needed, otherwise we already encountered two + # situations where we allocate them twice. + def get_init(model, name): + @functools.partial(jax.jit, backend="cpu") + def _init(rng): + bs = batch_size // jax.device_count() + img_size = tuple(getfirst(train_ds.element_spec, name, "image").shape[1:]) + no_image = jnp.zeros((bs,) + img_size, jnp.float32) + params = flax.core.unfreeze(model.init(rng, no_image))["params"] + + # Set bias in the head to a low value, such that loss is small initially. + if "init_head_bias" in config: + params["head"]["bias"] = jnp.full_like(params["head"]["bias"], + config["init_head_bias"]) + return params + return _init + + rng, *rng_inits = jax.random.split(rng, len(models) + 1) + with u.chrono.log_timing("z/secs/init"): + params_cpu = { + name: get_init(models[name], name=name)(r) + for name, r in zip(models, rng_inits)} + + if jax.process_index() == 0: + for name, params in params_cpu.items(): + parameter_overview.log_parameter_overview(params, msg=f"{name} params") + mw.measure(f"num_params_{name}", + sum(p.size for p in jax.tree_leaves(params))) + + write_note(f"Initializing {config.optax_name} optimizer...") + # For now, we explicitly only optimize the student parameters as there's + # nothing else to be optimized. If we ever want to add learnable projections + # or similar for good (we explored but ditched), need to refactor this a bit. + tx, sched_fns = bv_optax.make( + config, params_cpu["student"], sched_kw=dict( + total_steps=total_steps, batch_size=batch_size, data_size=ntrain_img)) + + # We jit this, such that the arrays are created on the CPU, not device[0]. + opt_cpu = jax.jit(tx.init, backend="cpu")(params_cpu["student"]) + sched_fns_cpu = [jax.jit(sched_fn, backend="cpu") for sched_fn in sched_fns] + + @jax.named_call + def loss_fn(student_params, params, data, rngs): + # Note: need to extract and use `student_params` out of `params` because the + # first argument of `loss_fn` is what's differentiated wrt. + params["student"] = student_params + + def fwd(name, params): + return jax.named_call(models[name].apply, name=name)( + {"params": params}, getfirst(data, name, "image"), + train=name == "student", rngs=rngs.get(name) + )[0] # logits, unused_outputs + logits = {name: fwd(name, w) for name, w in params.items()} + + measurements = {} + for name, lg in logits.items(): + measurements[f"entropy_{name}"] = -jnp.sum( + jax.nn.log_softmax(lg) * jax.nn.softmax(lg), axis=-1) + if "labels" in data: + measurements[f"task_loss_{name}"] = u.softmax_xent( + logits=lg, labels=data["labels"], reduction=False) + + # NOTE: xent is linear in labels, so for KL, this is actually the same as + # using a teacher-ensemble in probs-space! + measurements["distill_loss"] = 0.0 + for name in config.teachers: + l = dd.dist(logits["student"], logits[name], config.get("distance", "kl"), + **config.get("distance_kw", {})) + measurements[f"distill_loss_{name}"] = l + measurements["distill_loss"] += l + + outputs = (measurements["distill_loss"], measurements) + return jax.tree_map(jnp.mean, outputs) + + @functools.partial( + jax.jit, donate_argnums=(0, 1, 2), out_shardings=repl_sharding + ) + def update_fn(params, opt, rng, data): + """Update step.""" + + # Mixup. Note: overwrites the `data` entries (that's intended). + if config.get("mixup") and config.mixup.p: + to_mix = {name: data[name] + for name in ("image", "labels") + tuple(models) if name in data} + rng, _, to_mix = u.mixup(rng, **config.mixup, **to_mix) + data = {**data, **to_mix} + + # Get model-specific loss rng. + rng, *rng_models = jax.random.split(rng, len(models) + 1) + rngs_model_dicts = { + name: {"dropout": rngi} for name, rngi in zip(models, rng_models) + } + + w = params["student"] # Need to explicitly pull out the optimized ones. + (l, measurements), grads = jax.value_and_grad(loss_fn, has_aux=True)( + w, params, data, rngs=rngs_model_dicts + ) + updates, opt = tx.update(grads, opt, w) + w = optax.apply_updates(w, updates) + params["student"] = w + + # Take some logging measurements + gs = jax.tree_leaves(bv_optax.replace_frozen(config.schedule, grads, 0.)) + measurements["l2_grads"] = jnp.sqrt(sum([jnp.vdot(g, g) for g in gs])) + ps = jax.tree_leaves(w) + measurements["l2_params"] = jnp.sqrt(sum([jnp.vdot(p, p) for p in ps])) + us = jax.tree_leaves(updates) + measurements["l2_updates"] = jnp.sqrt(sum([jnp.vdot(u, u) for u in us])) + + return params, opt, rng, l, measurements + + # We always load the teachers first, because they NEED to be initialized + # and since we don't ever modify them, we don't store them in checkpoints. + for name in config.teachers: + init_def = config[f"{name}_init"] + write_note(f"Initializing {name} from {init_def}…") + params_cpu[name] = get_model_mod(name).load( + params_cpu[name], init_def, config[name], + **config.get(f"{name}_load", {})) + + # Decide how to initialize training. The order is important. + # 1. Always resumes from the existing checkpoint, e.g. resumes a finetune job. + # 2. Resume from a previous checkpoint, e.g. start a cooldown training job. + # 3. Initialize student from something, e.g. start a fine-tuning job. + # 4. Train from scratch. + resume_ckpt_path = None + if save_ckpt_path and gfile.exists(save_ckpt_path): + resume_ckpt_path = save_ckpt_path + elif config.get("resume"): + resume_ckpt_path = fillin(config.resume) + if resume_ckpt_path: + write_note("Resume training from checkpoint...") + # NOTE: we never change the teachers, so only checkpoint student here. + checkpoint = { + "params": params_cpu["student"], + "opt": opt_cpu, + "chrono": u.chrono.save(), + } + checkpoint_tree = jax.tree_structure(checkpoint) + loaded = u.load_checkpoint_np(resume_ckpt_path, checkpoint_tree) + # bfloat16 type gets lost when data is saved to disk, so we recover it. + checkpoint = jax.tree_map(u.recover_dtype, loaded) + params_cpu["student"], opt_cpu = checkpoint["params"], checkpoint["opt"] + u.chrono.load(checkpoint["chrono"]) + elif config.get("student_init"): + write_note(f"Initialize student from {config.student_init}...") + params_cpu["student"] = get_model_mod("student").load( + params_cpu["student"], config.student_init, config.get("student"), + **config.get("student_load", {})) + if jax.process_index() == 0: + parameter_overview.log_parameter_overview( + params_cpu["student"], msg="restored (student) params") + + write_note("Kicking off misc stuff...") + first_step = bv_optax.get_count(opt_cpu) + u.chrono.inform(first_step=first_step) + prof = None # Keeps track of start/stop of profiler state. + + write_note(f"Replicating...\n{u.chrono.note}") + params_repl = u.reshard(params_cpu, repl_sharding) + opt_repl = u.reshard(opt_cpu, repl_sharding) + + # Define predict functions that the evaluators can use: + # 1. One per model + predict_fns = {} + for name, model in models.items(): + def fwd(train_state, batch, n=name, m=model): + return m.apply({"params": train_state["params"][n]}, batch["image"]) + predict_fns[f"{name}_fwd"] = fwd + # 2. One for the ensemble of all teachers. + def teacher_ensemble_fwd(train_state, batch): + all_teacher_logits = [ + models[name].apply(train_state["params"][name], batch["image"])[0] + for name in config.teachers + ] + return jnp.mean([jax.nn.softmax(l) for l in all_teacher_logits], axis=0), {} # pytype: disable=wrong-arg-types # jnp-type + predict_fns["teacher_ensemble_fwd"] = teacher_ensemble_fwd + # 3.One for each (student, teacher) pair, eg for distance eval. + for name in [*config.teachers, "teacher_ensemble"]: + def fwd(train_state, batch, n=name): # pylint: disable=function-redefined + student_ret = predict_fns["student_fwd"](train_state, batch) + teacher_ret = predict_fns[f"{n}_fwd"](train_state, batch) + return student_ret, teacher_ret + predict_fns[f"student_{name}_fwd"] = fwd + + # Only initialize evaluators when they are first needed. + @functools.lru_cache(maxsize=None) + def evaluators(): + return eval_common.from_config( + config, + predict_fns, + lambda s: write_note(f"Init evaluator: {s}…\n{u.chrono.note}"), + lambda key, cfg: get_steps(key, default=None, cfg=cfg), + devices, + ) + + rng, rng_loop = jax.random.split(rng, 2) + rngs_loop = u.reshard(rng_loop, repl_sharding) + ckpt_writer = None + + write_note(f"First step compilations...\n{u.chrono.note}") + + # Note that training can be pre-empted during the final evaluation (i.e. + # just after the final checkpoint has been written to disc), in which case we + # want to run the evals. + if first_step in (total_steps, 0): + mw.step_start(first_step) + for (name, evaluator, _, prefix) in evaluators(): + if config.evals[name].get("skip_first") and first_step != total_steps: + continue + write_note(f"{name} evaluation...\n{u.chrono.note}") + with u.chrono.log_timing(f"z/secs/eval/{name}"): + for key, value in evaluator.run({"params": params_repl}): + mw.measure(f"{prefix}{key}", value) + + # Using a python integer for step here, because opt.state.step is allocated + # on TPU during replication. + for step, batch in zip(range(first_step + 1, total_steps + 1), train_iter): + mw.step_start(step) + + with jax.profiler.StepTraceAnnotation("train_step", step_num=step): + with u.chrono.log_timing("z/secs/update0", noop=step > first_step + 1): + params_repl, opt_repl, rngs_loop, loss_value, measurements = update_fn( + params_repl, opt_repl, rngs_loop, batch) + + # On the first host, let's always profile a handful of early steps. + if jax.process_index() == 0: + prof = u.startstop_prof(prof, step, first_step, get_steps("log_training")) + + # Report training progress + if (u.itstime(step, get_steps("log_training"), total_steps, host=0) + or u.chrono.warmup and jax.process_index() == 0): + for i, sched_fn_cpu in enumerate(sched_fns_cpu): + mw.measure(f"global_schedule{i if i else ''}", sched_fn_cpu(step - 1)) + l = mw.measure("training_loss", loss_value) + for name, value in measurements.items(): + mw.measure(name, value) + u.chrono.tick(step) + if not np.isfinite(l): + raise RuntimeError(f"The loss became nan or inf somewhere within steps " + f"[{step - get_steps('log_training')}, {step}]") + + # Checkpoint saving + if (save_ckpt_path and + (u.itstime(step, get_steps("ckpt", None), total_steps, host=0) or + u.itstime(step, get_steps("keep_ckpt", None), total_steps, host=0))): + u.chrono.pause(wait_for=(params_repl["student"], opt_repl)) + u.checkpointing_timeout(ckpt_writer, config.get("ckpt_timeout", 1)) + # We need to transfer the weights over now or else we risk keeping them + # alive while they'll be updated in a future step, creating hard to debug + # memory errors (see (internal link)). Also, takes device 0's params only. + params_cpu["student"], opt_cpu = jax.tree_map( + np.array, (params_repl["student"], opt_repl) + ) + + # Check whether we want to keep a copy of the current checkpoint. + copy_step = None + if u.itstime(step, get_steps("keep_ckpt", None), total_steps): + copy_step = step + + ckpt = {"params": params_cpu["student"], + "opt": opt_cpu, + "chrono": u.chrono.save()} + ckpt_writer = pool.apply_async( + u.save_checkpoint, (ckpt, save_ckpt_path, copy_step)) + u.chrono.resume() + + for (name, evaluator, log_steps, prefix) in evaluators(): + if u.itstime(step, log_steps, total_steps, first=False, last=True): + u.chrono.pause(wait_for=params_repl) + u.chrono.tick(step) # Record things like epoch number, core hours etc. + write_note(f"{name} evaluation...\n{u.chrono.note}") + with u.chrono.log_timing(f"z/secs/eval/{name}"): + for key, value in evaluator.run({"params": params_repl}): + mw.measure(f"{prefix}{key}", value) + u.chrono.resume() + mw.step_end() + + # Always give a chance to stop the profiler, no matter how things ended. + # TODO: can we also do this when dying of an exception like OOM? + if jax.process_index() == 0 and prof is not None: + u.startstop_prof(prof) + + # Last note needs to happen before the pool's closed =) + write_note(f"Done!\n{u.chrono.note}") + + pool.close() + pool.join() + mw.close() + + # Make sure all hosts stay up until the end of main. + u.sync() + + u.maybe_cleanup_workdir(workdir, flags.FLAGS.cleanup, info) + + +if __name__ == "__main__": + app.run(main) diff --git a/big_vision/trainers/proj/flexi/common.py b/big_vision/trainers/proj/flexi/common.py new file mode 100644 index 0000000000000000000000000000000000000000..2dd59883fb72b171a902a758f3aa578047dd0c3d --- /dev/null +++ b/big_vision/trainers/proj/flexi/common.py @@ -0,0 +1,47 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Few common utils used in both/all flexi-trainers.""" +import functools +import itertools +import numpy as np + + +def mkrng(xid, wid, step): + # Need to cap at 0, for example localruns use -1. + rng_key = (max(xid, 0), max(wid, 0), max(step, 0)) + return np.random.default_rng(rng_key) + + +def mkprob(x): + if x is None: + return x + return np.array(x) / np.sum(x) + + +def choice(values, ratios, rng=None): + rng = rng or np.random.default_rng() + return rng.choice(values, p=mkprob(ratios)) + + +def mkpredictfns(predict_fn, config, template="predict_{x}"): + # If we have two flexi args a=[1,2], b=[10,20], then we create a + # predict_fn for all possible combinations, named "predict_a=1_b=10" etc. + all_combinations = [dict(comb) for comb in itertools.product( + *[[(arg, val) for val in config[arg].v] for arg in config] + )] + return { + template.format(x="_".join(f"{k}={v}" for k, v in kw.items())): + functools.partial(predict_fn, **kw) + for kw in all_combinations} diff --git a/big_vision/trainers/proj/flexi/distill.py b/big_vision/trainers/proj/flexi/distill.py new file mode 100644 index 0000000000000000000000000000000000000000..00ce8ad93c32abd13bc8785bb9999a7be9ca9645 --- /dev/null +++ b/big_vision/trainers/proj/flexi/distill.py @@ -0,0 +1,464 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Distill a teacher model into a FlexiViT student. + +Note this file has code that is generic enough to allow using an ensemble +of teachers. This is inherited from `proj/distill/distill.py` and the goal +to only make minimal changes in a fork of that file. However, this feature +does not really make sense for FlexiViT. +""" +# pylint: disable=consider-using-from-import +import functools +import importlib +import multiprocessing.pool +import os + +from absl import app +from absl import flags +from absl import logging +import big_vision.evaluators.common as eval_common +import big_vision.evaluators.proj.distill.distance as dd +import big_vision.input_pipeline as input_pipeline +import big_vision.optax as bv_optax +import big_vision.trainers.proj.flexi.common as flexi +import big_vision.utils as u +from clu import parameter_overview +import flax +import jax +import jax.numpy as jnp +from ml_collections import config_flags +import numpy as np +import optax +import tensorflow as tf + +from tensorflow.io import gfile + +# pylint: disable=logging-fstring-interpolation + + +config_flags.DEFINE_config_file( + "config", None, "Training configuration.", lock_config=True) + +flags.DEFINE_string("workdir", default=None, help="Work unit directory.") +flags.DEFINE_boolean("cleanup", default=False, + help="Delete workdir (only) after successful completion.") + +# Adds jax flags to the program. +jax.config.parse_flags_with_absl() + + +def getfirst(d, *keys): + """Returns the first of `keys` that's present in mapping `d`.""" + result, found = None, False + for k in reversed(keys): + if k in d: + result, found = d[k], True + if found: + return result + else: + raise KeyError(f"None of {keys} is in {d.keys()}") + + +def main(argv): + del argv + tf.config.experimental.set_visible_devices([], "GPU") + + config = flags.FLAGS.config + workdir = flags.FLAGS.workdir + logging.info( + f"\u001b[33mHello from process {jax.process_index()} holding " + f"{jax.local_device_count()}/{jax.device_count()} devices and " + f"writing to workdir {workdir}.\u001b[0m") + + save_ckpt_path = None + if workdir: # Always create if requested, even if we may not write into it. + gfile.makedirs(workdir) + save_ckpt_path = os.path.join(workdir, "checkpoint.npz") + + # The pool is used to perform misc operations such as logging in async way. + pool = multiprocessing.pool.ThreadPool() + + # Here we register preprocessing ops from modules listed on `pp_modules`. + for m in config.get("pp_modules", ["ops_general", "ops_image", "ops_text"]): + importlib.import_module(f"big_vision.pp.{m}") + + # This seed makes the Jax part of things (like model init) deterministic. + # However, full training still won't be deterministic, for example due to the + # tf.data pipeline not being deterministic even if we would set TF seed. + # See (internal link) for a fun read on what it takes. + rng = jax.random.PRNGKey(config.get("seed", 0)) + + # These functions do more stuff internally, for OSS release we mock them by + # trivial alternatives in order to minize disruptions in the code. + xid, wid = -1, -1 + fillin = lambda s: s + def info(s, *a): + logging.info("\u001b[33mNOTE\u001b[0m: " + s, *a) + def write_note(note): + if jax.process_index() == 0: + info("%s", note) + + write_note("Initializing...") + + batch_size = config.input.batch_size + if batch_size % jax.device_count() != 0: + raise ValueError(f"Batch size ({batch_size}) must " + f"be divisible by device number ({jax.device_count()})") + info("Global batch size %d on %d hosts results in %d local batch size. With " + "%d dev per host (%d dev total), that's a %d per-device batch size.", + batch_size, jax.process_count(), batch_size // jax.process_count(), + jax.local_device_count(), jax.device_count(), + batch_size // jax.device_count()) + + # First thing after above sanity checks, so we can log "start" ticks. + mw = u.BigVisionMetricWriter(xid, wid, workdir, config) + + write_note("Initializing train dataset...") + train_ds, ntrain_img = input_pipeline.training(config.input) + + # Start prefetching already. + n_prefetch = config.get("prefetch_to_device", 1) + train_iter = input_pipeline.start_input_pipeline(train_ds, n_prefetch) + + total_steps = u.steps("total", config, ntrain_img, batch_size) + def get_steps(name, default=ValueError, cfg=config): + return u.steps(name, cfg, ntrain_img, batch_size, total_steps, default) + + u.chrono.inform(total_steps=total_steps, global_bs=batch_size, + steps_per_epoch=ntrain_img / batch_size, + measure=mw.measure, write_note=write_note) + + info("Running for %d steps, that means %f epochs", + total_steps, total_steps * batch_size / ntrain_img) + + # Create student and teacher models + def get_model_mod(name): # Used many times. + mod_name = config[f"{name}_name"] + return importlib.import_module(f"big_vision.models.{mod_name}") + + write_note("Initializing models...") + def make_model(name): + return get_model_mod(name).Model( + num_classes=config.num_classes, **config.get(name, {})) + + models = { + "student": make_model("student"), + **{t: make_model(t) for t in config.teachers} + } + + # We want all parameters to be created in host RAM, not on any device, they'll + # be sent there later as needed, otherwise we already encountered two + # situations where we allocate them twice. + def get_init(model, name): + @functools.partial(jax.jit, backend="cpu") + def _init(rng): + bs = batch_size // jax.device_count() + img_size = tuple(getfirst(train_ds.element_spec, name, "image").shape[1:]) + no_image = jnp.zeros((bs,) + img_size, jnp.float32) + params = flax.core.unfreeze(model.init(rng, no_image))["params"] + return params + return _init + + rng, *rng_inits = jax.random.split(rng, len(models) + 1) + with u.chrono.log_timing("z/secs/init"): + params_cpu = { + name: get_init(models[name], name=name)(r) + for name, r in zip(models, rng_inits)} + + if jax.process_index() == 0: + for name, params in params_cpu.items(): + parameter_overview.log_parameter_overview(params, msg=f"{name} params") + mw.measure(f"num_params_{name}", + sum(p.size for p in jax.tree_leaves(params))) + + write_note(f"Initializing {config.optax_name} optimizer...") + # For now, we explicitly only optimize the student parameters as there's + # nothing else to be optimized. If we ever want to add learnable projections + # or similar for good (we explored but ditched), need to refactor this a bit. + tx, sched_fns = bv_optax.make( + config, params_cpu["student"], sched_kw=dict( + total_steps=total_steps, batch_size=batch_size, data_size=ntrain_img)) + + # We jit this, such that the arrays are created on the CPU, not device[0]. + opt_cpu = jax.jit(tx.init, backend="cpu")(params_cpu["student"]) + sched_fns_cpu = [jax.jit(sched_fn, backend="cpu") for sched_fn in sched_fns] + + @jax.named_call + def loss_fn(student_params, params, data, rngs, **flexi_kw): + # Note: need to extract and use `student_params` out of `params` because the + # first argument of `loss_fn` is what's differentiated wrt. + params["student"] = student_params + + def fwd(name, params): + return jax.named_call(models[name].apply, name=name)( + {"params": params}, getfirst(data, name, "image"), + train=name == "student", rngs=rngs.get(name), + **(flexi_kw if name == "student" else {}) + )[0] # logits, unused_outputs + logits = {name: fwd(name, w) for name, w in params.items()} + + measurements = {} + for name, lg in logits.items(): + measurements[f"entropy_{name}"] = -jnp.sum( + jax.nn.log_softmax(lg) * jax.nn.softmax(lg), axis=-1) + if "labels" in data: + measurements[f"task_loss_{name}"] = u.softmax_xent( + logits=lg, labels=data["labels"], reduction=False) + + # NOTE: xent is linear in labels, so for KL, this is actually the same as + # using a teacher-ensemble in probs-space! + measurements["distill_loss"] = 0.0 + for name in config.teachers: + l = dd.dist(logits["student"], logits[name], config.get("distance", "kl"), + **config.get("distance_kw", {})) + measurements[f"distill_loss_{name}"] = l + measurements["distill_loss"] += l + + outputs = (measurements["distill_loss"], measurements) + return jax.tree_map(jnp.mean, outputs) + + flexi_argnames = sorted(config.flexi) + + @functools.partial(jax.pmap, axis_name="batch", donate_argnums=(0, 1), + static_broadcasted_argnums=tuple(range(4, 4 + len(flexi_argnames)))) + def update_fn(params, opt, rng, data, *args): + """Update step.""" + + # Mixup. Note: overwrites the `data` entries (that's intended). + if config.get("mixup") and config.mixup.p: + to_mix = {name: data[name] + for name in ("image", "labels") + tuple(models) if name in data} + rng, _, to_mix = u.mixup(rng, **config.mixup, **to_mix) + data = {**data, **to_mix} + + # Get device-specific loss rng. + rng, *rng_models = jax.random.split(rng, len(models) + 1) + rngs_models_local = { + name: {"dropout": jax.random.fold_in(rngi, jax.lax.axis_index("batch"))} + for name, rngi in zip(models, rng_models) + } + + w = params["student"] # Need to explicitly pull out the optimized ones. + (l, measurements), grads = jax.lax.pmean( + jax.value_and_grad(loss_fn, has_aux=True)( + w, params, data, rngs=rngs_models_local, + **dict(zip(flexi_argnames, args))), + axis_name="batch") + updates, opt = tx.update(grads, opt, w) + w = optax.apply_updates(w, updates) + params["student"] = w + + # Take some logging measurements + gs = jax.tree_leaves(bv_optax.replace_frozen(config.schedule, grads, 0.)) + measurements["l2_grads"] = jnp.sqrt(sum([jnp.vdot(g, g) for g in gs])) + ps = jax.tree_leaves(w) + measurements["l2_params"] = jnp.sqrt(sum([jnp.vdot(p, p) for p in ps])) + us = jax.tree_leaves(updates) + measurements["l2_updates"] = jnp.sqrt(sum([jnp.vdot(u, u) for u in us])) + + return params, opt, rng, l, measurements + + # We always load the teachers first, because they NEED to be initialized + # and since we don't ever modify them, we don't store them in checkpoints. + for name in config.teachers: + init_def = config[f"{name}_init"] + write_note(f"Initializing {name} from {init_def}…") + params_cpu[name] = get_model_mod(name).load( + params_cpu[name], init_def, config[name], + **config.get(f"{name}_load", {})) + + # Decide how to initialize training. The order is important. + # 1. Always resumes from the existing checkpoint, e.g. resumes a finetune job. + # 2. Resume from a previous checkpoint, e.g. start a cooldown training job. + # 3. Initialize student from something, e.g. start a fine-tuning job. + # 4. Train from scratch. + resume_ckpt_path = None + if save_ckpt_path and gfile.exists(save_ckpt_path): + resume_ckpt_path = save_ckpt_path + elif config.get("resume"): + resume_ckpt_path = fillin(config.resume) + if resume_ckpt_path: + write_note("Resume training from checkpoint...") + # NOTE: we never change the teachers, so only checkpoint student here. + checkpoint = { + "params": params_cpu["student"], + "opt": opt_cpu, + "chrono": u.chrono.save(), + } + checkpoint_tree = jax.tree_structure(checkpoint) + loaded = u.load_checkpoint_np(resume_ckpt_path, checkpoint_tree) + # bfloat16 type gets lost when data is saved to disk, so we recover it. + checkpoint = jax.tree_map(u.recover_dtype, loaded) + params_cpu["student"], opt_cpu = checkpoint["params"], checkpoint["opt"] + u.chrono.load(checkpoint["chrono"]) + elif config.get("student_init"): + write_note(f"Initialize student from {config.student_init}...") + params_cpu["student"] = get_model_mod("student").load( + params_cpu["student"], config.student_init, config.get("student"), + **config.get("student_load", {})) + if jax.process_index() == 0: + parameter_overview.log_parameter_overview( + params_cpu["student"], msg="restored (student) params") + + write_note("Kicking off misc stuff...") + first_step = bv_optax.get_count(opt_cpu) + u.chrono.inform(first_step=first_step) + prof = None # Keeps track of start/stop of profiler state. + + write_note(f"Replicating...\n{u.chrono.note}") + params_repl = flax.jax_utils.replicate(params_cpu) + opt_repl = flax.jax_utils.replicate(opt_cpu) + + # Define predict functions that the evaluators can use: + def predict_fn(params, *, name, **kw): + image = kw.pop(name, kw.pop("image", None)) + # Ugly API compatibility necessity: + for k in ("student", *config.teachers): + kw.pop(k, 0) + return models[name].apply({"params": params[name]}, image, **kw) + + # 1. One for each variant of the student + student_pfns = flexi.mkpredictfns( + functools.partial(predict_fn, name="student"), config.flexi, "student_{x}" + ) + # 2. One per teacher model + teacher_pfns = { + name: functools.partial(predict_fn, name=name) + for name in config.teachers + } + # 3. One for each (student-variant, teacher) pair, eg for distance eval. + combined_pfns = { + f"{sn}_{tn}": lambda *a, sfn=sfn, tfn=tfn, **kw: (sfn(*a, **kw), tfn(*a, **kw)) # pylint: disable=line-too-long + for sn, sfn in student_pfns.items() + for tn, tfn in teacher_pfns.items() + } + + predict_fns = {**student_pfns, **teacher_pfns, **combined_pfns} + + @functools.cache + def evaluators(): + return eval_common.from_config( + config, predict_fns, + lambda s: write_note(f"Init evaluator: {s}…\n{u.chrono.note}"), + lambda key, cfg: get_steps(key, default=None, cfg=cfg), + ) + + rng, rng_loop = jax.random.split(rng, 2) + rngs_loop = flax.jax_utils.replicate(rng_loop) + ckpt_writer = None + + write_note(f"First step compilations...\n{u.chrono.note}") + + # Note that training can be pre-empted during the final evaluation (i.e. + # just after the final checkpoint has been written to disc), in which case we + # want to run the evals. + if first_step in (total_steps, 0): + mw.step_start(first_step) + for (name, evaluator, _, prefix) in evaluators(): + if config.evals[name].get("skip_first") and first_step != total_steps: + continue + write_note(f"{name} evaluation...\n{u.chrono.note}") + with u.chrono.log_timing(f"z/secs/eval/{name}"): + for key, value in evaluator.run(params_repl): + mw.measure(f"{prefix}{key}", value) + + # Using a python integer for step here, because opt.state.step is allocated + # on TPU during replication. + for step, batch in zip(range(first_step + 1, total_steps + 1), train_iter): + mw.step_start(step) + + np_rng = flexi.mkrng(xid, wid, step) + flexi_args = [ + flexi.choice(config.flexi[n].v, config.flexi[n].p, np_rng) + for n in flexi_argnames + ] + + with jax.profiler.StepTraceAnnotation("train_step", step_num=step): + with u.chrono.log_timing("z/secs/update0", noop=step > first_step + 1): + params_repl, opt_repl, rngs_loop, loss_value, measurements = update_fn( + params_repl, opt_repl, rngs_loop, batch, *flexi_args) + + # On the first host, let's always profile a handful of early steps. + if jax.process_index() == 0: + prof = u.startstop_prof(prof, step, first_step, get_steps("log_training")) + + # Report training progress + if (u.itstime(step, get_steps("log_training"), total_steps, host=0) + or u.chrono.warmup and jax.process_index() == 0): + for i, sched_fn_cpu in enumerate(sched_fns_cpu): + mw.measure(f"global_schedule{i if i else ''}", sched_fn_cpu(step - 1)) + l = mw.measure("training_loss", loss_value[0]) + for name, value in measurements.items(): + mw.measure(name, value[0]) + u.chrono.tick(step) + if not np.isfinite(l): + raise RuntimeError(f"The loss became nan or inf somewhere within steps " + f"[{step - get_steps('log_training')}, {step}]") + + # Checkpoint saving + if (save_ckpt_path and + (u.itstime(step, get_steps("ckpt", None), total_steps, host=0) or + u.itstime(step, get_steps("keep_ckpt", None), total_steps, host=0))): + u.chrono.pause(wait_for=(params_repl["student"], opt_repl)) + u.checkpointing_timeout(ckpt_writer, config.get("ckpt_timeout", 1)) + # We need to transfer the weights over now or else we risk keeping them + # alive while they'll be updated in a future step, creating hard to debug + # memory errors (see (internal link)). Also, takes device 0's params only. + params_cpu["student"], opt_cpu = jax.tree_map( + lambda x: np.array(x[0]), (params_repl["student"], opt_repl)) + + # Check whether we want to keep a copy of the current checkpoint. + copy_step = None + if u.itstime(step, get_steps("keep_ckpt", None), total_steps): + copy_step = step + + ckpt = {"params": params_cpu["student"], + "opt": opt_cpu, + "chrono": u.chrono.save()} + ckpt_writer = pool.apply_async( + u.save_checkpoint, (ckpt, save_ckpt_path, copy_step)) + u.chrono.resume() + + for (name, evaluator, log_steps, prefix) in evaluators(): + if u.itstime(step, log_steps, total_steps, first=False, last=True): + u.chrono.pause(wait_for=params_repl) + u.chrono.tick(step) # Record things like epoch number, core hours etc. + write_note(f"{name} evaluation...\n{u.chrono.note}") + with u.chrono.log_timing(f"z/secs/eval/{name}"): + for key, value in evaluator.run(params_repl): + mw.measure(f"{prefix}{key}", value) + u.chrono.resume() + mw.step_end() + + # Always give a chance to stop the profiler, no matter how things ended. + # TODO: can we also do this when dying of an exception like OOM? + if jax.process_index() == 0 and prof is not None: + u.startstop_prof(prof) + + # Last note needs to happen before the pool's closed =) + write_note(f"Done!\n{u.chrono.note}") + + pool.close() + pool.join() + mw.close() + + # Make sure all hosts stay up until the end of main. + u.sync() + + u.maybe_cleanup_workdir(workdir, flags.FLAGS.cleanup, info) + + +if __name__ == "__main__": + app.run(main) diff --git a/big_vision/trainers/proj/flexi/train.py b/big_vision/trainers/proj/flexi/train.py new file mode 100644 index 0000000000000000000000000000000000000000..5e76265271e3cd0db97851e93e3d6bf09a11670f --- /dev/null +++ b/big_vision/trainers/proj/flexi/train.py @@ -0,0 +1,365 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Training loop with flexible/schedulable settings.""" +# pylint: disable=consider-using-from-import +import functools +import importlib +import multiprocessing.pool +import os + +from absl import app +from absl import flags +from absl import logging +import big_vision.evaluators.common as eval_common +import big_vision.input_pipeline as input_pipeline +import big_vision.optax as bv_optax +import big_vision.trainers.proj.flexi.common as flexi +import big_vision.utils as u +from clu import parameter_overview +import flax +import jax +import jax.numpy as jnp +from ml_collections import config_flags +import numpy as np +import optax +import tensorflow as tf + +from tensorflow.io import gfile + +# pylint: disable=logging-fstring-interpolation + + +config_flags.DEFINE_config_file( + "config", None, "Training configuration.", lock_config=True) + +flags.DEFINE_string("workdir", default=None, help="Work unit directory.") +flags.DEFINE_boolean("cleanup", default=False, + help="Delete workdir (only) after successful completion.") + +# Adds jax flags to the program. +jax.config.parse_flags_with_absl() + + +def main(argv): + del argv + tf.config.experimental.set_visible_devices([], "GPU") + + config = flags.FLAGS.config + workdir = flags.FLAGS.workdir + logging.info( + f"\u001b[33mHello from process {jax.process_index()} holding " + f"{jax.local_device_count()}/{jax.device_count()} devices and " + f"writing to workdir {workdir}.\u001b[0m") + + save_ckpt_path = None + if workdir: # Always create if requested, even if we may not write into it. + gfile.makedirs(workdir) + save_ckpt_path = os.path.join(workdir, "checkpoint.npz") + + # The pool is used to perform misc operations such as logging in async way. + pool = multiprocessing.pool.ThreadPool() + + # Here we register preprocessing ops from modules listed on `pp_modules`. + for m in config.get("pp_modules", ["ops_general", "ops_image", "ops_text"]): + importlib.import_module(f"big_vision.pp.{m}") + + # This seed makes the Jax part of things (like model init) deterministic. + # However, full training still won't be deterministic, for example due to the + # tf.data pipeline not being deterministic even if we would set TF seed. + # See (internal link) for a fun read on what it takes. + rng = jax.random.PRNGKey(config.get("seed", 0)) + + # These functions do more stuff internally, for OSS release we mock them by + # trivial alternatives in order to minize disruptions in the code. + xid, wid = -1, -1 + fillin = lambda s: s + def info(s, *a): + logging.info("\u001b[33mNOTE\u001b[0m: " + s, *a) + def write_note(note): + if jax.process_index() == 0: + info("%s", note) + + write_note("Initializing...") + + batch_size = config.input.batch_size + if batch_size % jax.device_count() != 0: + raise ValueError(f"Batch size ({batch_size}) must " + f"be divisible by device number ({jax.device_count()})") + info("Global batch size %d on %d hosts results in %d local batch size. With " + "%d dev per host (%d dev total), that's a %d per-device batch size.", + batch_size, jax.process_count(), batch_size // jax.process_count(), + jax.local_device_count(), jax.device_count(), + batch_size // jax.device_count()) + + # First thing after above sanity checks, so we can log "start" ticks. + mw = u.BigVisionMetricWriter(xid, wid, workdir, config) + + write_note("Initializing train dataset...") + train_ds, ntrain_img = input_pipeline.training(config.input) + + # Start prefetching already. + n_prefetch = config.get("prefetch_to_device", 1) + train_iter = input_pipeline.start_input_pipeline(train_ds, n_prefetch) + + total_steps = u.steps("total", config, ntrain_img, batch_size) + def get_steps(name, default=ValueError, cfg=config): + return u.steps(name, cfg, ntrain_img, batch_size, total_steps, default) + + u.chrono.inform(total_steps=total_steps, global_bs=batch_size, + steps_per_epoch=ntrain_img / batch_size, + measure=mw.measure, write_note=write_note) + + info("Running for %d steps, that means %f epochs", + total_steps, total_steps * batch_size / ntrain_img) + + write_note(f"Initializing {config.model_name} model...") + model_mod = importlib.import_module(f"big_vision.models.{config.model_name}") + model = model_mod.Model( + num_classes=config.num_classes, **config.get("model", {})) + + # We want all parameters to be created in host RAM, not on any device, they'll + # be sent there later as needed, otherwise we already encountered two + # situations where we allocate them twice. + @functools.partial(jax.jit, backend="cpu") + def init(rng): + shape = tuple(train_ds.element_spec["image"].shape[1:]) + bs = batch_size // jax.device_count() + dummy_input = jnp.zeros((bs,) + shape, jnp.float32) + params = flax.core.unfreeze(model.init(rng, dummy_input))["params"] + + # Set bias in the head to a low value, such that loss is small initially. + if "init_head_bias" in config: + params["head"]["bias"] = jnp.full_like(params["head"]["bias"], + config["init_head_bias"]) + + return params + + rng, rng_init = jax.random.split(rng) + with u.chrono.log_timing("z/secs/init"): + params_cpu = init(rng_init) + + if jax.process_index() == 0: + num_params = sum(p.size for p in jax.tree_leaves(params_cpu)) + parameter_overview.log_parameter_overview(params_cpu, msg="init params") + mw.measure("num_params", num_params) + + write_note(f"Initializing {config.optax_name} optimizer...") + tx, sched_fns = bv_optax.make(config, params_cpu, sched_kw=dict( + total_steps=total_steps, batch_size=batch_size, data_size=ntrain_img)) + + # We jit this, such that the arrays are created on the CPU, not device[0]. + opt_cpu = jax.jit(tx.init, backend="cpu")(params_cpu) + sched_fns_cpu = [jax.jit(sched_fn, backend="cpu") for sched_fn in sched_fns] + + flexi_argnames = sorted(config.flexi) + + @functools.partial(jax.pmap, axis_name="batch", donate_argnums=(0, 1), + static_broadcasted_argnums=tuple(range(5, 5 + len(flexi_argnames)))) + def update_fn(params, opt, rng, images, labels, *args): + """Update step.""" + + measurements = {} + + if config.get("mixup") and config.mixup.p: + rng, (images, labels), _ = u.mixup(rng, images, labels, **config.mixup) + + # Get device-specific loss rng. + rng, rng_model = jax.random.split(rng, 2) + rng_model_local = jax.random.fold_in(rng_model, jax.lax.axis_index("batch")) + + def loss_fn(params, images, labels): + logits, _ = model.apply( + {"params": params}, images, + train=True, rngs={"dropout": rng_model_local}, + **dict(zip(flexi_argnames, args))) + return getattr(u, config.get("loss", "sigmoid_xent"))( + logits=logits, labels=labels) + + l, grads = jax.value_and_grad(loss_fn)(params, images, labels) + l, grads = jax.lax.pmean((l, grads), axis_name="batch") + updates, opt = tx.update(grads, opt, params) + params = optax.apply_updates(params, updates) + + gs = jax.tree_leaves(bv_optax.replace_frozen(config.schedule, grads, 0.)) + measurements["l2_grads"] = jnp.sqrt(sum([jnp.vdot(g, g) for g in gs])) + ps = jax.tree_leaves(params) + measurements["l2_params"] = jnp.sqrt(sum([jnp.vdot(p, p) for p in ps])) + us = jax.tree_leaves(updates) + measurements["l2_updates"] = jnp.sqrt(sum([jnp.vdot(u, u) for u in us])) + + return params, opt, rng, l, measurements + + # We do not jit/pmap this function, because it is passed to evaluator that + # does it later. We output as many intermediate tensors as possible for + # maximal flexibility. Later `jit` will prune out things that are not needed. + def predict_fn(params, image, **flexi_kw): + logits, out = model.apply({"params": params}, image, **flexi_kw) + return logits, out + + # Decide how to initialize training. The order is important. + # 1. Always resumes from the existing checkpoint, e.g. resumes a finetune job. + # 2. Resume from a previous checkpoint, e.g. start a cooldown training job. + # 3. Initialize model from something, e,g, start a fine-tuning job. + # 4. Train from scratch. + resume_ckpt_path = None + if save_ckpt_path and gfile.exists(save_ckpt_path): + resume_ckpt_path = save_ckpt_path + elif config.get("resume"): + resume_ckpt_path = fillin(config.resume) + if resume_ckpt_path: + write_note("Resume training from checkpoint...") + checkpoint = { + "params": params_cpu, + "opt": opt_cpu, + "chrono": u.chrono.save(), + } + checkpoint_tree = jax.tree_structure(checkpoint) + loaded = u.load_checkpoint_np(resume_ckpt_path, checkpoint_tree) + # bfloat16 type gets lost when data is saved to disk, so we recover it. + checkpoint = jax.tree_map(u.recover_dtype, loaded) + params_cpu, opt_cpu = checkpoint["params"], checkpoint["opt"] + u.chrono.load(checkpoint["chrono"]) + elif config.get("model_init"): + write_note(f"Initialize model from {config.model_init}...") + params_cpu = model_mod.load( + params_cpu, config.model_init, config.get("model"), + **config.get("model_load", {})) + if jax.process_index() == 0: + parameter_overview.log_parameter_overview( + params_cpu, msg="restored params") + + write_note("Kicking off misc stuff...") + first_step = bv_optax.get_count(opt_cpu) + u.chrono.inform(first_step=first_step) + prof = None # Keeps track of start/stop of profiler state. + + write_note(f"Replicating...\n{u.chrono.note}") + params_repl = flax.jax_utils.replicate(params_cpu) + opt_repl = flax.jax_utils.replicate(opt_cpu) + + @functools.cache + def evaluators(): + return eval_common.from_config( + config, flexi.mkpredictfns(predict_fn, config.flexi, "predict_{x}"), + lambda s: write_note(f"Init evaluator: {s}…\n{u.chrono.note}"), + lambda key, cfg: get_steps(key, default=None, cfg=cfg), + ) + + rng, rng_loop = jax.random.split(rng, 2) + rngs_loop = flax.jax_utils.replicate(rng_loop) + ckpt_writer = None + + write_note(f"First step compilations...\n{u.chrono.note}") + + # Note that training can be pre-empted during the final evaluation (i.e. + # just after the final checkpoint has been written to disc), in which case we + # want to run the evals. + if first_step in (total_steps, 0): + mw.step_start(first_step) + for (name, evaluator, _, prefix) in evaluators(): + if config.evals[name].get("skip_first") and first_step != total_steps: + continue + write_note(f"{name} evaluation...\n{u.chrono.note}") + with u.chrono.log_timing(f"z/secs/eval/{name}"): + for key, value in evaluator.run(params_repl): + mw.measure(f"{prefix}{key}", value) + + # Using a python integer for step here, because opt.state.step is allocated + # on TPU during replication. + for step, batch in zip(range(first_step + 1, total_steps + 1), train_iter): + mw.step_start(step) + + np_rng = flexi.mkrng(xm_xp.id, xm_wu.id, step) + flexi_args = [ + flexi.choice(config.flexi[n].v, config.flexi[n].p, np_rng) + for n in flexi_argnames + ] + + with jax.profiler.StepTraceAnnotation("train_step", step_num=step): + with u.chrono.log_timing("z/secs/update0", noop=step > first_step + 1): + params_repl, opt_repl, rngs_loop, loss_value, measurements = update_fn( + params_repl, opt_repl, rngs_loop, batch["image"], batch["labels"], + *flexi_args) + + # On the first host, let's always profile a handful of early steps. + if jax.process_index() == 0: + prof = u.startstop_prof(prof, step, first_step, get_steps("log_training")) + + # Report training progress + if (u.itstime(step, get_steps("log_training"), total_steps, host=0) + or u.chrono.warmup and jax.process_index() == 0): + for i, sched_fn_cpu in enumerate(sched_fns_cpu): + mw.measure(f"global_schedule{i if i else ''}", sched_fn_cpu(step - 1)) + l = mw.measure("training_loss", loss_value[0]) + for name, value in measurements.items(): + mw.measure(name, value[0]) + u.chrono.tick(step) + if not np.isfinite(l): + raise RuntimeError(f"The loss became nan or inf somewhere within steps " + f"[{step - get_steps('log_training')}, {step}]") + + # Checkpoint saving + if (save_ckpt_path and + (u.itstime(step, get_steps("ckpt", None), total_steps, host=0) or + u.itstime(step, get_steps("keep_ckpt", None), total_steps, host=0))): + u.chrono.pause(wait_for=(params_repl, opt_repl)) + u.checkpointing_timeout(ckpt_writer, config.get("ckpt_timeout", 1)) + # We need to transfer the weights over now or else we risk keeping them + # alive while they'll be updated in a future step, creating hard to debug + # memory errors (see (internal link)). Also, takes device 0's params only. + params_cpu = jax.tree_map(lambda x: np.array(x[0]), params_repl) + opt_cpu = jax.tree_map(lambda x: np.array(x[0]), opt_repl) + + # Check whether we want to keep a copy of the current checkpoint. + copy_step = None + if u.itstime(step, get_steps("keep_ckpt", None), total_steps): + copy_step = step + + ckpt = {"params": params_cpu, "opt": opt_cpu, "chrono": u.chrono.save()} + ckpt_writer = pool.apply_async( + u.save_checkpoint, (ckpt, save_ckpt_path, copy_step)) + u.chrono.resume() + + for (name, evaluator, log_steps, prefix) in evaluators(): + if u.itstime(step, log_steps, total_steps, first=False, last=True): + u.chrono.pause(wait_for=params_repl) + u.chrono.tick(step) # Record things like epoch number, core hours etc. + write_note(f"{name} evaluation...\n{u.chrono.note}") + with u.chrono.log_timing(f"z/secs/eval/{name}"): + for key, value in evaluator.run(params_repl): + mw.measure(f"{prefix}{key}", value) + u.chrono.resume() + mw.step_end() + + # Always give a chance to stop the profiler, no matter how things ended. + # TODO: can we also do this when dying of an exception like OOM? + if jax.process_index() == 0 and prof is not None: + u.startstop_prof(prof) + + # Last note needs to happen before the pool's closed =) + write_note(f"Done!\n{u.chrono.note}") + + pool.close() + pool.join() + mw.close() + + # Make sure all hosts stay up until the end of main. + u.sync() + + u.maybe_cleanup_workdir(workdir, flags.FLAGS.cleanup, info) + + +if __name__ == "__main__": + app.run(main) diff --git a/big_vision/trainers/proj/givt/generative.py b/big_vision/trainers/proj/givt/generative.py new file mode 100644 index 0000000000000000000000000000000000000000..e571b759161d05c129bcb48719f27a2c57fddd1e --- /dev/null +++ b/big_vision/trainers/proj/givt/generative.py @@ -0,0 +1,719 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Training loop for GIVT-style autoregressive and masked models.""" + +# pylint: disable=consider-using-from-import +import functools +import importlib +import multiprocessing.pool +import os + +from absl import app +from absl import flags +from absl import logging +import big_vision.evaluators.common as eval_common +import big_vision.input_pipeline as input_pipeline +from big_vision.models.proj.givt import parallel_decode +import big_vision.models.proj.givt.decode as softar_decode +import big_vision.optax as bv_optax +import big_vision.sharding as bv_sharding +import big_vision.trainers.proj.givt.utils as trainer_utils +from big_vision.trainers.proj.uvim import panoptic_task +import big_vision.utils as u +from clu import parameter_overview +import flax +import jax +from jax.experimental import mesh_utils +from jax.experimental import multihost_utils +from jax.experimental.array_serialization import serialization as array_serial +import jax.numpy as jnp +from ml_collections import config_flags +import numpy as np +import optax +import tensorflow as tf + +from tensorflow.io import gfile + +# pylint: disable=logging-fstring-interpolation + + +config_flags.DEFINE_config_file( + "config", None, "Training configuration.", lock_config=True) + +flags.DEFINE_string("workdir", default=None, help="Work unit directory.") +flags.DEFINE_boolean("cleanup", default=False, + help="Delete workdir (only) after successful completion.") + +# Adds jax flags to the program. +jax.config.parse_flags_with_absl() +# Transfer guard will fail the program whenever that data between a host and +# a device is transferred implicitly. This often catches subtle bugs that +# cause slowdowns and memory fragmentation. Explicit transfers are done +# with jax.device_put and jax.device_get. +jax.config.update("jax_transfer_guard", "disallow") +# Fixes design flaw in jax.random that may cause unnecessary d2d comms. +jax.config.update("jax_threefry_partitionable", True) + + +NamedSharding = jax.sharding.NamedSharding +P = jax.sharding.PartitionSpec + + +def main(argv): + del argv + + jax.distributed.initialize() + + # Make sure TF does not touch GPUs. + tf.config.set_visible_devices([], "GPU") + + config = flags.FLAGS.config + +################################################################################ +# # +# Set up logging # +# # +################################################################################ + + # Set up work directory and print welcome message. + workdir = flags.FLAGS.workdir + logging.info( + f"\u001b[33mHello from process {jax.process_index()} holding " + f"{jax.local_device_count()}/{jax.device_count()} devices and " + f"writing to workdir {workdir}.\u001b[0m") + + save_ckpt_path = None + if workdir: # Always create if requested, even if we may not write into it. + gfile.makedirs(workdir) + save_ckpt_path = os.path.join(workdir, "checkpoint.bv") + + # The pool is used to perform misc operations such as logging in async way. + pool = multiprocessing.pool.ThreadPool() + + # Here we register preprocessing ops from modules listed on `pp_modules`. + for m in config.get("pp_modules", ["ops_general", "ops_image", "ops_text", + "proj.uvim.pp_ops", "proj.givt.pp_ops"]): + importlib.import_module(f"big_vision.pp.{m}") + + # Setup up logging and experiment manager. + xid, wid = -1, -1 + def info(s, *a): + logging.info("\u001b[33mNOTE\u001b[0m: " + s, *a) + def write_note(note): + if jax.process_index() == 0: + info("%s", note) + + mw = u.BigVisionMetricWriter(xid, wid, workdir, config) + + # Allow for things like timings as early as possible! + u.chrono.inform(measure=mw.measure, write_note=write_note) + +################################################################################ +# # +# Set up Mesh # +# # +################################################################################ + + # We rely on jax mesh_utils to organize devices, such that communication + # speed is the fastest for the last dimension, second fastest for the + # penultimate dimension, etc. + config_mesh = config.get("mesh", [("data", jax.device_count())]) + + # Sharding rules with default + sharding_rules = config.get("sharding_rules", [("act_batch", "data")]) + + mesh_axes, mesh_size = tuple(zip(*config_mesh)) + + # Because jax.utils do not support `-1` shape size. + mesh_size = np.array(jax.devices()).reshape(mesh_size).shape + + device_mesh = mesh_utils.create_device_mesh(mesh_size) + + # Consistent device order is important to ensure correctness of various train + # loop components, such as input pipeline, update step, evaluators. The + # order presribed by the `devices_flat` variable should be used throughout + # the program. + devices_flat = device_mesh.flatten() + +################################################################################ +# # +# Input Pipeline # +# # +################################################################################ + + write_note("Initializing train dataset...") + batch_size = config.input.batch_size + if batch_size % jax.device_count() != 0: + raise ValueError(f"Batch size ({batch_size}) must " + f"be divisible by device number ({jax.device_count()})") + info("Global batch size %d on %d hosts results in %d local batch size. With " + "%d dev per host (%d dev total), that's a %d per-device batch size.", + batch_size, jax.process_count(), batch_size // jax.process_count(), + jax.local_device_count(), jax.device_count(), + batch_size // jax.device_count()) + + train_ds, ntrain_img = input_pipeline.training(config.input) + + total_steps = u.steps("total", config, ntrain_img, batch_size) + def get_steps(name, default=ValueError, cfg=config): + return u.steps(name, cfg, ntrain_img, batch_size, total_steps, default) + + u.chrono.inform(total_steps=total_steps, global_bs=batch_size, + steps_per_epoch=ntrain_img / batch_size) + + info("Running for %d steps, that means %f epochs", + total_steps, total_steps * batch_size / ntrain_img) + + # Start input pipeline as early as possible. + n_prefetch = config.get("prefetch_to_device", 1) + train_iter = input_pipeline.start_global(train_ds, devices_flat, n_prefetch) + +################################################################################ +# # +# Create Model & Optimizer # +# # +################################################################################ + + write_note(f"Creating {config.vae.model_name} model...") + vae_mod = importlib.import_module( + f"big_vision.models.{config.vae.model_name}") + vae = vae_mod.Model(**config.vae.get("model", {})) + + write_note(f"Creating {config.model_name} model...") + model_mod = importlib.import_module(f"big_vision.models.{config.model_name}") + model_config = config.get("model", {}) + model = model_mod.Model(**model_config) + + if config.get("adaptor_name"): + write_note(f"Creating {config.adaptor_name} model...") + adaptor_mod = importlib.import_module( + f"big_vision.models.{config.adaptor_name}") + adaptor = adaptor_mod.Model(num_channels=model_config.out_dim, + **config.adaptor.model) + else: + adaptor = None + + def init(rng): + def _get_dummy_input(input_name, dtype=jnp.int64): + if input_name in train_ds.element_spec: + return jnp.zeros(train_ds.element_spec[input_name].shape, dtype=dtype) + return None + + dummy_img = _get_dummy_input("image", dtype=jnp.float32) + dummy_labels = _get_dummy_input("labels") + dummy_cond_img = _get_dummy_input("cond_image", dtype=jnp.float32) + local_batch_size = dummy_img.shape[0] # pytype: disable=attribute-error + + code_shape = ( + local_batch_size, model_config.seq_len, model_config.out_dim) + dummy_code = jnp.zeros(code_shape, jnp.float32) + + input_mask = model.get_input_mask_training( + jax.random.PRNGKey(0), (local_batch_size, model_config.seq_len) + ) + params = model.init(rng, dummy_code, dummy_labels, image=dummy_cond_img, + input_mask=input_mask)["params"] + + if adaptor is not None: + _, rng_adaptor = jax.random.split(rng) + adaptor_variables = adaptor.init(rng_adaptor, dummy_code) + params_adaptor = flax.core.unfreeze(adaptor_variables["params"]) + params["params_adaptor"] = params_adaptor # store in same dict + + return params + + rng = jax.random.PRNGKey(u.put_cpu(config.get("seed", 0))) + + write_note("Inferring parameter shapes...") + rng, rng_init = jax.random.split(rng) + params_shape = jax.eval_shape(init, rng_init) + + write_note("Inferring optimizer state shapes...") + tx, sched_fns = bv_optax.make(config, params_shape, sched_kw=dict( + total_steps=total_steps, batch_size=batch_size, data_size=ntrain_img)) + opt_shape = jax.eval_shape(tx.init, params_shape) + # We jit this, such that the arrays are created on the CPU, not device[0]. + sched_fns_cpu = [u.jit_cpu()(sched_fn) for sched_fn in sched_fns] + + # Training a stage 2 model requires a pretrained stage 1 model. We treat this + # as a constant and do not shard the parameters. + assert "model_init" in config.vae + params_vae = vae_mod.load(None, config.vae.model_init, + **config.vae.get("model_load", {})) + + def vae_encode(images, rng=None, reparametrize=True): + mu, logvar = vae.apply({"params": params_vae}, images, method=vae.encode) + if reparametrize: + assert rng is not None and "dropout" in rng + return vae.apply({"params": params_vae}, mu, logvar, + method=vae.reparametrize, rngs=rng) + return mu + + if jax.process_index() == 0: + num_params = sum(np.prod(p.shape) for p in jax.tree_leaves(params_shape)) + mw.measure("num_params", num_params) + +################################################################################ +# # +# Shard & Transfer # +# # +################################################################################ + + write_note("Creating device mesh...") + mesh = jax.sharding.Mesh(device_mesh, mesh_axes) + repl_sharding = jax.sharding.NamedSharding(mesh, P()) + + write_note("Inferring shardings...") + train_state_shape = {"params": params_shape, "opt": opt_shape} + + strategy = config.get("sharding_strategy", [(".*", "replicate")]) + train_state_sharding = bv_sharding.infer_sharding( + train_state_shape, strategy=strategy, mesh=mesh) + + write_note("Transferring train_state to devices...") + # RNG is always replicated + rng_init = u.reshard(rng_init, repl_sharding) + + # Parameters and the optimizer are now global (distributed) jax arrays. + params = jax.jit(init, out_shardings=train_state_sharding["params"])(rng_init) + opt = jax.jit(tx.init, out_shardings=train_state_sharding["opt"])(params) + + rng, rng_loop = jax.random.split(rng, 2) + rng_loop = u.reshard(rng_loop, repl_sharding) + del rng # not used anymore, so delete it. + + # At this point we have everything we need to form a train state. It contains + # all the parameters that are passed and updated by the main training step. + train_state = {"params": params, "opt": opt} + del params, opt # Delete to avoid memory leak or accidental reuse. + + write_note("Logging parameter overview...") + parameter_overview.log_parameter_overview( + train_state["params"], msg="Init params", + include_stats="global", jax_logging_process=0) + +################################################################################ +# # +# Update Step # +# # +################################################################################ + + # Define the loss function + def loss_fn(params, images, labels, cond_images, rng): + rng, rng_dropout = jax.random.split(rng, 2) + rng, rng_mask = jax.random.split(rng, 2) + _, rng_droplabels = jax.random.split(rng, 2) + + rng_dropout = {"dropout": rng_dropout} + + sequence = vae_encode(images, rng_dropout) + if adaptor is not None: + # Use the (invertible) adaptor to map to a new latent sequence + sequence = adaptor.apply({"params": params["params_adaptor"]}, + sequence, method=adaptor.forward) + + b, s, _ = sequence.shape + # This is None for the non-mask style. Otherwise, shape (b, s). + input_mask = model.get_input_mask_training(rng_mask, (b, s)) + drop_labels = model.get_drop_labels(rng_droplabels, batch_size=b) + + _, pdf = model.apply( + {"params": params}, sequence, labels, + image=cond_images, + train=True, + input_mask=input_mask, + drop_labels=drop_labels, + rngs=rng_dropout) + + # Shape: (B, L, out_dim) + nll = -pdf.log_prob(sequence) + metrics = {"nll": nll} + if input_mask is not None: + metrics["fraction_masked_out"] = input_mask.astype(jnp.float32).mean( + axis=1 + ) + if nll.ndim == 3: + input_mask = input_mask[:, :, None] + # Note that `input_mask` is True where we mask out the input (ie replace + # with mask token), so we also only gather nlls at the corresponding + # points. + nll = jnp.where(input_mask, nll, 0.0) + # Take mean only of the spots we care about to smooth loss magnitute + # between examples, like in maskgit (ie this is + # sum(loss * input_mask) / sum(input_mask) in their code. + loss = nll.mean(where=input_mask) + else: + loss = nll.mean() + + return loss, metrics + + @functools.partial( + jax.jit, + donate_argnums=(0,), + out_shardings=(train_state_sharding, repl_sharding)) + def update_fn(train_state, rng, batch): + """Update step.""" + + images = batch["image"] + labels, cond_images = batch.get("labels"), batch.get("cond_image") + + step_count = bv_optax.get_count(train_state["opt"], jittable=True) + rng = jax.random.fold_in(rng, step_count) + + measurements = {} + + # Get device-specific loss rng. + _, rng_model = jax.random.split(rng, 2) + params, opt = train_state["params"], train_state["opt"] + + (loss, metrics), grads = jax.value_and_grad(loss_fn, has_aux=True)( + params, images, labels, cond_images, rng_model) + updates, opt = tx.update(grads, opt, params) + params = optax.apply_updates(params, updates) + train_state = {"params": params, "opt": opt} + + measurements["training_loss"] = loss + gs = jax.tree_leaves(bv_optax.replace_frozen(config.schedule, grads, 0.)) + measurements["l2_grads"] = jnp.sqrt(sum([jnp.vdot(g, g) for g in gs])) + ps = jax.tree_leaves(params) + measurements["l2_params"] = jnp.sqrt(sum([jnp.vdot(p, p) for p in ps])) + us = jax.tree_leaves(updates) + measurements["l2_updates"] = jnp.sqrt(sum([jnp.vdot(u, u) for u in us])) + + if adaptor is not None: + ps_a = jax.tree_leaves(params["params_adaptor"]) + measurements["l2_params_adaptor"] = jnp.sqrt(sum([jnp.vdot(p, p) + for p in ps_a])) + + measurements.update({f"train/{k}": v.mean() for k, v in metrics.items()}) + + return train_state, measurements + +################################################################################ +# # +# Set up Evals # +# # +################################################################################ + + def validation_fn(train_state, batch, seed=0): + params = train_state["params"] + + local_rng = trainer_utils.get_local_rng(seed, batch) + + _, aux = loss_fn( + params, batch["image"], batch.get("labels"), + batch.get("cond_image"), local_rng) + return { + key: jnp.mean(value, axis=tuple(range(1, value.ndim))) + for key, value in aux.items() + } + + def predict_fn_teacher_forcing(train_state, batch, seed=0): + params = train_state["params"] + image, labels = batch["image"], batch.get("labels") + + local_rng = trainer_utils.get_local_rng(seed, batch) + + rng_dropout = {"dropout": local_rng} + sequence = vae_encode(image, rng_dropout) + + if adaptor is not None: + # Use the adaptor to map from VAE latent space to GIVT in/output space. + sequence = adaptor.apply({"params": params["params_adaptor"]}, + sequence, method=adaptor.forward) + + b, s, _ = sequence.shape + # This is None for the non-mask style. Otherwise, shape (b, s) of zeros + # (nothing masked). + input_mask = model.get_input_mask_teacher_forced((b, s)) + + _, pdf = model.apply( + {"params": params}, sequence, labels, + train=True, input_mask=input_mask, rngs=rng_dropout) + + rng_sample, _ = jax.random.split(local_rng, 2) + sampled_sequence = pdf.sample(seed=rng_sample) + + if adaptor is not None: + # Use the adaptor inverse to map back to the VAE latent space + sampled_sequence = adaptor.apply({"params": params["params_adaptor"]}, + sampled_sequence, method=adaptor.inverse) + logits = vae.apply( + {"params": params_vae}, sampled_sequence, method=vae.decode) + + return {"logits": logits} + + def predict_fn_rep(train_state, image, seed=0): + assert model.style == "ar" + assert model.drop_labels_probability == 1.0 + params = train_state["params"] + + local_rng = trainer_utils.get_local_rng(seed, batch) + + rng_dropout = {"dropout": local_rng} + sequence = vae_encode(image, rng_dropout) + placeholder_labels = jnp.zeros((sequence.shape[0],), dtype=jnp.int32) + + return model.apply({"params": params}, sequence, labels=placeholder_labels, + return_reps=True, method=model.decode) + + def predict_fn_sampling(train_state, batch, seed=0): + params = train_state["params"] + labels = batch.get("labels") + + local_rng = trainer_utils.get_local_rng(seed, batch) + code_logprobs = None + + if model.style == "ar": + if labels is None: + # Try to infer batch size if labels are not provided + if "image" in batch: + sampling_batch_size = batch["image"].shape[0] + elif "cond_image" in batch: + sampling_batch_size = batch["cond_image"].shape[0] + else: + sampling_batch_size = config.get("sampling_batch_size", 4) + else: + sampling_batch_size = None + sampled_codes, code_logprobs = softar_decode.generate( + params={"params": params}, + seed=local_rng, + model=model, + seq_len=config.model.seq_len, + feature_dim=config.model.out_dim, + labels=labels, + cond_image=batch.get("cond_image"), + batch_size=sampling_batch_size, + config=config.get("ar_generation_config"), + ) + elif model.style == "masked": + assert "cond_image" not in batch + sampled_codes = parallel_decode.decode_masked( # pytype: disable=wrong-arg-types + rng=local_rng, + labels=labels, + seq_len=config.model.seq_len, + feature_dim=config.model.out_dim, + model=model, + variables={"params": params}, + config=parallel_decode.MaskedGenerationConfig( + **config.get("masked_generation_config", {}) + ), + ).current_inputs_q + else: + raise NotImplementedError + + if adaptor is not None: + # Use the adaptor inverse to map back to the VAE latent space. + sampled_codes = adaptor.apply({"params": params["params_adaptor"]}, + sampled_codes, method=adaptor.inverse) + + sampled_images = vae.apply( + {"params": params_vae}, sampled_codes, method=vae.decode) + + sampling_results = {"logits": sampled_images} + if code_logprobs is not None: + sampling_results["logprobs"] = code_logprobs + + return sampling_results + + def predict_fn_sampling_panoptic( + train_state, batch, seed=0, min_fraction=0.0): + logits = predict_fn_sampling(train_state, batch, seed)["logits"] + return panoptic_task.panoptic_predictions_from_logits( + logits["semantics"], logits["instances"], min_fraction=min_fraction) + + def predict_fn_sampling_depth(train_state, batch, seed=0): + depth = predict_fn_sampling(train_state, batch, seed)["logits"]["depth"] + depth = trainer_utils.unbin_depth( + depth, min_depth=config.min_depth, max_depth=config.max_depth, + num_bins=config.vae.model.inout_specs["depth"][1]) + return {"depth": depth} + + # Only initialize evaluators when they are first needed. + @functools.lru_cache(maxsize=None) + def evaluators(): + return eval_common.from_config( + config, + { + "validation": validation_fn, + "sample_teacher_forced": predict_fn_teacher_forcing, + "sample": predict_fn_sampling, + "sample_panoptic": predict_fn_sampling_panoptic, + "sample_depth": predict_fn_sampling_depth, + "representation": predict_fn_rep, + }, + lambda s: write_note(f"Init evaluator: {s}…\n{u.chrono.note}"), + lambda key, cfg: get_steps(key, default=None, cfg=cfg), + devices_flat, + ) + + # Decide how to initialize training. The order is important. + # 1. Always resumes from the existing checkpoint, e.g. resumes a finetune job. + # 2. Resume from a previous checkpoint, e.g. start a cooldown training job. + # 3. Initialize model from something, e,g, start a fine-tuning job. + # 4. Train from scratch. + resume_ckpt_path = None + if save_ckpt_path and gfile.exists(f"{save_ckpt_path}-LAST"): + resume_ckpt_path = save_ckpt_path + elif config.get("resume"): + resume_ckpt_path = fillin(config.resume) + + ckpt_mngr = None + if save_ckpt_path or resume_ckpt_path: + ckpt_mngr = array_serial.GlobalAsyncCheckpointManager() + + if resume_ckpt_path: + write_note(f"Resuming training from checkpoint {resume_ckpt_path}...") + jax.tree_map(lambda x: x.delete(), train_state) + del train_state + shardings = { + **train_state_sharding, + "chrono": jax.tree_map(lambda _: repl_sharding, + u.chrono.save()), + } + loaded = u.load_checkpoint_ts( + resume_ckpt_path, tree=shardings, shardings=shardings) + train_state = {key: loaded[key] for key in train_state_sharding.keys()} + + u.chrono.load(jax.device_get(loaded["chrono"])) + del loaded + elif config.get("model_init"): + write_note(f"Initialize model from {config.model_init}...") + train_state["params"] = model_mod.load( + train_state["params"], config.model_init, config.get("model"), + **config.get("model_load", {})) + + # load has the freedom to return params not correctly sharded + train_state["params"] = u.reshard( + train_state["params"], train_state_sharding["params"]) + + parameter_overview.log_parameter_overview( + train_state["params"], msg="restored params", + include_stats="global", jax_logging_process=0) + + # At this point we need to know the current step to see whether to run evals. + write_note("Inferring the first step number...") + first_step_device = bv_optax.get_count(train_state["opt"], jittable=True) + first_step = int(jax.device_get(first_step_device)) + u.chrono.inform(first_step=first_step) + + # Note that training can be pre-empted during the final evaluation (i.e. + # just after the final checkpoint has been written to disc), in which case we + # want to run the evals. + if first_step in (total_steps, 0): + write_note("Running initial or final evals...") + mw.step_start(first_step) + for (name, evaluator, _, prefix) in evaluators(): + if config.evals[name].get("skip_first") and first_step != total_steps: + continue + write_note(f"{name} evaluation...\n{u.chrono.note}") + with u.chrono.log_timing(f"z/secs/eval/{name}"): + with mesh, flax.linen.logical_axis_rules(sharding_rules): + for key, value in evaluator.run(train_state): + mw.measure(f"{prefix}{key}", value) + +################################################################################ +# # +# Train Loop # +# # +################################################################################ + + prof = None # Keeps track of start/stop of profiler state. + + write_note("Starting training loop, compiling the first step...") + for step, batch in zip(range(first_step + 1, total_steps + 1), train_iter): + # Skip training loop when running an eval-only config + if config.get("eval_only", False): + break + mw.step_start(step) + + with jax.profiler.StepTraceAnnotation("train_step", step_num=step): + with u.chrono.log_timing("z/secs/update0", noop=step > first_step + 1): + with mesh, flax.linen.logical_axis_rules(sharding_rules): + train_state, measurements = update_fn(train_state, rng_loop, batch) + + # On the first host, let's always profile a handful of early steps. + if jax.process_index() == 0: + prof = u.startstop_prof(prof, step, first_step, get_steps("log_training")) + + # Report training progress + if (u.itstime(step, get_steps("log_training"), total_steps, host=0) + or u.chrono.warmup and jax.process_index() == 0): + for i, sched_fn_cpu in enumerate(sched_fns_cpu): + mw.measure(f"global_schedule{i if i else ''}", + sched_fn_cpu(u.put_cpu(step - 1))) + measurements = jax.device_get(measurements) + for name, value in measurements.items(): + mw.measure(name, value) + u.chrono.tick(step) + if not np.isfinite(measurements["training_loss"]): + raise RuntimeError(f"The loss became nan or inf somewhere within steps " + f"[{step - get_steps('log_training')}, {step}]") + + # Checkpoint saving + keep_ckpt_steps = get_steps("keep_ckpt", None) or total_steps + if save_ckpt_path and ( + (keep := u.itstime(step, keep_ckpt_steps, total_steps, first=False)) + or u.itstime(step, get_steps("ckpt", None), total_steps, first=True) + ): + u.chrono.pause(wait_for=train_state) + + # Copy because we add extra stuff to the checkpoint. + ckpt = {**train_state} + + # To save chrono state correctly and safely in a multihost setup, we + # broadcast the state to all hosts and convert it to a global array. + with jax.transfer_guard("allow"): + chrono_ckpt = multihost_utils.broadcast_one_to_all(u.chrono.save()) + chrono_shardings = jax.tree_map(lambda _: repl_sharding, chrono_ckpt) + ckpt = ckpt | {"chrono": u.reshard(chrono_ckpt, chrono_shardings)} + + u.save_checkpoint_ts(ckpt_mngr, ckpt, save_ckpt_path, step, keep) + u.chrono.resume() + + for (name, evaluator, log_steps, prefix) in evaluators(): + if u.itstime(step, log_steps, total_steps, first=False, last=True): + u.chrono.pause(wait_for=train_state) + u.chrono.tick(step) # Record things like epoch number, core hours etc. + write_note(f"{name} evaluation...\n{u.chrono.note}") + with u.chrono.log_timing(f"z/secs/eval/{name}"): + with mesh, flax.linen.logical_axis_rules(sharding_rules): + for key, value in evaluator.run(train_state): + mw.measure(f"{prefix}{key}", jax.device_get(value)) + u.chrono.resume() + mw.step_end() + + # Always give a chance to stop the profiler, no matter how things ended. + if jax.process_index() == 0 and prof is not None: + u.startstop_prof(prof) + + # Last note needs to happen before the pool's closed =) + write_note(f"Done!\n{u.chrono.note}") + + pool.close() + pool.join() + mw.close() + + if ckpt_mngr: + ckpt_mngr.wait_until_finished() + + # Make sure all hosts stay up until the end of main. + u.sync() + + u.maybe_cleanup_workdir(workdir, flags.FLAGS.cleanup, info) + + +if __name__ == "__main__": + app.run(main) diff --git a/big_vision/trainers/proj/givt/utils.py b/big_vision/trainers/proj/givt/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..449b43f37fb8f60fb1e41090ed91c70c92719f02 --- /dev/null +++ b/big_vision/trainers/proj/givt/utils.py @@ -0,0 +1,70 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utils for GIVT stage I and II trainers.""" + +from typing import Any + +import jax +import jax.numpy as jnp + + +def unbin_depth( + depth: jax.Array, + *, + min_depth: float, + max_depth: float, + num_bins: int, +) -> jax.Array: + """Transform a depth map with binned values into a float-valued depth map. + + Args: + depth: Depth map whose binned values are encoded in one-hot fashion along + the last dimension. + min_depth: Minimum binned depth value. + max_depth: Maximum value of binned depth. + num_bins: Number of depth bins. + + Returns: + Float-valued depth map. + """ + depth = jnp.argmax(depth, axis=-1) + depth = depth.astype(jnp.float32) + 0.5 # Undoes floor in expectation. + depth /= num_bins + return depth * (max_depth - min_depth) + min_depth + + +def get_local_rng( + seed: int | jax.Array, + batch: Any, +) -> jax.Array: + """Generate a per-image seed based on the image id or the image values. + + Args: + seed: Random seed from which per-image seeds should be derived. + batch: Pytree containing a batch of images (key "image") and optionally + image ids (key "image/id"). + + Returns: + Array containing per-image ids. + """ + fake_id = None + if "image" in batch: + fake_id = (10**6 * jax.vmap(jnp.mean)(batch["image"])).astype(jnp.int32) + return jax.lax.scan( + lambda k, x: (jax.random.fold_in(k, x), None), + jax.random.PRNGKey(seed), + batch.get("image/id", fake_id), + )[0] + diff --git a/big_vision/trainers/proj/givt/vae.py b/big_vision/trainers/proj/givt/vae.py new file mode 100644 index 0000000000000000000000000000000000000000..242899e5ea00dd133fad1269382bb9309ba5dfca --- /dev/null +++ b/big_vision/trainers/proj/givt/vae.py @@ -0,0 +1,569 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Train loop for training a VAE or beta-VAE with a Gaussian encoder.""" +# pylint: disable=consider-using-from-import +import functools +import importlib +import math +import multiprocessing.pool +import os + +from absl import app +from absl import flags +from absl import logging +from big_vision import input_pipeline +import big_vision.evaluators.common as eval_common +import big_vision.optax as bv_optax +import big_vision.sharding as bv_sharding +import big_vision.trainers.proj.givt.utils as trainer_utils +from big_vision.trainers.proj.uvim import panoptic_task +import big_vision.utils as u +from clu import parameter_overview +import flax.linen as nn +import jax +from jax.experimental import mesh_utils +from jax.experimental import multihost_utils +from jax.experimental.array_serialization import serialization as array_serial +import jax.numpy as jnp +from ml_collections import config_flags +import numpy as np +import optax +import tensorflow as tf + +from tensorflow.io import gfile + +# pylint: disable=logging-fstring-interpolation + +partial = functools.partial + +config_flags.DEFINE_config_file( + "config", None, "Training configuration.", lock_config=True) + +flags.DEFINE_string("workdir", default=None, help="Work unit directory.") +flags.DEFINE_boolean("cleanup", default=False, + help="Delete workdir (only) after successful completion.") + +# Adds jax flags to the program. +jax.config.parse_flags_with_absl() +# Transfer guard will fail the program whenever that data between a host and +# a device is transferred implicitly. This often catches subtle bugs that +# cause slowdowns and memory fragmentation. Explicit transfers are done +# with jax.device_put and jax.device_get. +jax.config.update("jax_transfer_guard", "disallow") +# Fixes design flaw in jax.random that may cause unnecessary d2d comms. +jax.config.update("jax_threefry_partitionable", True) + + +NamedSharding = jax.sharding.NamedSharding +P = jax.sharding.PartitionSpec + + +def main(argv): + del argv + + jax.distributed.initialize() + + # Make sure TF does not touch GPUs. + tf.config.set_visible_devices([], "GPU") + + config = flags.FLAGS.config + +################################################################################ +# # +# Set up logging # +# # +################################################################################ + + # Set up work directory and print welcome message. + workdir = flags.FLAGS.workdir + logging.info( + f"\u001b[33mHello from process {jax.process_index()} holding " + f"{jax.local_device_count()}/{jax.device_count()} devices and " + f"writing to workdir {workdir}.\u001b[0m") + + save_ckpt_path = None + if workdir: # Always create if requested, even if we may not write into it. + gfile.makedirs(workdir) + save_ckpt_path = os.path.join(workdir, "checkpoint.bv") + + # The pool is used to perform misc operations such as logging in async way. + pool = multiprocessing.pool.ThreadPool() + + # Here we register preprocessing ops from modules listed on `pp_modules`. + for m in config.get("pp_modules", + ["ops_general", "ops_image", "proj.uvim.pp_ops", + "proj.givt.pp_ops"]): + importlib.import_module(f"big_vision.pp.{m}") + + # Setup up logging and experiment manager. + xid, wid = -1, -1 + fillin = lambda s: s + def info(s, *a): + logging.info("\u001b[33mNOTE\u001b[0m: " + s, *a) + def write_note(note): + if jax.process_index() == 0: + info("%s", note) + + mw = u.BigVisionMetricWriter(xid, wid, workdir, config) + + # Allow for things like timings as early as possible! + u.chrono.inform(measure=mw.measure, write_note=write_note) + +################################################################################ +# # +# Set up Mesh # +# # +################################################################################ + + # We rely on jax mesh_utils to organize devices, such that communication + # speed is the fastest for the last dimension, second fastest for the + # penultimate dimension, etc. + config_mesh = config.get("mesh", [("data", jax.device_count())]) + + # Sharding rules with default + sharding_rules = config.get("sharding_rules", [("act_batch", "data")]) + + mesh_axes, mesh_size = tuple(zip(*config_mesh)) + + # Because jax.utils do not support `-1` shape size. + mesh_size = np.array(jax.devices()).reshape(mesh_size).shape + + device_mesh = mesh_utils.create_device_mesh(mesh_size) + + # Consistent device order is important to ensure correctness of various train + # loop components, such as input pipeline, update step, evaluators. The + # order presribed by the `devices_flat` variable should be used throughout + # the program. + devices_flat = device_mesh.flatten() + +################################################################################ +# # +# Input Pipeline # +# # +################################################################################ + + write_note("Initializing train dataset...") + batch_size = config.input.batch_size + if batch_size % jax.device_count() != 0: + raise ValueError(f"Batch size ({batch_size}) must " + f"be divisible by device number ({jax.device_count()})") + info("Global batch size %d on %d hosts results in %d local batch size. With " + "%d dev per host (%d dev total), that's a %d per-device batch size.", + batch_size, jax.process_count(), batch_size // jax.process_count(), + jax.local_device_count(), jax.device_count(), + batch_size // jax.device_count()) + + train_ds, ntrain_img = input_pipeline.training(config.input) + + total_steps = u.steps("total", config, ntrain_img, batch_size) + def get_steps(name, default=ValueError, cfg=config): + return u.steps(name, cfg, ntrain_img, batch_size, total_steps, default) + + u.chrono.inform(total_steps=total_steps, global_bs=batch_size, + steps_per_epoch=ntrain_img / batch_size) + + info("Running for %d steps, that means %f epochs", + total_steps, total_steps * batch_size / ntrain_img) + + # Start input pipeline as early as possible. + n_prefetch = config.get("prefetch_to_device", 1) + train_iter = input_pipeline.start_global(train_ds, devices_flat, n_prefetch) + +################################################################################ +# # +# Create Model & Optimizer # +# # +################################################################################ + + write_note("Creating model...") + model_mod = importlib.import_module(f"big_vision.models.{config.model_name}") + model = model_mod.Model(**config.get("model", {})) + + def init(rng): + batch = jax.tree_map( + lambda x: jnp.zeros(x.shape, x.dtype.as_numpy_dtype), + train_ds.element_spec) + params = model.init(rng, batch["image"])["params"] + + return params + + # This seed makes the Jax part of things (like model init) deterministic. + # However, full training still won't be deterministic, for example due to the + # tf.data pipeline not being deterministic even if we would set TF seed. + rng = jax.random.PRNGKey(u.put_cpu(config.get("seed", 0))) + + write_note("Inferring parameter shapes...") + rng, rng_init = jax.random.split(rng) + params_shape = jax.eval_shape(init, rng_init) + + write_note("Inferring optimizer state shapes...") + tx, sched_fns = bv_optax.make(config, nn.unbox(params_shape), sched_kw=dict( + total_steps=total_steps, batch_size=batch_size, data_size=ntrain_img)) + opt_shape = jax.eval_shape(tx.init, params_shape) + # We jit this, such that the arrays are created on the CPU, not device[0]. + sched_fns_cpu = [u.jit_cpu()(sched_fn) for sched_fn in sched_fns] + + if jax.process_index() == 0: + num_params = sum(np.prod(p.shape) for p in jax.tree_leaves(params_shape)) + mw.measure("num_params", num_params) + +################################################################################ +# # +# Shard & Transfer # +# # +################################################################################ + + write_note("Creating device mesh...") + mesh = jax.sharding.Mesh(device_mesh, mesh_axes) + repl_sharding = jax.sharding.NamedSharding(mesh, P()) + + write_note("Inferring shardings...") + train_state_shape = {"params": params_shape, "opt": opt_shape} + + strategy = config.get("sharding_strategy", [(".*", "replicate")]) + with nn.logical_axis_rules(sharding_rules): + train_state_sharding = bv_sharding.infer_sharding( + train_state_shape, strategy=strategy, mesh=mesh) + + write_note("Transferring train_state to devices...") + # RNG is always replicated + rng_init = u.reshard(rng_init, repl_sharding) + + # Parameters and the optimizer are now global (distributed) jax arrays. + params = jax.jit(init, out_shardings=train_state_sharding["params"])(rng_init) + opt = jax.jit(tx.init, out_shardings=train_state_sharding["opt"])(params) + + rng, rng_loop = jax.random.split(rng, 2) + rng_loop = u.reshard(rng_loop, repl_sharding) + del rng # not used anymore, so delete it. + + # At this point we have everything we need to form a train state. It contains + # all the parameters that are passed and updated by the main training step. + # From here on, we have no need for Flax AxisMetadata (such as partitioning). + train_state = nn.unbox({"params": params, "opt": opt}) + del params, opt # Delete to avoid memory leak or accidental reuse. + + write_note("Logging parameter overview...") + parameter_overview.log_parameter_overview( + train_state["params"], msg="Init params", + include_stats="global", jax_logging_process=0) + + # Computing ELBO or beta-VAE loss for Gaussian encoder. + def vae_loss_fn(logits, image, mu, logvar, beta=1.0, + keep_batch_dim=False): + rec_loss_fn = config.get("rec_loss_fn", "l2") + if rec_loss_fn == "l2": + loss_rec = 0.5 * jnp.sum( + jnp.square(logits - image), axis=tuple(range(1, logits.ndim))) + elif rec_loss_fn == "xent": + loss_rec = 0.0 + for k, (in_ch, _) in config.model.inout_specs.items(): + cur_logits = logits[k] + b, _, _, c = cur_logits.shape + # This xent function expect a 1-D sequence of logits + image_flat = image[..., in_ch].reshape((b, -1)) + if config.get("mask_zero_target", False): + weights = (image_flat != 0).astype(jnp.int32) + else: + weights = None + loss_rec += u.weighted_softmax_xent( + logits=cur_logits.reshape((b, -1, c)), + labels=image_flat, + reduction=False, + weights=weights) + else: + raise ValueError(f"Unknown reconstruction loss function: {rec_loss_fn}") + + loss_kl = - 0.5 * jnp.sum( + 1 + logvar - jnp.square(mu) - jnp.exp(logvar), + axis=tuple(range(1, mu.ndim))) + + if not keep_batch_dim: + loss_rec, loss_kl = jnp.mean(loss_rec), jnp.mean(loss_kl) + loss = loss_rec + beta * loss_kl + return loss, {"loss": loss, "loss_rec": loss_rec, "loss_kl": loss_kl} + +################################################################################ +# # +# Update Step # +# # +################################################################################ + + @functools.partial( + jax.jit, + donate_argnums=(0,), + out_shardings=(train_state_sharding, repl_sharding)) + def update_fn(train_state, rng, batch): + """Update step.""" + step_count = bv_optax.get_count(train_state["opt"], jittable=True) + rng = jax.random.fold_in(rng, step_count) + + # Get device-specific loss rng. + _, rng_model = jax.random.split(rng, 2) + + def loss_fn(params): + logits, out = model.apply( + {"params": params}, + batch["image"], + train=True, + rngs={"dropout": rng_model}) + mu = out["mu"] + logvar = out["logvar"] + + loss, aux_loss = vae_loss_fn( + logits, batch["image"], mu, logvar, config.get("beta", 1.0), + ) + return loss, aux_loss + + params, opt = train_state["params"], train_state["opt"] + (loss, measurements), grads = jax.value_and_grad(loss_fn, has_aux=True)( + params) + updates, opt = tx.update(grads, opt, params) + params = optax.apply_updates(params, updates) + + measurements["training_loss"] = loss + gs = jax.tree_leaves(bv_optax.replace_frozen(config.schedule, grads, 0.)) + measurements["l2_grads"] = jnp.sqrt(sum([jnp.sum(g * g) for g in gs])) + ps = jax.tree_leaves(params) + measurements["l2_params"] = jnp.sqrt(sum([jnp.sum(p * p) for p in ps])) + us = jax.tree_leaves(updates) + measurements["l2_updates"] = jnp.sqrt(sum([jnp.sum(u * u) for u in us])) + + return {"params": params, "opt": opt}, measurements + +################################################################################ +# # +# Load Checkpoint # +# # +################################################################################ + + # Decide how to initialize training. The order is important. + # 1. Always resumes from the existing checkpoint, e.g. resumes a finetune job. + # 2. Resume from a previous checkpoint, e.g. start a cooldown training job. + # 3. Initialize model from something, e,g, start a fine-tuning job. + # 4. Train from scratch. + resume_ckpt_path = None + if save_ckpt_path and gfile.exists(f"{save_ckpt_path}-LAST"): + resume_ckpt_path = save_ckpt_path + elif config.get("resume"): + resume_ckpt_path = fillin(config.resume) + + ckpt_mngr = None + if save_ckpt_path or resume_ckpt_path: + ckpt_mngr = array_serial.GlobalAsyncCheckpointManager() + + if resume_ckpt_path: + write_note(f"Resuming training from checkpoint {resume_ckpt_path}...") + jax.tree_map(lambda x: x.delete(), train_state) + del train_state + shardings = { + **train_state_sharding, + "chrono": jax.tree_map(lambda _: repl_sharding, + u.chrono.save()), + } + loaded = u.load_checkpoint_ts( + resume_ckpt_path, tree=shardings, shardings=shardings) + train_state = {key: loaded[key] for key in train_state_sharding.keys()} + + u.chrono.load(jax.device_get(loaded["chrono"])) + del loaded + elif config.get("model_init"): + write_note(f"Initialize model from {config.model_init}...") + train_state["params"] = model_mod.load( + train_state["params"], config.model_init, config.get("model"), + **config.get("model_load", {})) + + # load has the freedom to return params not correctly sharded. Think of for + # example ViT resampling position embedings on CPU as numpy arrays. + train_state["params"] = u.reshard( + train_state["params"], train_state_sharding["params"]) + + parameter_overview.log_parameter_overview( + train_state["params"], msg="restored params", + include_stats="global", jax_logging_process=0) + + +################################################################################ +# # +# Setup Evals # +# # +################################################################################ + + # We do not jit/pmap this function, because it is passed to evaluator that + # does it later. We output as many intermediate tensors as possible for + # maximal flexibility. Later `jit` will prune out things that are not needed. + def validation_fn(train_state, batch, seed=0): + """Compute per-example metrics.""" + local_rng = trainer_utils.get_local_rng(seed, batch) + # Provide `dropout` rng for reprarametrization and set `train=True` to have + # the model actually do reparametrization. + logits, out = model.apply({"params": train_state["params"]}, batch["image"], + rngs={"dropout": local_rng}, train=True) + _, aux_loss = vae_loss_fn( + logits, batch["image"], out["mu"], out["logvar"], + config.get("beta", 1.0), + keep_batch_dim=True) + + return jax.tree_map( + lambda x: jnp.mean(x, axis=tuple(range(1, x.ndim))), + aux_loss) + + def predict_fn(train_state, batch, seed=0): + if isinstance(batch, dict): + batch = batch["image"] + local_rng = trainer_utils.get_local_rng(seed, {"image": batch}) + # Provide `dropout` rng and set `train=True` to perform reparametrization + logits, _ = model.apply({"params": train_state["params"]}, batch, + rngs={"dropout": local_rng}, train=True) + return {"logits": logits} + + def predict_fn_panoptic(train_state, batch): + logits = predict_fn({"params": train_state["params"]}, batch)["logits"] + return panoptic_task.panoptic_predictions_from_logits( + logits["semantics"], logits["instances"]) + + def predict_fn_depth(train_state, batch): + depth = predict_fn( + {"params": train_state["params"]}, batch)["logits"]["depth"] + depth = trainer_utils.unbin_depth( + depth, min_depth=config.min_depth, max_depth=config.max_depth, + num_bins=config.model.inout_specs["depth"][1]) + return {"depth": depth} + + # Only initialize evaluators when they are first needed. + @functools.lru_cache(maxsize=None) + def evaluators(): + return eval_common.from_config( + config, { + "predict": predict_fn, + "predict_panoptic": predict_fn_panoptic, + "predict_depth": predict_fn_depth, + "validation": validation_fn}, + lambda s: write_note(f"Init evaluator: {s}…\n{u.chrono.note}"), + lambda key, cfg: get_steps(key, default=None, cfg=cfg), + devices_flat, + ) + + # At this point we need to know the current step to see whether to run evals. + write_note("Inferring the first step number...") + first_step_device = bv_optax.get_count(train_state["opt"], jittable=True) + first_step = int(jax.device_get(first_step_device)) + u.chrono.inform(first_step=first_step) + + # Note that training can be pre-empted during the final evaluation (i.e. + # just after the final checkpoint has been written to disc), in which case we + # want to run the evals. + if first_step in (total_steps, 0): + write_note("Running initial or final evals...") + mw.step_start(first_step) + for (name, evaluator, _, prefix) in evaluators(): + if config.evals[name].get("skip_first") and first_step != total_steps: + continue + write_note(f"{name} evaluation...\n{u.chrono.note}") + with u.chrono.log_timing(f"z/secs/eval/{name}"): + with mesh, nn.logical_axis_rules(sharding_rules): + for key, value in evaluator.run(train_state): + mw.measure(f"{prefix}{key}", value) + +################################################################################ +# # +# Train Loop # +# # +################################################################################ + + prof = None # Keeps track of start/stop of profiler state. + + write_note("Starting training loop, compiling the first step...") + for step, batch in zip(range(first_step + 1, total_steps + 1), train_iter): + mw.step_start(step) + + with jax.profiler.StepTraceAnnotation("train_step", step_num=step): + with u.chrono.log_timing("z/secs/update0", noop=step > first_step + 1): + with mesh, nn.logical_axis_rules(sharding_rules): + train_state, measurements = update_fn(train_state, rng_loop, batch) + + # On the first host, let's always profile a handful of early steps. + if jax.process_index() == 0: + prof = u.startstop_prof(prof, step, first_step, get_steps("log_training")) + + # Report training progress + if (u.itstime(step, get_steps("log_training"), total_steps, host=0) + or u.chrono.warmup and jax.process_index() == 0): + for i, sched_fn_cpu in enumerate(sched_fns_cpu): + mw.measure(f"global_schedule{i if i else ''}", + sched_fn_cpu(u.put_cpu(step - 1))) + measurements = jax.device_get(measurements) + for name, value in measurements.items(): + mw.measure(name, value) + u.chrono.tick(step) + if not np.isfinite(measurements["training_loss"]): + raise RuntimeError(f"The loss became nan or inf somewhere within steps " + f"[{step - get_steps('log_training')}, {step}]") + + # Checkpoint saving + keep_ckpt_steps = get_steps("keep_ckpt", None) or total_steps + if save_ckpt_path and ( + (keep := u.itstime(step, keep_ckpt_steps, total_steps, first=False)) + or u.itstime(step, get_steps("ckpt", None), total_steps, first=True) + ): + u.chrono.pause(wait_for=train_state) + + # Copy because we add extra stuff to the checkpoint. + ckpt = {**train_state} + + # To save chrono state correctly and safely in a multihost setup, we + # broadcast the state to all hosts and convert it to a global array. + with jax.transfer_guard("allow"): + chrono_ckpt = multihost_utils.broadcast_one_to_all(u.chrono.save()) + chrono_shardings = jax.tree_map(lambda _: repl_sharding, chrono_ckpt) + ckpt = ckpt | {"chrono": u.reshard(chrono_ckpt, chrono_shardings)} + + u.save_checkpoint_ts(ckpt_mngr, ckpt, save_ckpt_path, step, keep) + u.chrono.resume() + + for (name, evaluator, log_steps, prefix) in evaluators(): + if u.itstime(step, log_steps, total_steps, first=False, last=True): + u.chrono.pause(wait_for=train_state) + u.chrono.tick(step) # Record things like epoch number, core hours etc. + write_note(f"{name} evaluation...\n{u.chrono.note}") + with u.chrono.log_timing(f"z/secs/eval/{name}"): + with mesh, nn.logical_axis_rules(sharding_rules): + for key, value in evaluator.run(train_state): + mw.measure(f"{prefix}{key}", jax.device_get(value)) + u.chrono.resume() + mw.step_end() + + # Always give a chance to stop the profiler, no matter how things ended. + if jax.process_index() == 0 and prof is not None: + u.startstop_prof(prof) + + # Last note needs to happen before the pool's closed =) + write_note(f"Done!\n{u.chrono.note}") + + pool.close() + pool.join() + mw.close() + + if ckpt_mngr: + ckpt_mngr.wait_until_finished() + + # Make sure all hosts stay up until the end of main. + u.sync() + + u.maybe_cleanup_workdir(workdir, flags.FLAGS.cleanup, info) + + +if __name__ == "__main__": + app.run(main) diff --git a/big_vision/trainers/proj/gsam/gsam.py b/big_vision/trainers/proj/gsam/gsam.py new file mode 100644 index 0000000000000000000000000000000000000000..89efa8d38c2dddf8acd66b8440beafb980710189 --- /dev/null +++ b/big_vision/trainers/proj/gsam/gsam.py @@ -0,0 +1,122 @@ +# Copyright 2022 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +'''This file provides jax implementation of GSAM.''' + +import jax +import jax.numpy as jnp + +def dual_vector(y): + """Returns the solution of max_x y^T x s.t. ||x||_2 <= 1. + Args: + y: A pytree of numpy ndarray, vector y in the equation above. + """ + gradient_norm = jnp.sqrt(sum( + jnp.sum(jnp.square(e)) for e in jax.tree_util.tree_leaves(y))) + normalized_gradient = jax.tree_map(lambda x: x / gradient_norm, y) + return normalized_gradient, gradient_norm + +def gsam_gradient(loss_fn, params, inputs, targets, + rho_max, rho_min, alpha, lr, lr_max, lr_min, eps=1e-12, + adaptive_perturbation=False, minimize_fp=True): + """ + Get the GSAM gradient (https://openreview.net/pdf?id=edONMAnhLu-). + Args: + loss_fn: the loss function. + params: the model weights. + inputs: the inputs to the loss function. + targets: the targets to the loss function. + rho_max: the maximum rho value for perturbation of weights. + rho_min: the minimum rho value for perturbation of weights. + alpha: the alpha value for the rho schedule, see Algorithm 1 in the paper. + lr: current learning rate. + lr_max: the maximum learning rate. + lr_min: the minimum learning rate. + eps: the epsilon value for numerical stability. + adaptive_perturbation: if False, same perturbation as SAM, + treat all parameters as a single vector, + perturbation norm is calculated as the norm of the whole vector; + If True, perturbation norm is proportional to parameter norm, + this stabilizes training when different layers have weights + of different scales. + Emprically, setting it to True can handle 10x larger rho than + setting it to False. + minimize_fp: if True, min(f_p, h), original GSAM; + if False, min(f, h), where f is the clean loss. + f_p is the perturbed loss, h is the surrogate gap. + If True, training dynamics is closer to SAM than conventional training, + you might observe several loss spikes during training. + If False, the training dynamics is closer to conventional training, + and is often more stable (fewer loss spikes) during training. + Returns: + l_clean: the loss function value. + g_gsam: the GSAM gradient. g_gsam is not averaged across workers, + need to call "jax.lax.pmean" to average. + + Note: + Setting `rho_max=rho_min` and `alpha=0` reduces GSAM to SAM. + """ + l_clean, g_clean = jax.value_and_grad(loss_fn)(params, inputs, targets) + g_clean_normalized, g_clean_length = dual_vector(g_clean) + + if lr_max == lr_min: + sam_rho = rho_max + else: + sam_rho = rho_min + (rho_max - rho_min) * (lr - lr_min) / (lr_max - lr_min) + + # Per-worker perturbation. + if adaptive_perturbation: + param_sam = jax.tree_map(lambda a, b: a + \ + jnp.abs(a) * sam_rho * b / (g_clean_length + eps), params, g_clean) + else: + param_sam = jax.tree_map(lambda a, b: a + \ + sam_rho * b / (g_clean_length + eps), params, g_clean) + + # Get gradients at perturbed weights. + _, g_robust = jax.value_and_grad(loss_fn)(param_sam, inputs, targets) + + # Decompose gradients. + g_clean_flatten, _ = jax.tree_util.tree_flatten(g_clean) + g_robust_flatten, _ = jax.tree_util.tree_flatten(g_robust) + + if minimize_fp: + # Decompose g_clean onto parallel and vertical to g_robust. + g_robust_normalized, _ = dual_vector(g_robust) + g_robust_normalized_flatten, _ = jax.tree_util.tree_flatten( + g_robust_normalized) + + g_clean_projection_norm = sum(jnp.vdot(p, q) for (p,q) in + zip(g_robust_normalized_flatten, g_clean_flatten)) + g_clean_residual = jax.tree_map(lambda a, b: + a - g_clean_projection_norm * b, g_clean, g_robust_normalized) + + # Get GSAM gradient. + g_gsam = jax.tree_map(lambda a, b: a - b * alpha, + g_robust, g_clean_residual) + else: + # Decompose g_robust onto parallel and vertical to g_clean. + g_clean_normalized, g_clean_length = dual_vector(g_clean) + g_clean_normalized_flatten, _ = jax.tree_util.tree_flatten( + g_clean_normalized) + + g_robust_projection_norm = sum(jnp.vdot(p, q) for (p,q) in + zip(g_clean_normalized_flatten, g_robust_flatten)) + g_robust_residual = jax.tree_map(lambda a, b: + a - g_robust_projection_norm * b, g_robust, g_clean_normalized) + + # Get GSAM gradient. + g_gsam = jax.tree_map(lambda a, b: a + b * alpha, + g_clean, g_robust_residual) + + # Always return the clean loss (rather than the perturbed loss). + return l_clean, g_gsam diff --git a/big_vision/trainers/proj/gsam/train.py b/big_vision/trainers/proj/gsam/train.py new file mode 100644 index 0000000000000000000000000000000000000000..8b66b03eb544987cbcbf91240ccf9ee5e1d5883c --- /dev/null +++ b/big_vision/trainers/proj/gsam/train.py @@ -0,0 +1,370 @@ +# Copyright 2022 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Training loop example. +Trainer that implements SAM/GSAM optimizers. +""" +# pylint: disable=consider-using-from-import +from functools import partial +import importlib +import multiprocessing.pool +import os + +from absl import app +from absl import flags +from absl import logging +import big_vision.evaluators.common as eval_common +import big_vision.input_pipeline as input_pipeline +import big_vision.optax as bv_optax +import big_vision.pp.builder as pp_builder +from big_vision.trainers.proj.gsam.gsam import gsam_gradient +import big_vision.utils as u +from clu import parameter_overview +import flax +import jax +import jax.numpy as jnp +from ml_collections import config_flags +import numpy as np +import optax +import tensorflow as tf +import tensorflow.io.gfile as gfile + +# pylint: disable=logging-fstring-interpolation + + +config_flags.DEFINE_config_file( + "config", None, "Training configuration.", lock_config=True) + +flags.DEFINE_string("workdir", default=None, help="Work unit directory.") +flags.DEFINE_boolean("cleanup", default=False, + help="Delete workdir (only) after successful completion.") + +# Adds jax flags to the program. +jax.config.parse_flags_with_absl() + + +def main(argv): + del argv + tf.config.experimental.set_visible_devices([], "GPU") + + config = flags.FLAGS.config + workdir = flags.FLAGS.workdir + logging.info( + f"\u001b[33mHello from process {jax.process_index()} holding " + f"{jax.local_device_count()}/{jax.device_count()} devices and " + f"writing to workdir {workdir}.\u001b[0m") + + assert not config.get("grad_accum_steps"), "Grad-acc not supported anymore." + + save_checkpoint_path = None + if workdir and config.get("checkpoint_steps"): + gfile.makedirs(workdir) + save_checkpoint_path = os.path.join(workdir, "checkpoint.npz") + + # The pool is used to perform misc operations such as logging in async way. + pool = multiprocessing.pool.ThreadPool() + + # Here we register preprocessing ops from modules listed on `pp_modules`. + for m in config.get("pp_modules", ["ops_general", "ops_image"]): + importlib.import_module(f"big_vision.pp.{m}") + + # This seed makes the Jax part of things (like model init) deterministic. + # However, full training still won't be deterministic, for example due to the + # tf.data pipeline not being deterministic even if we would set TF seed. + # See (internal link) for a fun read on what it takes. + rng = jax.random.PRNGKey(config.get("seed", 0)) + + # These functions do more stuff internally, for OSS release we mock them by + # trivial alternatives in order to minize disruptions in the code. + xid, wid = -1, -1 + fillin = lambda s: s + def info(s, *a): + logging.info("\u001b[33mNOTE\u001b[0m: " + s, *a) + def write_note(note): + if jax.process_index() == 0: + info("%s", note) + + # Verify settings to make sure no checkpoints are accidentally missed. + if config.get("keep_checkpoint_steps"): + assert config.get("checkpoint_steps"), "Specify `checkpoint_steps`." + assert config.keep_checkpoint_steps % config.checkpoint_steps == 0, ( + f"`keep_checkpoint_steps` ({config.checkpoint_steps}) should be" + f"divisible by `checkpoint_steps ({config.checkpoint_steps}).`") + + batch_size = config.batch_size + if batch_size % jax.device_count() != 0: + raise ValueError(f"Batch size ({batch_size}) must " + f"be divisible by device number ({jax.device_count()})") + info("Global batch size %d on %d hosts results in %d local batch size. With " + "%d dev per host (%d dev total), that's a %d per-device batch size.", + batch_size, jax.process_count(), batch_size // jax.process_count(), + jax.local_device_count(), jax.device_count(), + batch_size // jax.device_count()) + + # First thing after above sanity checks, so we can log "start" ticks. + mw = u.BigVisionMetricWriter(xid, wid, workdir) + chrono = u.Chrono() + + write_note("Initializing train dataset...") + train_ds = input_pipeline.make_for_train( + dataset=config.dataset, + split=config.train_split, + batch_size=config.batch_size, + preprocess_fn=pp_builder.get_preprocess_fn(config.pp_train), + shuffle_buffer_size=config.get("shuffle_buffer_size"), + cache_raw=config.get("cache_raw", False), + data_dir=fillin(config.get("dataset_dir"))) + + n_prefetch = config.get("prefetch_to_device", 1) + train_iter = input_pipeline.start_input_pipeline(train_ds, n_prefetch) + + ntrain_img = input_pipeline.get_num_examples( + config.dataset, config.train_split, + data_dir=fillin(config.get("dataset_dir"))) + steps_per_epoch = ntrain_img / batch_size + + if config.get("num_epochs"): + total_steps = int(config.num_epochs * steps_per_epoch) + assert not config.get("total_steps"), "Set either num_epochs or total_steps" + else: + total_steps = config.total_steps + + info("Running for %d steps, that means %f epochs and %f steps per epoch", + total_steps, total_steps * batch_size / ntrain_img, steps_per_epoch) + + write_note(f"Initializing {config.model_name} model...") + model_mod = importlib.import_module(f"big_vision.models.{config.model_name}") + model = model_mod.Model( + num_classes=config.num_classes, **config.get("model", {})) + + # We want all parameters to be created in host RAM, not on any device, they'll + # be sent there later as needed, otherwise we already encountered two + # situations where we allocate them twice. + @partial(jax.jit, backend="cpu") + def init(rng): + shape = tuple(train_ds.element_spec["image"].shape[1:]) + bs = config.batch_size // jax.device_count() + dummy_input = jnp.zeros((bs,) + shape, jnp.float32) + params = flax.core.unfreeze(model.init(rng, dummy_input))["params"] + + # Set bias in the head to a low value, such that loss is small initially. + if "init_head_bias" in config: + params["head"]["bias"] = jnp.full_like(params["head"]["bias"], + config["init_head_bias"]) + + return params + + rng, rng_init = jax.random.split(rng) + params_cpu = init(rng_init) + + if jax.process_index() == 0: + num_params = sum(p.size for p in jax.tree_util.tree_leaves(params_cpu)) + parameter_overview.log_parameter_overview(params_cpu, msg="init params") + mw.measure("num_params", num_params) + + write_note(f"Initializing {config.optax_name} optimizer...") + tx, sched_fns = bv_optax.make(config, params_cpu, sched_kw=dict( + global_batch_size=batch_size, + total_steps=total_steps, + steps_per_epoch=steps_per_epoch)) + + assert len(sched_fns) == 1, "Current GSAM supports one global learning-rate." + + # We jit this, such that the arrays are created on the CPU, not device[0]. + opt_cpu = jax.jit(tx.init, backend="cpu")(params_cpu) + sched_fns_cpu = [jax.jit(sched_fn, backend="cpu") for sched_fn in sched_fns] + + @partial(jax.pmap, axis_name="batch", donate_argnums=(0, 1)) + def update_fn(params, opt, rng, images, labels, step): + """Update step.""" + + measurements = {} + + if config.get("mixup") and config.mixup.p: + rng, (images, labels), _ = u.mixup(rng, images, labels, **config.mixup) + + # Get device-specific loss rng. + rng, rng_model = jax.random.split(rng, 2) + rng_model_local = jax.random.fold_in(rng_model, jax.lax.axis_index("batch")) + + def loss_fn(params, images, labels): + logits, _ = model.apply( + {"params": flax.core.freeze(params)}, images, + train=True, rngs={"dropout": rng_model_local}) + return getattr(u, config.get("loss", "sigmoid_xent"))( + logits=logits, labels=labels) + + learning_rate = sched_fns[0](step) * config.lr + l, grads = gsam_gradient(loss_fn=loss_fn, params=params, inputs=images, + targets=labels, lr=learning_rate, **config.gsam) + l, grads = jax.lax.pmean((l, grads), axis_name="batch") + updates, opt = tx.update(grads, opt, params) + params = optax.apply_updates(params, updates) + + gs = jax.tree_leaves(bv_optax.replace_frozen(config.schedule, grads, 0.)) + measurements["l2_grads"] = jnp.sqrt(sum(jnp.vdot(g, g) for g in gs)) + ps = jax.tree_util.tree_leaves(params) + measurements["l2_params"] = jnp.sqrt(sum(jnp.vdot(p, p) for p in ps)) + us = jax.tree_util.tree_leaves(updates) + measurements["l2_updates"] = jnp.sqrt(sum(jnp.vdot(u, u) for u in us)) + + return params, opt, rng, l, measurements + + # We do not jit/pmap this function, because it is passed to evaluator that + # does it later. We output as many intermediate tensors as possible for + # maximal flexibility. Later `jit` will prune out things that are not needed. + def predict_fn(params, image): + logits, out = model.apply({"params": params}, image) + return logits, out + + # Decide how to initialize training. The order is important. + # 1. Always resumes from the existing checkpoint, e.g. resumes a finetune job. + # 2. Resume from a previous checkpoint, e.g. start a cooldown training job. + # 3. Initialize model from something, e,g, start a fine-tuning job. + # 4. Train from scratch. + resume_checkpoint_path = None + if save_checkpoint_path and gfile.exists(save_checkpoint_path): + resume_checkpoint_path = save_checkpoint_path + elif config.get("resume"): + resume_checkpoint_path = fillin(config.resume) + if resume_checkpoint_path: + write_note("Resume training from checkpoint...") + checkpoint = { + "params": params_cpu, + "opt": opt_cpu, + "chrono": chrono.save(), + } + checkpoint_tree = jax.tree_structure(checkpoint) + loaded = u.load_checkpoint(checkpoint_tree, resume_checkpoint_path) + # bfloat16 type gets lost when data is saved to disk, so we recover it. + checkpoint = jax.tree_map(u.recover_dtype, loaded) + params_cpu, opt_cpu = checkpoint["params"], checkpoint["opt"] + chrono.load(checkpoint["chrono"]) + elif config.get("model_init"): + write_note(f"Initialize model from {config.model_init}...") + params_cpu = model_mod.load( + params_cpu, config.model_init, config.get("model"), + **config.get("model_load", {})) + if jax.process_index() == 0: + parameter_overview.log_parameter_overview( + params_cpu, msg="restored params") + + write_note("Kicking off misc stuff...") + first_step = bv_optax.get_count(opt_cpu) + chrono.inform(first_step, total_steps, batch_size, steps_per_epoch) + prof = None # Keeps track of start/stop of profiler state. + + write_note(f"Replicating...\n{chrono.note}") + params_repl = flax.jax_utils.replicate(params_cpu) + opt_repl = flax.jax_utils.replicate(opt_cpu) + + evaluators = eval_common.from_config( + config, {"predict": predict_fn}, + lambda s: write_note(f"Initializing evaluator: {s}...\n{chrono.note}")) + + rng, rng_loop = jax.random.split(rng, 2) + rngs_loop = flax.jax_utils.replicate(rng_loop) + checkpoint_writer = None + + write_note(f"First step compilations...\n{chrono.note}") + error = None # For exiting with an error after cleanup. Avoids indentation. + # Using a python integer for step here, because opt.state.step is allocated + # on TPU during replication. + for step, train_batch in zip( + range(first_step + 1, total_steps + 1), train_iter): + mw.step_start(step) + + with jax.profiler.StepTraceAnnotation("train_step", step_num=step): + params_repl, opt_repl, rngs_loop, loss_value, measurements = update_fn( + params_repl, opt_repl, rngs_loop, + train_batch["image"], + train_batch["labels"], + flax.jax_utils.replicate(step)) + + # On the first host, let's always profile a handful of early steps. + if jax.process_index() == 0: + prof = u.startstop_prof(prof, step, first_step, config.log_training_steps) + + # Report training progress + if (u.itstime(step, config.log_training_steps, total_steps, host=0) + or chrono.warmup and jax.process_index() == 0): + for i, sched_fn_cpu in enumerate(sched_fns_cpu): + mw.measure(f"global_schedule{i if i else ''}", sched_fn_cpu(step - 1)) + l = mw.measure("training_loss", loss_value[0]) + for name, value in measurements.items(): + mw.measure(name, value[0]) + chrono.tick(step, mw.measure, write_note) + if not np.isfinite(l): + error = (f"The loss became nan or inf somewhere within steps " + f"[{step - config.log_training_steps}, {step}]") + break + + # Checkpoint saving + if (save_checkpoint_path and + u.itstime(step, config.get("checkpoint_steps"), total_steps, host=0)): + chrono.pause(wait_for=(params_repl, opt_repl)) + u.checkpointing_timeout(checkpoint_writer, + config.get("checkpoint_timeout", 1)) + # We need to transfer the weights over now or else we risk keeping them + # alive while they'll be updated in a future step, creating hard to debug + # memory errors (see (internal link)). Also, takes device 0's params only. + params_cpu = jax.tree_map(lambda x: np.array(x[0]), params_repl) + opt_cpu = jax.tree_map(lambda x: np.array(x[0]), opt_repl) + + # Check whether we want to keep a copy of the current checkpoint. + copy_step = None + if u.itstime(step, config.get("keep_checkpoint_steps"), total_steps): + copy_step = step + + ckpt = {"params": params_cpu, "opt": opt_cpu, "chrono": chrono.save()} + checkpoint_writer = pool.apply_async( + u.save_checkpoint, (ckpt, save_checkpoint_path, copy_step)) + chrono.resume() + + for (name, evaluator, log_steps, prefix) in evaluators: + if u.itstime(step, log_steps, total_steps): + chrono.pause(wait_for=params_repl) + write_note(f"{name} evaluation...\n{chrono.note}") + for key, value in evaluator.run(params_repl): + mw.measure(f"{prefix}{key}", value) + chrono.resume() + mw.step_end() + + # Always give a chance to stop the profiler, no matter how things ended. + # TODO: can we also do this when dying of an exception like OOM? + if jax.process_index() == 0 and prof is not None: + u.startstop_prof(prof) + + # Last note needs to happen before the pool's closed =) + if not error: + write_note(f"Done!\n{chrono.note}") + else: + write_note(f"Failed!\n{error}\n{chrono.note}") + + pool.close() + pool.join() + mw.close() + + # Make sure all hosts stay up until the end of main. + u.sync_all_hosts() + + # Before cleanup, as cleanup should only run for successful jobs. + if error is not None: + raise RuntimeError(error) + + u.maybe_cleanup_workdir(workdir, flags.FLAGS.cleanup, info) + + +if __name__ == "__main__": + app.run(main) diff --git a/big_vision/trainers/proj/image_text/_deprecated_contrastive.py b/big_vision/trainers/proj/image_text/_deprecated_contrastive.py new file mode 100644 index 0000000000000000000000000000000000000000..863941a8e5d84b73530b93a19591cb65e01e0dcf --- /dev/null +++ b/big_vision/trainers/proj/image_text/_deprecated_contrastive.py @@ -0,0 +1,514 @@ +# Copyright 2023 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Contrastive training loop. + +For models Like +- LiT (https://arxiv.org/abs/2111.07991) +- CLIP (https://arxiv.org/abs/2103.00020) +- SigLIP (https://arxiv.org/abs/2303.15343) +""" +# pylint: disable=consider-using-from-import +import functools +import importlib +import multiprocessing.pool +import os + +from absl import app +from absl import flags +from absl import logging +import big_vision.evaluators.common as eval_common +import big_vision.input_pipeline as input_pipeline +import big_vision.optax as bv_optax +import big_vision.utils as u +from clu import parameter_overview +import flax +import jax +import jax.numpy as jnp +from ml_collections import config_flags +import numpy as np +import optax +import tensorflow as tf + +from tensorflow.io import gfile + +# pylint: disable=logging-fstring-interpolation + + +config_flags.DEFINE_config_file( + "config", None, "Training configuration.", lock_config=True) + +flags.DEFINE_string("workdir", default=None, help="Work unit directory.") +flags.DEFINE_boolean("cleanup", default=False, + help="Delete workdir (only) after successful completion.") + +# Adds jax flags to the program. +jax.config.parse_flags_with_absl() + + +def clip(x, *, a_max=None, a_min=None): + """Like jnp.clip, but allows all-None to mean don't clip.""" + if a_max is None and a_min is None: + return x + return jnp.clip(x, a_max=a_max, a_min=a_min) + + +def all_gather(z, roll=False, only_others=False): + """All gather and flatten first two dims.""" + def gather_flat(x): + x = jax.lax.all_gather(x, "batch") + if roll or only_others: + # Each device moves "its" chunk to the beginning. Simplies loss/acc calcs. + x = jnp.roll(x, -jax.lax.axis_index("batch"), axis=0) + if only_others: + x = x[1:] + return jnp.concatenate(x, 0) # Fold in "device" and "batch" dims. + return jax.tree_map(gather_flat, z) + + +def softmax_loss(zimg, ztxt, temperature): + """Softmax loss following the CLIP paper. Factorized to reduce memory cost.""" + + def unidirectional_loss(z1, z2, t): + z2 = all_gather(z2, roll=True) + logits = jnp.dot(z1, z2.T) * t + # This a softmax across the larger gathered axis, taking advantage of the + # fact that positives are known to be on the diagonal. + loss = -(jnp.diag(logits) - jax.scipy.special.logsumexp(logits, axis=-1)) + acc = jnp.argmax(logits, axis=1) == jnp.arange(z1.shape[0]) + return loss.mean(), acc.mean() + + extras = {} + loss = 0 + for name, row, col in [("i2t", zimg, ztxt), ("t2i", ztxt, zimg)]: + loss_dir, acc_dir = unidirectional_loss(row, col, temperature) + loss += 0.5 * loss_dir + extras[f"{name}_acc"] = acc_dir + extras[f"{name}_loss"] = loss_dir + + loss = jax.lax.pmean(loss, "batch") + return loss, extras + + +def _avg_pos_logit(x_me): + return jnp.mean(jnp.diag(x_me)) + + +def _avg_neg_logit(x_me, x_ot=None): + nom = jnp.sum(x_me) - jnp.sum(jnp.diag(x_me)) + den = x_me.size - len(x_me) + if x_ot is not None and x_ot.size: + nom += jnp.sum(x_ot) + den += x_ot.size + return nom / den + + +def sigmoid_loss(zimg, ztxt, temperature, bias=0.0): + """Sigmoid loss from SigLIP: https://arxiv.org/abs/2303.15343.""" + # Sigmoid loss. Since it's unidirectional, image embeddings stick to + # "me", i.e. the device they are computed on, and text embeddings travel. + ztxt_me = ztxt # Text embeddings on my devices: (n, D) + ztxt_ot = all_gather(ztxt, only_others=True) # Text emb from others: (N, D) + + logits_me = jnp.dot(zimg, ztxt_me.T) # (n, D) . (D, n) -> (n, n) + logits_ot = jnp.dot(zimg, ztxt_ot.T) # (n, D) . (D, N) -> (n, N) + logits_me = logits_me * temperature + bias + logits_ot = logits_ot * temperature + bias + + eye = jnp.eye(zimg.shape[0]) + # Standard sigmoid computes everything twice, once assuming positive + # labels and once assuming negative ones. But here we know exactly where + # to find positives (on "me" diagonal) and negatives (everywhere else), + # so compute each one's loss only once: + m1_diag1 = -jnp.ones_like(logits_me) + 2 * eye + loglik_me = jax.nn.log_sigmoid(m1_diag1 * logits_me) + loglik_ot = jax.nn.log_sigmoid(-logits_ot) + + # Normalize by npos per column, but that's one, so just sum. + nll_me = -loglik_me.sum(axis=-1) + nll_ot = -loglik_ot.sum(axis=-1) + l = nll_me.mean() + nll_ot.mean() # == concat'ing me/ot along axis -1 above. + + return l, { + # Only local device metrics for now, as last time I tried, there was + # some funny unimplemented business with jax.lax.pmin/pmax! + # So what's reported here is average of per-device min/max/avg. + "pos_min_logit": jnp.min(jnp.diag(logits_me)), + "pos_max_logit": jnp.max(jnp.diag(logits_me)), + "pos_avg_logit": _avg_pos_logit(logits_me), + "local_neg_min_logit": jnp.min(logits_me + 1e9 * eye), + "local_neg_max_logit": jnp.max(logits_me - 1e9 * eye), + "local_neg_avg_logit": _avg_neg_logit(logits_me), + "neg_min_logit": jnp.minimum( + jnp.min(logits_me + 1e9 * eye), + jnp.min(logits_ot) if logits_ot.size else jnp.inf), + "neg_max_logit": jnp.maximum( + jnp.max(logits_me - 1e9 * eye), + jnp.max(logits_ot) if logits_ot.size else -jnp.inf), + "neg_avg_logit": _avg_neg_logit(logits_me, logits_ot), + } + + +def _gather_from_device(x, device_id, axis_name="batch"): + return jax.lax.psum((jax.lax.axis_index(axis_name) == device_id) * x, + axis_name) + + +def chunked_sigmoid_loss(zimg, ztxt, temperature, bias=0.0): + """Loss computation from section 3.1 of arxiv.org/abs/2303.15343.""" + + # Calculate loss for representations on this device, which includes positives. + logits_me = jnp.dot(zimg, ztxt.T) # (n, D) . (D, n) -> (n, n) + logits_me = logits_me * temperature + bias + m1_diag1 = -jnp.ones_like(logits_me) + 2 * jnp.eye(zimg.shape[0]) + loglik_me = jax.nn.log_sigmoid(m1_diag1 * logits_me) + nll_me = -loglik_me.sum(axis=-1).mean() + + def negative_loss(ztxt_other_device): + logits_ot = jnp.dot(zimg, ztxt_other_device.T) # (n, D) . (D, n) -> (n, n) + logits_ot = logits_ot * temperature + bias + loglik_ot = jax.nn.log_sigmoid(-logits_ot) + return -jnp.sum(loglik_ot, axis=-1).mean() + + me = jax.lax.axis_index("batch") + # All other devices are negatives. Hot-potato swap ztxt across devices. + # Interestingly, ppermute based implementation was memory intensive, so using + # all-reduce to gather representations. + nll_others = 0 + for device_id in range(jax.device_count()): + skip = jnp.not_equal(device_id, me) + nll_others += skip * negative_loss(_gather_from_device(ztxt, device_id)) + + eye = jnp.eye(zimg.shape[0]) + return nll_me + nll_others, { + "pos_min_logit": jnp.min(jnp.diag(logits_me)), + "pos_max_logit": jnp.max(jnp.diag(logits_me)), + "pos_avg_logit": _avg_pos_logit(logits_me), + "local_neg_min_logit": jnp.min(logits_me + 1e9 * eye), + "local_neg_max_logit": jnp.max(logits_me - 1e9 * eye), + "local_neg_avg_logit": _avg_neg_logit(logits_me),} + + +def main(argv): + del argv + tf.config.experimental.set_visible_devices([], "GPU") + + config = flags.FLAGS.config + workdir = flags.FLAGS.workdir + logging.info( # pylint: disable=logging-fstring-interpolation + f"\u001b[33mHello from process {jax.process_index()} holding " + f"{jax.local_device_count()}/{jax.device_count()} devices and " + f"writing to workdir {workdir}.\u001b[0m") + + save_ckpt_path = None + if workdir: # Always create if requested, even if we may not write into it. + gfile.makedirs(workdir) + save_ckpt_path = os.path.join(workdir, "checkpoint.npz") + + # The pool is used to perform misc operations such as logging in async way. + pool = multiprocessing.pool.ThreadPool() + + # Here we register preprocessing ops from modules listed on `pp_modules`. + for m in config.get("pp_modules", ["ops_general", "ops_image", "ops_text"]): + importlib.import_module(f"big_vision.pp.{m}") + + # This seed makes the Jax part of things (like model init) deterministic. + # However, full training still won't be deterministic, for example due to the + # tf.data pipeline not being deterministic even if we would set TF seed. + # See (internal link) for a fun read on what it takes. + rng = jax.random.PRNGKey(config.get("seed", 0)) + + # These functions do more stuff internally, for OSS release we mock them by + # trivial alternatives in order to minize disruptions in the code. + xid, wid = -1, -1 + def info(s, *a): + logging.info("\u001b[33mNOTE\u001b[0m: " + s, *a) + def write_note(note): + if jax.process_index() == 0: + info("%s", note) + + write_note("Initializing...") + + batch_size = config.input.batch_size + if batch_size % jax.device_count() != 0: + raise ValueError(f"Batch size ({batch_size}) must " + f"be divisible by device number ({jax.device_count()})") + info("Global batch size %d on %d hosts results in %d local batch size. With " + "%d dev per host (%d dev total), that's a %d per-device batch size.", + batch_size, jax.process_count(), batch_size // jax.process_count(), + jax.local_device_count(), jax.device_count(), + batch_size // jax.device_count()) + + # First thing after above sanity checks, so we can log "start" ticks. + mw = u.BigVisionMetricWriter(xid, wid, workdir, config) + + write_note("Initializing train dataset...") + train_ds, ntrain_img = input_pipeline.training(config.input) + + # Start prefetching already. + n_prefetch = config.get("prefetch_to_device", 1) + train_iter = input_pipeline.start_input_pipeline(train_ds, n_prefetch) + + total_steps = u.steps("total", config, ntrain_img, batch_size) + def get_steps(name, default=ValueError, cfg=config): + return u.steps(name, cfg, ntrain_img, batch_size, total_steps, default) + + u.chrono.inform(total_steps=total_steps, global_bs=batch_size, + steps_per_epoch=ntrain_img / batch_size, + measure=mw.measure, write_note=write_note) + + info("Running for %d steps, that means %f epochs", + total_steps, total_steps * batch_size / ntrain_img) + + write_note(f"Initializing {config.model_name} model...") + model_mod = importlib.import_module(f"big_vision.models.{config.model_name}") + model = model_mod.Model(**config.get("model", {})) + + # We want all parameters to be created in host RAM, not on any device, they'll + # be sent there later as needed, otherwise we already encountered two + # situations where we allocate them twice. + @functools.partial(jax.jit, backend="cpu") + def init(rng): + bs = batch_size // jax.device_count() + image_size = tuple(train_ds.element_spec["image"].shape[1:]) + no_image = jnp.zeros((bs,) + image_size, jnp.float32) + text_size = tuple(train_ds.element_spec["labels"].shape[1:]) + no_text = jnp.zeros((bs,) + text_size, jnp.int32) + params = flax.core.unfreeze(model.init(rng, no_image, no_text))["params"] + return params + + rng, rng_init = jax.random.split(rng) + with u.chrono.log_timing("z/secs/init"): + params_cpu = init(rng_init) + + if jax.process_index() == 0: + num_params = sum(p.size for p in jax.tree_leaves(params_cpu)) + parameter_overview.log_parameter_overview(params_cpu, msg="init params") + mw.measure("num_params", num_params) + + write_note(f"Initializing {config.optax_name} optimizer...") + tx, sched_fns = bv_optax.make(config, params_cpu, sched_kw=dict( + total_steps=total_steps, batch_size=batch_size, data_size=ntrain_img)) + + # We jit this, such that the arrays are created on the CPU, not device[0]. + opt_cpu = jax.jit(tx.init, backend="cpu")(params_cpu) + sched_fns_cpu = [jax.jit(sched_fn, backend="cpu") for sched_fn in sched_fns] + + @functools.partial(jax.pmap, axis_name="batch", donate_argnums=(0, 1)) + def update_fn(params, opt, rng, batch): + """Update step.""" + assert "mixup" not in config, "We still have to figure out mixup." + + # Get device-specific loss rng. + rng, rng_model = jax.random.split(rng, 2) + rng_model_local = jax.random.fold_in(rng_model, jax.lax.axis_index("batch")) + + def loss_fn(params, images, labels): + zimg, ztxt, extras = model.apply( + {"params": params}, images, labels, + train=True, rngs={"dropout": rng_model_local}) + + match config.get("loss_fn", "softmax"): + case "softmax": + l, l_extras = softmax_loss(zimg, ztxt, extras["t"]) + case "sigmoid": + l, l_extras = sigmoid_loss(zimg, ztxt, extras["t"], bias=extras["b"]) + case "chunked_sigmoid": + l, l_extras = chunked_sigmoid_loss(zimg, ztxt, extras["t"], + bias=extras["b"]) + case _: + raise NotImplementedError(f"Unrecognized loss {config.loss_fn=}") + + return l, { + "t": extras["t"], + "t/parameter": extras["t/parameter"], + "train/nimg": jnp.mean(extras["img/norm"]), + "train/ntxt": jnp.mean(extras["txt/norm"]), + **{f"train/{k}": v for k, v in l_extras.items()}, + } + + (l, measurements), grads = jax.value_and_grad( + loss_fn, has_aux=True)(params, batch["image"], batch["labels"]) + l, measurements, grads = jax.lax.pmean((l, measurements, grads), + axis_name="batch") + updates, opt = tx.update(grads, opt, params) + params = optax.apply_updates(params, updates) + + gs = jax.tree_leaves(bv_optax.replace_frozen(config.schedule, grads, 0.)) + measurements["l2_grads"] = jnp.sqrt(sum([jnp.vdot(g, g) for g in gs])) + ps = jax.tree_leaves(params) + measurements["l2_params"] = jnp.sqrt(sum([jnp.vdot(p, p) for p in ps])) + us = jax.tree_leaves(updates) + measurements["l2_updates"] = jnp.sqrt(sum([jnp.vdot(u, u) for u in us])) + + return params, opt, rng, l, measurements + + # We require hashable function reference for evaluator. + # We do not jit/pmap this function, because it is passed to evaluator that + # does it later. We output as many intermediate tensors as possible for + # maximal flexibility. Later `jit` will prune out things that are not needed. + def predict_fn(params, image=None, text=None, **unused_kwargs): + del unused_kwargs # `unused_kwargs` is to be compatible with few-shot + zimg, ztxt, out = model.apply({"params": params}, image, text) + return zimg, ztxt, out + + # Only initialize evaluators when they are first needed. + @functools.lru_cache(maxsize=None) + def evaluators(): + return eval_common.from_config( + config, {"predict": predict_fn}, + lambda s: write_note(f"Init evaluator: {s}…\n{u.chrono.note}"), + lambda key, cfg: get_steps(key, default=None, cfg=cfg), + ) + + # Decide how to initialize training. The order is important. + # 1. Always resumes from the existing checkpoint, e.g. resumes a finetune job. + # 2. Resume from a previous checkpoint, e.g. start a cooldown training job. + # 3. Initialize model from something, e,g, start a fine-tuning job. + # 4. Train from scratch. + resume_ckpt_path = None + if save_ckpt_path and gfile.exists(save_ckpt_path): + resume_ckpt_path = save_ckpt_path + elif config.get("resume"): + resume_ckpt_path = config.resume.format(wid=xm_wu.id) + if resume_ckpt_path: + write_note("Resume training from checkpoint...") + checkpoint = { + "params": params_cpu, + "opt": opt_cpu, + "chrono": u.chrono.save(), + } + checkpoint_tree = jax.tree_structure(checkpoint) + loaded = u.load_checkpoint_np(resume_ckpt_path, checkpoint_tree) + # bfloat16 type gets lost when data is saved to disk, so we recover it. + checkpoint = jax.tree_map(u.recover_dtype, loaded) + params_cpu, opt_cpu = checkpoint["params"], checkpoint["opt"] + u.chrono.load(checkpoint["chrono"]) + elif config.get("model_init"): + write_note(f"Initialize model from {config.model_init}...") + params_cpu = model_mod.load( + params_cpu, config.model_init, config.get("model"), + **config.get("model_load", {})) + if jax.process_index() == 0: + parameter_overview.log_parameter_overview( + params_cpu, msg="restored params") + + write_note("Kicking off misc stuff...") + first_step = bv_optax.get_count(opt_cpu) + u.chrono.inform(first_step=first_step) + prof = None # Keeps track of start/stop of profiler state. + + write_note(f"Replicating...\n{u.chrono.note}") + params_repl = flax.jax_utils.replicate(params_cpu) + opt_repl = flax.jax_utils.replicate(opt_cpu) + + rng, rng_loop = jax.random.split(rng, 2) + rngs_loop = flax.jax_utils.replicate(rng_loop) + ckpt_writer = None + + write_note(f"First step compilations...\n{u.chrono.note}") + + # Note that training can be pre-empted during the final evaluation (i.e. + # just after the final checkpoint has been written to disc), in which case we + # want to run the evals. + if first_step in (total_steps, 0): + mw.step_start(first_step) + for (name, evaluator, _, prefix) in evaluators(): + if config.evals[name].get("skip_first") and first_step != total_steps: + continue + write_note(f"{name} evaluation...\n{u.chrono.note}") + with u.chrono.log_timing(f"z/secs/eval/{name}"): + for key, value in evaluator.run(params_repl): + mw.measure(f"{prefix}{key}", value) + + # Using a python integer for step here, because opt.state.step is allocated + # on TPU during replication. + for step, batch in zip(range(first_step + 1, total_steps + 1), train_iter): + mw.step_start(step) + + with jax.profiler.StepTraceAnnotation("train_step", step_num=step): + with u.chrono.log_timing("z/secs/update0", noop=step > first_step + 1): + params_repl, opt_repl, rngs_loop, loss_value, measurements = update_fn( + params_repl, opt_repl, rngs_loop, batch) + + # On the first host, let's always profile a handful of early steps. + if jax.process_index() == 0: + prof = u.startstop_prof(prof, step, first_step, get_steps("log_training")) + + # Report training progress + if (u.itstime(step, get_steps("log_training"), total_steps, host=0) + or u.chrono.warmup and jax.process_index() == 0): + for i, sched_fn_cpu in enumerate(sched_fns_cpu): + mw.measure(f"global_schedule{i if i else ''}", sched_fn_cpu(step - 1)) + l = mw.measure("training_loss", loss_value[0]) + for name, value in measurements.items(): + mw.measure(name, value[0]) + u.chrono.tick(step) + if not np.isfinite(l): + raise RuntimeError(f"The loss became nan or inf somewhere within steps " + f"[{step - get_steps('log_training')}, {step}]") + + # Checkpoint saving + if (save_ckpt_path and + (u.itstime(step, get_steps("ckpt", None), total_steps, host=0) or + u.itstime(step, get_steps("keep_ckpt", None), total_steps, host=0))): + u.chrono.pause(wait_for=(params_repl, opt_repl)) + u.checkpointing_timeout(ckpt_writer, config.get("ckpt_timeout", 1)) + # We need to transfer the weights over now or else we risk keeping them + # alive while they'll be updated in a future step, creating hard to debug + # memory errors (see (internal link)). Also, takes device 0's params only. + params_cpu = jax.tree_map(lambda x: np.array(x[0]), params_repl) + opt_cpu = jax.tree_map(lambda x: np.array(x[0]), opt_repl) + + # Check whether we want to keep a copy of the current checkpoint. + copy_step = None + if u.itstime(step, get_steps("keep_ckpt", None), total_steps): + copy_step = step + + ckpt = {"params": params_cpu, "opt": opt_cpu, "chrono": u.chrono.save()} + ckpt_writer = pool.apply_async( + u.save_checkpoint, (ckpt, save_ckpt_path, copy_step)) + u.chrono.resume() + + for (name, evaluator, log_steps, prefix) in evaluators(): + if u.itstime(step, log_steps, total_steps, first=False, last=True): + u.chrono.pause(wait_for=params_repl) + u.chrono.tick(step) # Record things like epoch number, core hours etc. + write_note(f"{name} evaluation...\n{u.chrono.note}") + with u.chrono.log_timing(f"z/secs/eval/{name}"): + for key, value in evaluator.run(params_repl): + mw.measure(f"{prefix}{key}", value) + u.chrono.resume() + mw.step_end() + + # Always give a chance to stop the profiler, no matter how things ended. + # TODO: can we also do this when dying of an exception like OOM? + if jax.process_index() == 0 and prof is not None: + u.startstop_prof(prof) + + # Last note needs to happen before the pool's closed =) + write_note(f"Done!\n{u.chrono.note}") + + pool.close() + pool.join() + mw.close() + + # Make sure all hosts stay up until the end of main. + u.sync() + + u.maybe_cleanup_workdir(workdir, flags.FLAGS.cleanup, info) + + +if __name__ == "__main__": + app.run(main) diff --git a/big_vision/trainers/proj/image_text/siglip.py b/big_vision/trainers/proj/image_text/siglip.py new file mode 100644 index 0000000000000000000000000000000000000000..3ea6118be9f5d1f581667737cb774412672fcaa4 --- /dev/null +++ b/big_vision/trainers/proj/image_text/siglip.py @@ -0,0 +1,527 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Trainer for "Sigmoid Loss for Language Image Pre-Training". + +SigLIP (https://arxiv.org/abs/2303.15343) + +TODO: implement chunked version with shard_map. +""" +# pylint: disable=consider-using-from-import +# pylint: disable=logging-fstring-interpolation + +import functools +import importlib +import multiprocessing.pool +import os + +from absl import app +from absl import flags +from absl import logging +import big_vision.evaluators.common as eval_common +import big_vision.input_pipeline as input_pipeline +import big_vision.optax as bv_optax +import big_vision.sharding as bv_sharding +import big_vision.utils as u +from clu import parameter_overview +import flax.linen as nn +import jax +from jax.experimental import mesh_utils +from jax.experimental import multihost_utils +from jax.experimental.array_serialization import serialization as array_serial +import jax.numpy as jnp +from ml_collections import config_flags +import numpy as np +import optax +import tensorflow as tf + +from tensorflow.io import gfile + + +config_flags.DEFINE_config_file( + "config", None, "Training configuration.", lock_config=True) + +flags.DEFINE_string("workdir", default=None, help="Work unit directory.") +flags.DEFINE_boolean("cleanup", default=False, + help="Delete workdir (only) after successful completion.") + +# Adds jax flags to the program. +jax.config.parse_flags_with_absl() +# Transfer guard will fail the program whenever that data between a host and +# a device is transferred implicitly. This often catches subtle bugs that +# cause slowdowns and memory fragmentation. Explicit transfers are done +# with jax.device_put and jax.device_get. +jax.config.update("jax_transfer_guard", "disallow") +# Fixes design flaw in jax.random that may cause unnecessary d2d comms. +jax.config.update("jax_threefry_partitionable", True) + + +NamedSharding = jax.sharding.NamedSharding +P = jax.sharding.PartitionSpec + + +def main(argv): + del argv + + jax.distributed.initialize() + + # Make sure TF does not touch GPUs. + tf.config.set_visible_devices([], "GPU") + + config = flags.FLAGS.config + +################################################################################ +# # +# Set up logging # +# # +################################################################################ + + # Set up work directory and print welcome message. + workdir = flags.FLAGS.workdir + logging.info( + f"\u001b[33mHello from process {jax.process_index()} holding " + f"{jax.local_device_count()}/{jax.device_count()} devices and " + f"writing to workdir {workdir}.\u001b[0m") + + save_ckpt_path = None + if workdir: # Always create if requested, even if we may not write into it. + gfile.makedirs(workdir) + save_ckpt_path = os.path.join(workdir, "checkpoint.bv") + + # The pool is used to perform misc operations such as logging in async way. + pool = multiprocessing.pool.ThreadPool() + + # Here we register preprocessing ops from modules listed on `pp_modules`. + for m in config.get("pp_modules", ["ops_general", "ops_image", "ops_text"]): + importlib.import_module(f"big_vision.pp.{m}") + + # Setup up logging and experiment manager. + xid, wid = -1, -1 + fillin = lambda s: s + def info(s, *a): + logging.info("\u001b[33mNOTE\u001b[0m: " + s, *a) + def write_note(note): + if jax.process_index() == 0: + info("%s", note) + + mw = u.BigVisionMetricWriter(xid, wid, workdir, config) + + # Allow for things like timings as early as possible! + u.chrono.inform(measure=mw.measure, write_note=write_note) + +################################################################################ +# # +# Set up Mesh # +# # +################################################################################ + + # We rely on jax mesh_utils to organize devices, such that communication + # speed is the fastest for the last dimension, second fastest for the + # penultimate dimension, etc. + config_mesh = config.get("mesh", [("data", jax.device_count())]) + + # Sharding rules with default + sharding_rules = config.get("sharding_rules", [("act_batch", "data")]) + + mesh_axes, mesh_size = tuple(zip(*config_mesh)) + + # Because jax.utils do not support `-1` shape size. + mesh_size = np.array(jax.devices()).reshape(mesh_size).shape + + device_mesh = mesh_utils.create_device_mesh(mesh_size) + + # Consistent device order is important to ensure correctness of various train + # loop components, such as input pipeline, update step, evaluators. The + # order presribed by the `devices_flat` variable should be used throughout + # the program. + devices_flat = device_mesh.flatten() + +################################################################################ +# # +# Input Pipeline # +# # +################################################################################ + + write_note("Initializing train dataset...") + batch_size = config.input.batch_size + if batch_size % jax.device_count() != 0: + raise ValueError(f"Batch size ({batch_size}) must " + f"be divisible by device number ({jax.device_count()})") + info("Global batch size %d on %d hosts results in %d local batch size. With " + "%d dev per host (%d dev total), that's a %d per-device batch size.", + batch_size, jax.process_count(), batch_size // jax.process_count(), + jax.local_device_count(), jax.device_count(), + batch_size // jax.device_count()) + + train_ds, ntrain_img = input_pipeline.training(config.input) + + total_steps = u.steps("total", config, ntrain_img, batch_size) + def get_steps(name, default=ValueError, cfg=config): + return u.steps(name, cfg, ntrain_img, batch_size, total_steps, default) + + u.chrono.inform(total_steps=total_steps, global_bs=batch_size, + steps_per_epoch=ntrain_img / batch_size) + + info("Running for %d steps, that means %f epochs", + total_steps, total_steps * batch_size / ntrain_img) + + # Start input pipeline as early as possible. + n_prefetch = config.get("prefetch_to_device", 1) + train_iter = input_pipeline.start_global(train_ds, devices_flat, n_prefetch) + +################################################################################ +# # +# Create Model & Optimizer # +# # +################################################################################ + + write_note("Creating model...") + model_mod = importlib.import_module(f"big_vision.models.{config.model_name}") + model = model_mod.Model(**config.get("model", {})) + + def init(rng): + batch = jax.tree_map(lambda x: jnp.zeros(x.shape, x.dtype.as_numpy_dtype), + train_ds.element_spec) + params = model.init(rng, batch["image"], batch["labels"])["params"] + + # Set bias in the head to a low value, such that loss is small initially. + if "init_head_bias" in config: + params["head"]["bias"] = jnp.full_like(params["head"]["bias"], + config["init_head_bias"]) + + return params + + # This seed makes the Jax part of things (like model init) deterministic. + # However, full training still won't be deterministic, for example due to the + # tf.data pipeline not being deterministic even if we would set TF seed. + # See (internal link) for a fun read on what it takes. + rng = jax.random.PRNGKey(u.put_cpu(config.get("seed", 0))) + + write_note("Inferring parameter shapes...") + rng, rng_init = jax.random.split(rng) + params_shape = jax.eval_shape(init, rng_init) + + write_note("Inferring optimizer state shapes...") + tx, sched_fns = bv_optax.make(config, params_shape, sched_kw=dict( + total_steps=total_steps, batch_size=batch_size, data_size=ntrain_img)) + opt_shape = jax.eval_shape(tx.init, params_shape) + # We jit this, such that the arrays are created on the CPU, not device[0]. + sched_fns_cpu = [u.jit_cpu()(sched_fn) for sched_fn in sched_fns] + + if jax.process_index() == 0: + num_params = sum(np.prod(p.shape) for p in jax.tree_leaves(params_shape)) + mw.measure("num_params", num_params) + +################################################################################ +# # +# Shard & Transfer # +# # +################################################################################ + + write_note("Creating device mesh...") + mesh = jax.sharding.Mesh(device_mesh, mesh_axes) + repl_sharding = jax.sharding.NamedSharding(mesh, P()) + + write_note("Inferring shardings...") + train_state_shape = {"params": params_shape, "opt": opt_shape} + + strategy = config.get("sharding_strategy", [(".*", "replicate")]) + train_state_sharding = bv_sharding.infer_sharding( + train_state_shape, strategy=strategy, mesh=mesh) + + write_note("Transferring train_state to devices...") + # RNG is always replicated + rng_init = u.reshard(rng_init, repl_sharding) + + # Parameters and the optimizer are now global (distributed) jax arrays. + params = jax.jit(init, out_shardings=train_state_sharding["params"])(rng_init) + opt = jax.jit(tx.init, out_shardings=train_state_sharding["opt"])(params) + + rng, rng_loop = jax.random.split(rng, 2) + rng_loop = u.reshard(rng_loop, repl_sharding) + del rng # not used anymore, so delete it. + + # At this point we have everything we need to form a train state. It contains + # all the parameters that are passed and updated by the main training step. + train_state = {"params": params, "opt": opt} + del params, opt # Delete to avoid memory leak or accidental reuse. + + write_note("Logging parameter overview...") + parameter_overview.log_parameter_overview( + train_state["params"], msg="Init params", + include_stats="global", jax_logging_process=0) + +################################################################################ +# # +# Update Step # +# # +################################################################################ + + @functools.partial( + jax.jit, + donate_argnums=(0,), + out_shardings=(train_state_sharding, repl_sharding)) + def update_fn(train_state, rng, batch): + """Update step.""" + + images, labels = batch["image"], batch["labels"] + + step_count = bv_optax.get_count(train_state["opt"], jittable=True) + rng = jax.random.fold_in(rng, step_count) + assert "mixup" not in config, "Mixup is not supported for SigLIP." + + # Get device-specific loss rng. + rng, rng_model = jax.random.split(rng, 2) + + def loss_fn(params): + zimg, ztxt, extras = model.apply( + {"params": params}, images, labels, + train=True, rngs={"dropout": rng_model}) + logits = jnp.dot(zimg, ztxt.T) + logits = logits * extras["t"] + extras["b"] + eye = jnp.eye(zimg.shape[0]) + + # Standard sigmoid computes everything twice, once assuming positive + # labels and once assuming negative ones. But here we know exactly where + # to find positives (on "me" diagonal) and negatives (everywhere else), + # so compute each one's loss only once: + m1_diag1 = -jnp.ones_like(logits) + 2 * eye + loglik = jax.nn.log_sigmoid(m1_diag1 * logits) + + # Normalize by npos per column, but that's one, so just sum. + nll = -jnp.sum(loglik, axis=-1) + + # NOTE: same as concat'ing me/ot along axis -1 above. + l = jnp.mean(nll) + + return l + + params, opt = train_state["params"], train_state["opt"] + loss, grads = jax.value_and_grad(loss_fn)(params) + updates, opt = tx.update(grads, opt, params) + params = optax.apply_updates(params, updates) + + measurements = {"training_loss": loss} + gs = jax.tree_leaves(bv_optax.replace_frozen(config.schedule, grads, 0.)) + measurements["l2_grads"] = jnp.sqrt(sum([jnp.sum(g * g) for g in gs])) + ps = jax.tree_leaves(params) + measurements["l2_params"] = jnp.sqrt(sum([jnp.sum(p * p) for p in ps])) + us = jax.tree_leaves(updates) + measurements["l2_updates"] = jnp.sqrt(sum([jnp.sum(u * u) for u in us])) + + return {"params": params, "opt": opt}, measurements + +################################################################################ +# # +# Load Checkpoint # +# # +################################################################################ + + # Decide how to initialize training. The order is important. + # 1. Always resumes from the existing checkpoint, e.g. resumes a finetune job. + # 2. Resume from a previous checkpoint, e.g. start a cooldown training job. + # 3. Initialize model from something, e,g, start a fine-tuning job. + # 4. Train from scratch. + resume_ckpt_path = None + if save_ckpt_path and gfile.exists(f"{save_ckpt_path}-LAST"): + resume_ckpt_path = save_ckpt_path + elif config.get("resume"): + resume_ckpt_path = fillin(config.resume) + + ckpt_mngr = None + if save_ckpt_path or resume_ckpt_path: + ckpt_mngr = array_serial.GlobalAsyncCheckpointManager() + + if resume_ckpt_path: + write_note(f"Resuming training from checkpoint {resume_ckpt_path}...") + jax.tree_map(lambda x: x.delete(), train_state) + del train_state + shardings = { + **train_state_sharding, + "chrono": jax.tree_map(lambda _: repl_sharding, + u.chrono.save()), + } + loaded = u.load_checkpoint_ts( + resume_ckpt_path, tree=shardings, shardings=shardings) + train_state = {key: loaded[key] for key in train_state_sharding.keys()} + + u.chrono.load(jax.device_get(loaded["chrono"])) + del loaded + elif config.get("model_init"): + write_note(f"Initialize model from {config.model_init}...") + # TODO: when updating the `load` API soon, do pass and request the + # full `train_state` from it. Examples where useful: VQVAE, BN. + train_state["params"] = model_mod.load( + train_state["params"], config.model_init, config.get("model"), + **config.get("model_load", {})) + + # load has the freedom to return params not correctly sharded. Think of for + # example ViT resampling position embedings on CPU as numpy arrays. + train_state["params"] = u.reshard( + train_state["params"], train_state_sharding["params"]) + + parameter_overview.log_parameter_overview( + train_state["params"], msg="restored params", + include_stats="global", jax_logging_process=0) + + +################################################################################ +# # +# Setup Evals # +# # +################################################################################ + + # We do not jit/pmap this function, because it is passed to evaluator that + # does it later. We output as many intermediate tensors as possible for + # maximal flexibility. Later `jit` will prune out things that are not needed. + def eval_logits_fn(train_state, batch): + zimg, ztxt, out = model.apply( + {"params": train_state["params"]}, + batch.get("image", None), batch.get("labels", None)) + return zimg, ztxt, out + + def eval_loss_fn(train_state, batch): + logits, _ = model.apply({"params": train_state["params"]}, batch["image"]) + loss_fn = getattr(u, config.get("loss", "sigmoid_xent")) + return { + "loss": loss_fn(logits=logits, labels=batch["labels"], reduction=False) + } + + eval_fns = { + "predict": eval_logits_fn, + "loss": eval_loss_fn, + } + + # Only initialize evaluators when they are first needed. + @functools.lru_cache(maxsize=None) + def evaluators(): + return eval_common.from_config( + config, eval_fns, + lambda s: write_note(f"Init evaluator: {s}…\n{u.chrono.note}"), + lambda key, cfg: get_steps(key, default=None, cfg=cfg), + devices_flat, + ) + + # At this point we need to know the current step to see whether to run evals. + write_note("Inferring the first step number...") + first_step_device = bv_optax.get_count(train_state["opt"], jittable=True) + first_step = int(jax.device_get(first_step_device)) + u.chrono.inform(first_step=first_step) + + # Note that training can be pre-empted during the final evaluation (i.e. + # just after the final checkpoint has been written to disc), in which case we + # want to run the evals. + if first_step in (total_steps, 0): + write_note("Running initial or final evals...") + mw.step_start(first_step) + for (name, evaluator, _, prefix) in evaluators(): + if config.evals[name].get("skip_first") and first_step != total_steps: + continue + write_note(f"{name} evaluation...\n{u.chrono.note}") + with u.chrono.log_timing(f"z/secs/eval/{name}"): + with mesh, nn.logical_axis_rules(sharding_rules): + for key, value in evaluator.run(train_state): + mw.measure(f"{prefix}{key}", value) + +################################################################################ +# # +# Train Loop # +# # +################################################################################ + + prof = None # Keeps track of start/stop of profiler state. + + write_note("Starting training loop, compiling the first step...") + for step, batch in zip(range(first_step + 1, total_steps + 1), train_iter): + mw.step_start(step) + + with jax.profiler.StepTraceAnnotation("train_step", step_num=step): + with u.chrono.log_timing("z/secs/update0", noop=step > first_step + 1): + with mesh, nn.logical_axis_rules(sharding_rules): + train_state, measurements = update_fn(train_state, rng_loop, batch) + + # On the first host, let's always profile a handful of early steps. + if jax.process_index() == 0: + prof = u.startstop_prof(prof, step, first_step, get_steps("log_training")) + + # Report training progress + if (u.itstime(step, get_steps("log_training"), total_steps, host=0) + or u.chrono.warmup and jax.process_index() == 0): + for i, sched_fn_cpu in enumerate(sched_fns_cpu): + mw.measure(f"global_schedule{i if i else ''}", + sched_fn_cpu(u.put_cpu(step - 1))) + measurements = jax.device_get(measurements) + for name, value in measurements.items(): + mw.measure(name, value) + u.chrono.tick(step) + if not np.isfinite(measurements["training_loss"]): + raise RuntimeError(f"The loss became nan or inf somewhere within steps " + f"[{step - get_steps('log_training')}, {step}]") + + # Checkpoint saving + keep_ckpt_steps = get_steps("keep_ckpt", None) or total_steps + if save_ckpt_path and ( + (keep := u.itstime(step, keep_ckpt_steps, total_steps, first=False)) + or u.itstime(step, get_steps("ckpt", None), total_steps, first=True) + ): + u.chrono.pause(wait_for=train_state) + + # Copy because we add extra stuff to the checkpoint. + ckpt = {**train_state} + + # To save chrono state correctly and safely in a multihost setup, we + # broadcast the state to all hosts and convert it to a global array. + with jax.transfer_guard("allow"): + chrono_ckpt = multihost_utils.broadcast_one_to_all(u.chrono.save()) + chrono_shardings = jax.tree_map(lambda _: repl_sharding, chrono_ckpt) + ckpt = ckpt | {"chrono": u.reshard(chrono_ckpt, chrono_shardings)} + + u.save_checkpoint_ts(ckpt_mngr, ckpt, save_ckpt_path, step, keep) + u.chrono.resume() + + for (name, evaluator, log_steps, prefix) in evaluators(): + if u.itstime(step, log_steps, total_steps, first=False, last=True): + u.chrono.pause(wait_for=train_state) + u.chrono.tick(step) # Record things like epoch number, core hours etc. + write_note(f"{name} evaluation...\n{u.chrono.note}") + with u.chrono.log_timing(f"z/secs/eval/{name}"): + with mesh, nn.logical_axis_rules(sharding_rules): + for key, value in evaluator.run(train_state): + mw.measure(f"{prefix}{key}", jax.device_get(value)) + u.chrono.resume() + mw.step_end() + + # Always give a chance to stop the profiler, no matter how things ended. + # TODO: can we also do this when dying of an exception like OOM? + if jax.process_index() == 0 and prof is not None: + u.startstop_prof(prof) + + # Last note needs to happen before the pool's closed =) + write_note(f"Done!\n{u.chrono.note}") + + pool.close() + pool.join() + mw.close() + + if ckpt_mngr: + ckpt_mngr.wait_until_finished() + + # Make sure all hosts stay up until the end of main. + u.sync() + + u.maybe_cleanup_workdir(workdir, flags.FLAGS.cleanup, info) + + +if __name__ == "__main__": + app.run(main) diff --git a/big_vision/trainers/proj/paligemma/__pycache__/predict_fns.cpython-310.pyc b/big_vision/trainers/proj/paligemma/__pycache__/predict_fns.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8f623ac2503409f0d1caf5527d40be21eaf08d75 Binary files /dev/null and b/big_vision/trainers/proj/paligemma/__pycache__/predict_fns.cpython-310.pyc differ diff --git a/big_vision/trainers/proj/paligemma/predict_fns.py b/big_vision/trainers/proj/paligemma/predict_fns.py new file mode 100644 index 0000000000000000000000000000000000000000..48627d203a3b8ff4f99e8e26ca1eeb3859e0903a --- /dev/null +++ b/big_vision/trainers/proj/paligemma/predict_fns.py @@ -0,0 +1,466 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Prediction functions for PaliGemma.""" + +import collections +import functools + +from big_vision.pp import registry +import big_vision.utils as u +import einops +import jax +import jax.numpy as jnp +import numpy as np + + +P = jax.sharding.PartitionSpec + +# pylint: disable=missing-function-docstring + + +def get_all(model): + """Returns `predict_fns` for evaluators.""" + fns = { + "logits": _logits, + "image_avg_repr": _image_avg_repr, + "decode": _decode, + "decode_with_logp": _decode_with_logp, + "beam_decode": _beam_decode, + } + return {name: functools.partial(fn, model=model) for name, fn in fns.items()} + + +def _logits(train_state, batch, *, model): + images, text, mask = batch["image"], batch["text"], batch["mask_ar"] + text_logits, out = model.apply( + {"params": train_state["params"]}, + images, text[:, :-1], mask[:, :-1], + ) + return text_logits, out + + +def _image_avg_repr(train_state, batch, *, model, key="img/pre_logits"): + zimg, out = model.apply( + {"params": train_state["params"]}, + image=batch["image"], + method=model.embed_image, + ) + if key: + zimg = u.tree_get(out, key) + # At this point, zimg is a (batch of) sequence of image tokens, because we + # assume the model is a vit with "none" head. This predict-fn is for fewshot + # evaluator, so we need to turn it into reasonably-sized vector -> avg. + zimg = jnp.mean(zimg, axis=range(1, zimg.ndim - 1)) + return zimg, out + + +def _decode_with_logp( + train_state, batch, *, model, devices, max_decode_len, eos_token, + best_of_n=1, sampler="greedy", replicate_out=False, eos_look_behind=0): + """Sample token continuations to the input sequences.""" + mesh = jax.sharding.Mesh(devices, ("devices",)) + replicate_sharding = jax.sharding.NamedSharding(mesh, P()) + out_sharding = jax.sharding.NamedSharding( + mesh, P() if replicate_out else P("devices") + ) + + # Prefill the model cache and generate logits for first token. + logits, cache = jax.jit( + _prefill_cache, + out_shardings=out_sharding, + static_argnames=("model", "max_decode_len"), + )( + train_state["params"], + { + "image": batch["image"], + "text": batch["text"], + "mask_input": batch["mask_input"], + "mask_ar": batch["mask_ar"], + }, + model=model, + max_decode_len=max_decode_len, + ) + + # Mask indicating real examples. False if example is used to pad the batch. + mask = batch["_mask"] + + # Repeat example in case we are picking the best of n. + logits, cache, mask = jax.jit( + _bon_repeat, + static_argnames=("n",) + )((logits, cache, mask), n=best_of_n) + + decode_sample_output = jax.jit( + _decode_sample_output, + static_argnames=("max_decode_len", "sampler"), + ) + decode_early_stop = jax.jit( + _decode_early_stop, + out_shardings=replicate_sharding, + static_argnames=("eos_token",), + ) + extend_cache = jax.jit( + _extend_cache, + donate_argnums=1, + static_argnames=("model",), + ) + + # Keep sampling tokens from last logits until EOS or max_decode_len. + state = None + # Setting `eos_look_behind>0` removes blocking transfer with small batches. + stops = collections.deque(maxlen=1 + eos_look_behind) + for idx in range(max_decode_len): + tokens, state = decode_sample_output( + state, logits, max_decode_len=max_decode_len, sampler=sampler + ) + + if idx + 1 >= max_decode_len: + break + + stops.append(decode_early_stop(state, mask, eos_token=eos_token)) + if len(stops) == stops.maxlen and jax.device_get(stops[0]): + break + + # Compute logits for next token + logits, cache = extend_cache( + train_state["params"], cache, tokens, model=model + ) + + # Select the best of n sample for each example. + _, tokens, logp = jax.jit( + _bon_select, + out_shardings=out_sharding, + static_argnames=("n", "eos_token"), + )(state, n=best_of_n, eos_token=eos_token) + + return tokens, logp + + +def _decode(train_state, batch, **kwargs): + tokens, _ = _decode_with_logp(train_state, batch, **kwargs) + return tokens + + +def _bon_repeat(tree, *, n): + return jax.tree.map(lambda x: jnp.repeat(x, n, axis=0), tree) + + +def _compute_score(tokens, logp, eos_token): + """Compute log-probability of each sequence up to first eos (including it).""" + seqlen = jnp.sum(jnp.cumsum(tokens == eos_token, axis=-1) == 0, axis=-1) + 1 + token_mask = jnp.arange(tokens.shape[-1]) < seqlen[..., None] + scores = jnp.sum(logp * token_mask, axis=-1) + return scores + + +def _bon_select(state, *, n, eos_token): + """Pick the sampled sequence with the highest likelihood for each example.""" + (_, tokens, logp) = state + + # Filter state to only keep the best of each example. + scores = _compute_score(tokens, logp, eos_token) + scores = einops.rearrange(scores, "(b n) -> b n", n=n) + state = jax.tree.map( + lambda x: einops.rearrange(x, "(b n) l -> b n l", n=n), state) + best_indices = jnp.argmax(scores, -1) # [b] + state = jax.tree.map( + lambda x: jnp.take_along_axis(x, best_indices[:, None, None], axis=1), + state) + state = jax.tree.map(lambda x: x[:, 0], state) + + return state + + +def _decode_sample_output(state, logits, *, max_decode_len, sampler): + if state is None: + # Decode state keeps track of sampled tokens and their logp. + bs = logits.shape[0] + seqlen = jnp.zeros((bs, 1), dtype=jnp.int32) + tokens = jnp.zeros((bs, max_decode_len), dtype=jnp.int32) + logp = jnp.zeros((bs, max_decode_len), dtype=logits.dtype) + else: + (seqlen, tokens, logp) = state + + # Sample tokens. + sampled_tokens, sampled_logp = _sample_logits(logits, sampler=sampler) + + # Update state with sampled outputs. + new_len = seqlen + 1 + new_tokens = _put_along_last_axis(tokens, seqlen, sampled_tokens) + new_logp = _put_along_last_axis(logp, seqlen, sampled_logp) + new_state = (new_len, new_tokens, new_logp) + + return sampled_tokens, new_state + + +def _decode_early_stop(state, mask, *, eos_token): + (seqlen, tokens, unused_logp) = state + token_mask = jnp.arange(tokens.shape[-1])[None, :] < seqlen + has_eos = jnp.any(jnp.logical_and(tokens == eos_token, token_mask), axis=-1) + done = jnp.logical_or(has_eos, jnp.logical_not(mask)) + return jnp.all(done) + + +def _put_along_last_axis(arr, indices, values): + """Like np.put_along_axis(..., axis=-1), since jax is missing it.""" + assert arr.ndim == indices.ndim == values.ndim, ( + arr.ndim, indices.ndim, values.ndim) + onehot = jax.nn.one_hot(indices, arr.shape[-1], dtype=values.dtype) + put_mask = jnp.einsum("...i,...in->...n", + jnp.ones(values.shape, jnp.int32), onehot) + put_values = jnp.einsum("...i,...in->...n", values, onehot) + return jnp.where(put_mask, put_values, arr) + + +def _prefill_cache(params, batch, *, model, max_decode_len): + """Initialize the model cache for decoding with the prompts.""" + variables = {"params": params} + (x, input_mask, mask_ar), _ = model.apply( + variables, batch["image"], batch["text"], + input_mask=batch["mask_input"], + mask_ar=batch["mask_ar"], + method=model.embed_image_and_text) + last_logits, variables = model.apply( + variables, x, input_mask, mask_ar, + cache_size=x.shape[1] + max_decode_len, + method=model.prefill_cache, + mutable=("cache",)) + return last_logits, variables["cache"] + + +def _extend_cache(params, cache, tokens, *, model): + """Extend the model cache for decoding with one token per sequence.""" + variables = {"params": params, "cache": cache} + x, _ = model.apply(variables, tokens, method=model.embed_text) + last_logits, variables = model.apply( + variables, x, method=model.extend_cache, mutable=("cache",)) + return last_logits, variables["cache"] + + +def _sample_logits(logits, sampler): + """Returns a sampled token and its logp from logits.""" + # Note: Consider making it possible for evaluators to pass rng seed to + # decode functions. For now generate it from jax.lax and avoid evaluators + # having to deal with it. + rng = jax.random.PRNGKey( + jax.lax.rng_uniform(0, np.iinfo(np.int32).max, tuple())) + + # Use Registry to support specifying things like: + # "greedy", "nucleus(0.2)", "temperature(t=1.0)" + sampled_tokens = registry.Registry.lookup("paligemma_sampler." + sampler)( + logits=logits, rng=rng) + + # Find the log probability (normalized logits) of selected tokens. + sampled_logp = jnp.take_along_axis( + jax.nn.log_softmax(logits, axis=-1), + sampled_tokens[..., None], -1)[..., 0] + + return sampled_tokens, sampled_logp + + +@registry.Registry.register("paligemma_sampler.greedy") +def _greedy_sampling(*, logits, rng): + del rng + return jnp.argmax(logits, axis=-1) + + +@registry.Registry.register("paligemma_sampler.temperature") +def _temperature_sampling(t, *, logits, rng): + return jax.random.categorical(rng, logits / t) + + +@registry.Registry.register("paligemma_sampler.nucleus") +def _nucleus_sampling(p: float, t: float = 1.0, *, logits, rng): + logits = logits / t + neg_inf = np.array(-1.0e7) # Effective negative infinity. + logits_sorted = jnp.sort(logits, axis=-1, descending=True) + sorted_cum_probs = jnp.cumsum( + jax.nn.softmax(logits_sorted, axis=-1), axis=-1) + cutoff_index = jnp.sum(sorted_cum_probs < p, axis=-1, keepdims=True) + cutoff_logit = jnp.take_along_axis(logits_sorted, cutoff_index, axis=-1) + logits = jnp.where(logits < cutoff_logit, + jnp.full_like(logits, neg_inf), logits) + return jax.random.categorical(rng, logits) + + +def _beam_decode(train_state, batch, *, + model, devices, max_decode_len, + eos_token, beam_size, replicate_out=False): + """Beam search (greedy/top-k exploration).""" + mesh = jax.sharding.Mesh(devices, ("devices",)) + replicate_sharding = jax.sharding.NamedSharding(mesh, P()) + out_sharding = jax.sharding.NamedSharding( + mesh, P() if replicate_out else P("devices") + ) + + # Prefill the model cache and generate logits for first token. + logits, cache = jax.jit( + _prefill_cache, + out_shardings=out_sharding, + static_argnames=("model", "max_decode_len"), + )( + train_state["params"], + { + "image": batch["image"], + "text": batch["text"], + "mask_input": batch["mask_input"], + "mask_ar": batch["mask_ar"], + }, + model=model, + max_decode_len=max_decode_len, + ) + + # Mask indicating real examples. False if example is used to pad the batch. + mask = batch["_mask"] + + beam_sample_output = jax.jit( + _beam_sample_output, + static_argnames=("max_decode_len", "beam_size", "eos_token"), + ) + beam_early_stop = jax.jit( + _beam_early_stop, + out_shardings=replicate_sharding, + static_argnames=("eos_token",), + ) + extend_cache = jax.jit( + _extend_cache, + donate_argnums=1, + static_argnames=("model",), + ) + + # Keep sampling tokens from last logits until EOS or max_decode_len. + state = None + for idx in range(max_decode_len): + tokens, state, cache = beam_sample_output( + state, logits, cache, + max_decode_len=max_decode_len, beam_size=beam_size, eos_token=eos_token) + + early_stop = beam_early_stop(state, mask, eos_token=eos_token) + if jax.device_get(early_stop) or (idx + 1 >= max_decode_len): + break + + # Compute logits for next token + logits, cache = extend_cache( + train_state["params"], cache, tokens, model=model) + + return jax.jit(_beam_make_output, out_shardings=out_sharding)(state) + + +def _beam_early_stop(state, mask, eos_token): + (best_tokens, best_logp, seqlen, unused_tokens, logp) = state + + # Scores of finalized sequences. + best_scores = _compute_score(best_tokens, best_logp, eos_token) + + # Scores of live sequences. + live_mask = jnp.arange(logp.shape[-1])[None, None] < seqlen + live_scores = jnp.sum(logp * live_mask, axis=-1) + live_scores = jnp.max(live_scores, axis=1) + + done = live_scores < best_scores + return jnp.all(jnp.logical_or(done, jnp.logical_not(mask))) + + +def _beam_make_output(state): + (best_tokens, *_) = state + return best_tokens[:, 0, ...] + + +def _beam_sample_output(state, logits, cache, *, + beam_size, max_decode_len, eos_token): + assert logits.shape[1] == 1 + logits = jax.nn.log_softmax(logits[:, 0, :]) # Normalize logits + + if state is None: + bs = logits.shape[0] + # Beam decode state keeps track of: + # A) Best sampled output for each example. At initialization these have + # shape[1]=0, but end up with shape[1]=1 after first call. + best_tokens = jnp.zeros((bs, 0, max_decode_len), dtype=jnp.int32) + best_logp = jnp.zeros((bs, 0, max_decode_len), dtype=logits.dtype) + # B) N candidate sequences for each example. At initialization these have + # beam_size=1, but end up with correct beam_size when expanded. + seqlen = jnp.zeros((bs, 1, 1), dtype=jnp.int32) + tokens = jnp.zeros((bs, 1, max_decode_len), dtype=jnp.int32) + logp = jnp.zeros((bs, 1, max_decode_len), dtype=logits.dtype) + else: + (best_tokens, best_logp, seqlen, tokens, logp) = state + bs = logits.shape[0] // beam_size + assert best_tokens.shape[0] == bs + + # Reshape cache to [example, candidate, ...]. + # Note: on first call the number of candidates is 1. Later it is beam_size. + cache, logits = jax.tree.map( + lambda x: einops.rearrange(x, "(b n) ... -> b n ...", b=bs), + (cache, logits)) + + # Consider a live sequence could end now and update the best finished + # sequences so far for each example. This strategy is found in some beam + # implementations such as in praxis. + # The code below also adjusts the best shape[1]=0 -> 1 during first call. + eos_tokens = jnp.array(eos_token)[None, None, None] + new_tokens = _put_along_last_axis(tokens, seqlen, eos_tokens) + new_logp = _put_along_last_axis(logp, seqlen, logits[:, :, eos_token, None]) + + best_tokens = jnp.concatenate([best_tokens, new_tokens], axis=1) + best_logp = jnp.concatenate([best_logp, new_logp], axis=1) + best_scores = _compute_score(best_tokens, best_logp, eos_token=eos_token) + _, top_indices = jax.lax.top_k(best_scores, k=1) + + best_tokens = jnp.take_along_axis(best_tokens, top_indices[..., None], axis=1) + best_logp = jnp.take_along_axis(best_logp, top_indices[..., None], axis=1) + + # To find the next best N live candidates we expand each candidate and keep + # the best N (ignoring EOS tokens). In this case we expand into (N+1) + # candidates and set their likelihood to "-inf" (if EOS) after the fact. + live_mask = jnp.arange(logp.shape[-1])[None, None] < seqlen + live_scores = jnp.sum(logp * live_mask, axis=-1) + topk_logits, topk_tokens = jax.lax.top_k(logits, beam_size+1) + scores = live_scores[..., None] + topk_logits + scores = jnp.where( + topk_tokens != eos_token, scores, jnp.finfo(scores.dtype).min) + + # From the N*(N+1) candidates find the top N for each example. + topk_logits, topk_tokens, scores = jax.tree.map( + lambda x: einops.rearrange(x, "b n np1 -> b (n np1)"), + (topk_logits, topk_tokens, scores)) + _, topk_indices = jax.lax.top_k(scores, k=beam_size) + sampled_indices = topk_indices // (beam_size+1) + sampled_tokens = jnp.take_along_axis( + topk_tokens, topk_indices, axis=-1)[..., None] + sampled_logits = jnp.take_along_axis( + topk_logits, topk_indices, axis=-1)[..., None] + + # Adjust cache and state so it matches the selected top N input candidates. + # This also adjusts the beam_size=1->n during first call. + def take_candidates(x): + one_hot_matrix = jax.nn.one_hot(sampled_indices, x.shape[1], dtype=x.dtype) + return jnp.einsum("bi...,boi->bo...", x, one_hot_matrix) + cache, seqlen, tokens, logp = jax.tree.map( + take_candidates, (cache, seqlen, tokens, logp)) + + # Write the sampled tokens/logits on the reshuffled state. + tokens = _put_along_last_axis(tokens, seqlen, sampled_tokens) + logp = _put_along_last_axis(logp, seqlen, sampled_logits) + seqlen = seqlen + 1 + + state = (best_tokens, best_logp, seqlen, tokens, logp) + + # Reshape to [(example, candidate), ...]. + sampled_tokens, cache = jax.tree.map( + lambda x: einops.rearrange(x, "b n ... -> (b n) ..."), + (sampled_tokens, cache)) + + return sampled_tokens, state, cache diff --git a/big_vision/trainers/proj/paligemma/run.py b/big_vision/trainers/proj/paligemma/run.py new file mode 100644 index 0000000000000000000000000000000000000000..b4b1f5d18507aab37d3c15293d9e3f15faf1d28a --- /dev/null +++ b/big_vision/trainers/proj/paligemma/run.py @@ -0,0 +1,141 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Load and run the PaliGemma model.""" +import functools +import sys + +from absl import app +from absl import flags +from absl import logging + +# pylint: disable=all +import jax +from jax.sharding import Mesh, NamedSharding, PartitionSpec +import ml_collections +import numpy as np + +import big_vision.models.proj.paligemma.gemma_bv +import big_vision.models.proj.paligemma.paligemma as model_mod +import big_vision.models.vit +import big_vision.pp.builder +import big_vision.pp.tokenizer +import big_vision.pp.ops_image +import big_vision.pp.ops_general +import big_vision.pp.ops_text +import big_vision.pp.proj.paligemma.ops +import big_vision.sharding +import big_vision.trainers.proj.paligemma.predict_fns +import big_vision.utils as u +# pylint: enable=all + +# We always want to be explicit about any host-device transfers. +jax.config.update("jax_transfer_guard", "disallow") + +CKPT = flags.DEFINE_string( + "ckpt", default=None, help="Path to checkpoint.") +IMAGE = flags.DEFINE_string( + "image", default=None, help="Path to input image.") + +SAMPLER = flags.DEFINE_string( + "sampler", default="greedy", help="Decoding strategy. Try `nucleus(0.1)`") +RES = flags.DEFINE_integer( + "res", default=224, help="Image resolution (224, 448, 896).") +MAX_DECODE_LEN = flags.DEFINE_integer( + "max_decode_len", default=128, help="Max total generation steps.") +PREFILL_LEN = flags.DEFINE_integer( + "prefill_len", default=32, help="Size of prefill (prompt). " + "Shorter is faster, but too short will cut off your prompt.") + +TOKENIZER = "gemma(tokensets=['loc', 'seg'])" + + +def load_model(ckpt): + model_cfg = ml_collections.FrozenConfigDict(dict( + img=dict(variant="So400m/14", pool_type="none", scan=True), + llm=dict(vocab_size=256_000 + 1024 + 128), + )) + model = model_mod.Model(**model_cfg) + params = model_mod.load(None, ckpt, model_cfg) + return model, params + + +def info(s, *a): + logging.info("\u001b[33mNOTE\u001b[0m: " + s, *a) + logging.flush() + + +def main(argv): + info(f"{argv=}") + info("Loading model...") + model, params = load_model(CKPT.value) + + predict_fns = big_vision.trainers.proj.paligemma.predict_fns.get_all(model) + + info("Loading tokenizer...") + tokzr = big_vision.pp.tokenizer.get_tokenizer(TOKENIZER) + + info("Creating mesh and sharding params...") + mesh = Mesh(jax.devices(), ("data")) + repl_sharding = NamedSharding(mesh, PartitionSpec()) + data_sharding = NamedSharding(mesh, PartitionSpec("data")) + params_sharding = big_vision.sharding.infer_sharding( + params, strategy=[(".*", "fsdp(axis='data')")], mesh=mesh) + + # Ship the params to device(s) + params = jax.tree.map(lambda x, sh: u.reshard(x, sh), params, params_sharding) + + # Mostly go through pp ops to build our batch: + pp_fn = big_vision.pp.builder.get_preprocess_fn("|".join([ + f"decode|resize({RES.value})|value_range(-1, 1)", + f"tok(key='prefix', bos='yes', model={repr(TOKENIZER)})", + f"tok(key='septok', text='\\n', model={repr(TOKENIZER)})", + 'masked_concat(["prefix", "septok"], mask_ar=[0, 0], mask_input=[1, 1])', + f'tolen({PREFILL_LEN.value}, pad_value=0, key="text")', + f'tolen({PREFILL_LEN.value}, pad_value=1, key="mask_ar")', + f'tolen({PREFILL_LEN.value}, pad_value=0, key="mask_input")', + 'keep("image", "text", "mask_ar", "mask_input")', + ]), log_data=False) + + decode = functools.partial( + predict_fns["decode"], devices=jax.devices(), + eos_token=tokzr.eos_token, max_decode_len=MAX_DECODE_LEN.value, + sampler=SAMPLER.value) + + def make_batch(fname, prompt): + image = open(fname, "rb").read() + + # Create an example + example = pp_fn({"image": image, "prefix": np.array(prompt)}) + example["_mask"] = np.array(True) # True means valid non-pad example + + batch = jax.tree.map(lambda x: x[None], example) + return u.reshard(batch, repl_sharding) # Move to device(s) + + info("Precompiling inference function...") + decode({"params": params}, batch=make_batch(IMAGE.value, "caption en")) + + info("Type a prompt and press enter, for example 'caption en': ") + for line in map(str.strip, sys.stdin): + tokens = decode({"params": params}, batch=make_batch(IMAGE.value, line)) + tokens = jax.device_get(tokens)[0] # First batch entry. + + # TODO: b/lbeyer - flip around: output on stdout, logs on stderr. + print(tokzr.to_str(tokens), file=sys.stderr, flush=True) + + +if __name__ == "__main__": + flags.mark_flag_as_required("ckpt") + flags.mark_flag_as_required("image") + app.run(main) diff --git a/big_vision/trainers/proj/paligemma/train.py b/big_vision/trainers/proj/paligemma/train.py new file mode 100644 index 0000000000000000000000000000000000000000..9a0d1bbb6b626a41b40a0cae3ac047c5a55d710c --- /dev/null +++ b/big_vision/trainers/proj/paligemma/train.py @@ -0,0 +1,525 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Training loop for PaliGemma-style VLM.""" +# pylint: disable=consider-using-from-import +# pylint: disable=logging-fstring-interpolation + +import functools +import importlib +import multiprocessing.pool +import os + +from absl import app +from absl import flags +from absl import logging +import big_vision.datasets.core as ds_core +import big_vision.evaluators.common as eval_common +import big_vision.input_pipeline as input_pipeline +import big_vision.optax as bv_optax +import big_vision.sharding as bv_sharding +import big_vision.trainers.proj.paligemma.predict_fns as predict_fns +import big_vision.utils as u +from clu import parameter_overview +import flax +import flax.linen as nn +import jax +from jax.experimental import mesh_utils +from jax.experimental import multihost_utils +from jax.experimental.array_serialization import serialization as array_serial +import jax.numpy as jnp +import ml_collections as mlc +from ml_collections import config_flags +import numpy as np +import optax +import tensorflow as tf + +from tensorflow.io import gfile + + +config_flags.DEFINE_config_file( + "config", None, "Training configuration.", lock_config=True) + +flags.DEFINE_string("workdir", default=None, help="Work unit directory.") +flags.DEFINE_boolean("cleanup", default=False, + help="Delete workdir (only) after successful completion.") + +# Adds jax flags to the program. +jax.config.parse_flags_with_absl() +# Transfer guard will fail the program whenever that data between a host and +# a device is transferred implicitly. This often catches subtle bugs that +# cause slowdowns and memory fragmentation. Explicit transfers are done +# with jax.device_put and jax.device_get. +jax.config.update("jax_transfer_guard", "disallow") + + +NamedSharding = jax.sharding.NamedSharding +P = jax.sharding.PartitionSpec + + +def main(argv): + del argv + + # This is needed on multihost systems, but crashes on non-TPU single-host. + if os.environ.get("BV_JAX_INIT"): + jax.distributed.initialize() + + # Make sure TF does not touch GPUs. + tf.config.set_visible_devices([], "GPU") + +################################################################################ +# # +# Set up logging # +# # +################################################################################ + + # Set up work directory and print welcome message. + config = flags.FLAGS.config + workdir = flags.FLAGS.workdir + logging.info( + f"\u001b[33mHello from process {jax.process_index()} holding " + f"{jax.local_device_count()}/{jax.device_count()} devices and " + f"writing to workdir {workdir}.\u001b[0m") + + save_ckpt_path = None + if workdir: # Always create if requested, even if we may not write into it. + gfile.makedirs(workdir) + save_ckpt_path = os.path.join(workdir, "checkpoint.bv") + + # The pool is used to perform misc operations such as logging in async way. + pool = multiprocessing.pool.ThreadPool() + + # Here we register preprocessing ops from modules listed on `pp_modules`. + for m in config.get("pp_modules", ["ops_general", "ops_image", "ops_text"]): + importlib.import_module(f"big_vision.pp.{m}") + + # Setup up logging and experiment manager. + xid, wid = -1, -1 + fillin = lambda s: s + def info(s, *a): + logging.info("\u001b[33mNOTE\u001b[0m: " + s, *a) + def write_note(note): + if jax.process_index() == 0: + info("%s", note) + + mw = u.BigVisionMetricWriter(xid, wid, workdir, config) + + # Allow for things like timings as early as possible! + u.chrono.inform(measure=mw.measure, write_note=write_note) + +################################################################################ +# # +# Set up Mesh # +# # +################################################################################ + + # We rely on jax mesh_utils to organize devices, such that communication + # speed is the fastest for the last dimension, second fastest for the + # penultimate dimension, etc. + config_mesh = config.get("mesh", [("data", jax.device_count())]) + + # Sharding rules with the default of doing full data sharding. + sharding_rules = config.get("sharding_rules", [("act_batch", "data")]) + + mesh_axes, mesh_size = tuple(zip(*config_mesh)) + + # Because jax.utils do not support `-1` shape size. + mesh_size = np.array(jax.devices()).reshape(mesh_size).shape + + device_mesh = mesh_utils.create_device_mesh( + mesh_size, allow_split_physical_axes=config.get( + "mesh_allow_split_physical_axes", False)) + + # Consistent device order is important to ensure correctness of various train + # loop components, such as input pipeline, update step, evaluators. The + # order prescribed by the `devices_flat` variable should be used throughout + # the program. + devices_flat = device_mesh.flatten() + +################################################################################ +# # +# Input Pipeline # +# # +################################################################################ + + write_note("Initializing train dataset...") + batch_size = config.input.batch_size + if batch_size % jax.device_count() != 0: + raise ValueError(f"Batch size ({batch_size}) must " + f"be divisible by device number ({jax.device_count()})") + info("Global batch size %d on %d hosts results in %d local batch size. With " + "%d dev per host (%d dev total), that's a %d per-device batch size.", + batch_size, jax.process_count(), batch_size // jax.process_count(), + jax.local_device_count(), jax.device_count(), + batch_size // jax.device_count()) + + train_ds, ntrain_img = input_pipeline.training(config.input) + + total_steps = u.steps("total", config, ntrain_img, batch_size) + def get_steps(name, default=ValueError, cfg=config): + return u.steps(name, cfg, ntrain_img, batch_size, total_steps, default) + + u.chrono.inform(total_steps=total_steps, global_bs=batch_size, + steps_per_epoch=ntrain_img / batch_size) + + info("Running for %d steps, that means %f epochs", + total_steps, total_steps * batch_size / ntrain_img) + + # Start input pipeline as early as possible, this will kick-start filling + # shuffle buffers and get the first batch in a background thread. + n_prefetch = config.get("prefetch_to_device", 1) + train_iter = input_pipeline.start_global( + train_ds, devices_flat, n_prefetch, warmup=n_prefetch > 0) + + # For mixed data, add per-dataset epoch and examples seen measurements. + if isinstance(config.input.data.get("name"), str): + measure_per_dataset_times = lambda step: None # No-op + else: + nexamples = { + name: ds_core.get(**config.input[name].data).total_examples + for name in config.input.data + } + def measure_per_dataset_times(step): + total = sum(config.input.data.values()) + for name, w in config.input.data.items(): + w = w / total + mw.measure(f"examples_seen_{name}", u.chrono.accum_examples_seen * w) + mw.measure(f"epoch_{name}", step * batch_size * w / nexamples[name]) + +################################################################################ +# # +# Create Model & Optimizer # +# # +################################################################################ + + write_note(f"Initializing {config.model_name} model...") + model_mod = importlib.import_module(f"big_vision.models.{config.model_name}") + model = model_mod.Model(**mlc.FrozenConfigDict(config.get("model", {}))) + + def init(rng, partial_params=None): + batch = jax.tree.map(lambda x: jnp.zeros(x.shape, x.dtype.as_numpy_dtype), + train_ds.element_spec) + _, variables = model.apply( # flax init is just apply with mutable. + {"params": partial_params or {}}, + batch["image"], batch["text"][:, :-1], batch["mask_ar"][:, :-1], + rngs={"params": rng, "dropout": rng}, + mutable=["params"]) + return flax.core.unfreeze(variables["params"]) + + # This seed makes the Jax part of things (like model init) deterministic. + # However, full training still won't be deterministic, for example due to the + # tf.data pipeline not being deterministic even if we would set TF seed. + # See (internal link) for a fun read on what it takes. + rng = jax.random.PRNGKey(u.put_cpu(config.get("seed", 0))) + + write_note("Inferring parameter shapes...") + rng, rng_init = jax.random.split(rng) + params_shape = jax.eval_shape(init, rng_init) + params_shape = nn.unbox(params_shape) + + write_note("Inferring optimizer state shapes...") + tx, sched_fns = bv_optax.make(config, params_shape, sched_kw=dict( + total_steps=total_steps, batch_size=batch_size, data_size=ntrain_img)) + opt_shape = jax.eval_shape(tx.init, params_shape) + # We jit this, such that the arrays are created on the CPU, not device[0]. + sched_fns_cpu = [u.jit_cpu()(sched_fn) for sched_fn in sched_fns] + + if jax.process_index() == 0: + num_params = sum(np.prod(p.shape) for p in jax.tree.leaves(params_shape)) + mw.measure("num_params", num_params) + +################################################################################ +# # +# Init and/or load model onto devices # +# # +################################################################################ + + write_note("Creating device mesh...") + mesh = jax.sharding.Mesh(device_mesh, mesh_axes) + repl_sharding = jax.sharding.NamedSharding(mesh, P()) + + write_note("Inferring shardings...") + train_state_shape = {"params": params_shape, "opt": opt_shape} + + strategy = config.get("sharding_strategy", [(".*", "replicate")]) + train_state_sharding = bv_sharding.infer_sharding( + train_state_shape, strategy=strategy, mesh=mesh) + + # Decide how to initialize training. The order is important. + # 1. Always resumes from the existing checkpoint, e.g. resumes a finetune job. + # 2. Resume from a previous checkpoint, e.g. start a cooldown training job. + # 3. Initialize model from scratch or from something, e.g. fine-tuning job. + resume_ckpt_path = None + if save_ckpt_path and gfile.exists(f"{save_ckpt_path}-LAST"): + resume_ckpt_path = save_ckpt_path + elif config.get("resume"): + resume_ckpt_path = fillin(config.resume) + + if resume_ckpt_path: + write_note(f"Resuming training from checkpoint {resume_ckpt_path}...") + shardings = { + **train_state_sharding, + "chrono": jax.tree.map(lambda _: repl_sharding, u.chrono.save()), + } + loaded = u.load_checkpoint_ts( + resume_ckpt_path, tree=shardings, shardings=shardings) + train_state = {key: loaded[key] for key in train_state_sharding.keys()} + u.chrono.load(jax.device_get(loaded["chrono"])) + del loaded + else: + write_note( + f"Initialize model from {config.get('model_init') or 'scratch'}...") + + # To avoid holding two copies of parameters we first call `model.load` + # and then initialize the missing variables. + if config.get("model_init"): + # We call `model.load` with params shape, so it can know all model params + # including their shapes and dtypes (also shardings once wired). + params = model_mod.load( + params_shape, config.model_init, config.get("model"), + **config.get("model_load", {})) + + # Keep only params loaded by `model.load` and shard them into devices. + mask = jax.tree.map( + lambda x: not isinstance(x, jax.ShapeDtypeStruct), params) + params = u.reshard(u.tree_filter(params, mask), + u.tree_filter(train_state_sharding["params"], mask)) + + parameter_overview.log_parameter_overview( + params, msg="Restored params", + include_stats="global", jax_logging_process=0) + else: + params = {} + + # Init will initialize any missing params. + rng_init = u.reshard(rng_init, repl_sharding) + params = jax.jit( + init, donate_argnums=1, out_shardings=train_state_sharding["params"])( + rng_init, params) + params = nn.unbox(params) + + # Initialize optimizer and construct train_state. + opt = jax.jit(tx.init, out_shardings=train_state_sharding["opt"])(params) + train_state = {"params": params, "opt": opt} + del params, opt # Delete to avoid memory leak or accidental reuse. + + parameter_overview.log_parameter_overview( + train_state["params"], msg="Parameter overview", + include_stats="global", jax_logging_process=0) + + rng, rng_loop = jax.random.split(rng, 2) + rng_loop = u.reshard(rng_loop, repl_sharding) + del rng, rng_init # not used anymore, so delete it. + +################################################################################ +# # +# Update Step # +# # +################################################################################ + + @functools.partial( + jax.jit, + donate_argnums=(0,), + out_shardings=(train_state_sharding, repl_sharding)) + def update_fn(train_state, rng, batch): + """Update step.""" + + step_count = bv_optax.get_count(train_state["opt"], jittable=True) + rng = jax.random.fold_in(rng, step_count) + assert "mixup" not in config, "Mixup is not supported for SigLIP." + + # Get device-specific loss rng. + _, rng_model = jax.random.split(rng, 2) + + imgs, txts, mask_ar = batch["image"], batch["text"], batch["mask_ar"] + + def loss_fn(params): + text_logits, _ = model.apply( + {"params": params}, imgs, txts[:, :-1], mask_ar[:, :-1], + train=True, rngs={"dropout": rng_model}) + + logp = jax.nn.log_softmax(text_logits, axis=-1) + targets = jax.nn.one_hot(txts[:, 1:], text_logits.shape[-1]) + off_value = config.get("label_smoothing", 0.0) + if off_value > 0: + denom = text_logits.shape[-1] - 1 + targets = jnp.where( + targets == 1.0, 1.0 - off_value, off_value / denom) + + # Sum across vocab. + token_pplx = jnp.sum(logp * targets, axis=-1) + + # Shift by one since the loss is on the _next_ token. + mask_loss = batch["mask_loss"][:, 1:] + token_pplx = token_pplx * mask_loss + pplx = -jnp.sum(token_pplx, axis=-1) + pplx /= jnp.clip(jnp.sum(mask_loss, axis=-1), 1) + + # In this dict the (outer) reduction is along batch. + measurements = dict( + training_loss=jnp.mean(pplx), + avg_sup_seqlen=jnp.mean(jnp.sum(mask_loss, axis=-1)), + max_sup_seqlen=jnp.max(jnp.sum(mask_loss, axis=-1)), + ) + + return measurements["training_loss"], measurements + + params, opt = train_state["params"], train_state["opt"] + (_, measurements), grads = jax.value_and_grad(loss_fn, has_aux=True)(params) + updates, opt = tx.update(grads, opt, params) + params = optax.apply_updates(params, updates) + + gs = jax.tree.leaves(bv_optax.replace_frozen(config.schedule, grads, 0.)) + measurements["l2_grads"] = jnp.sqrt(sum([jnp.sum(g * g) for g in gs])) + ps = jax.tree.leaves(params) + measurements["l2_params"] = jnp.sqrt(sum([jnp.sum(p * p) for p in ps])) + us = jax.tree.leaves(updates) + measurements["l2_updates"] = jnp.sqrt(sum([jnp.sum(u * u) for u in us])) + + return {"params": params, "opt": opt}, measurements + +################################################################################ +# # +# Setup Evals # +# # +################################################################################ + + # Only initialize evaluators when they are first needed. + @functools.lru_cache(maxsize=None) + def evaluators(): + return eval_common.from_config( + config, + predict_fns.get_all(model), + lambda s: write_note(f"Init evaluator: {s}…\n{u.chrono.note}"), + lambda key, cfg: get_steps(key, default=None, cfg=cfg), + devices_flat, + ) + + # At this point we need to know the current step to see whether to run evals. + write_note("Inferring the first step number...") + first_step_device = bv_optax.get_count(train_state["opt"], jittable=True) + first_step = int(jax.device_get(first_step_device)) + u.chrono.inform(first_step=first_step) + + # Note that training can be pre-empted during the final evaluation (i.e. + # just after the final checkpoint has been written to disc), in which case we + # want to run the evals. + if first_step in (total_steps, 0): + write_note("Running initial or final evals...") + mw.step_start(first_step) + for (name, evaluator, _, prefix) in evaluators(): + if config.evals[name].get("skip_first") and first_step != total_steps: + continue + write_note(f"{name} evaluation...\n{u.chrono.note}") + with u.chrono.log_timing(f"z/secs/eval/{name}"): + with mesh, nn.logical_axis_rules(sharding_rules): + for key, value in evaluator.run(train_state): + mw.measure(f"{prefix}{key}", value) + +################################################################################ +# # +# Train Loop # +# # +################################################################################ + + prof = None # Keeps track of start/stop of profiler state. + ckpt_mngr = None + + write_note("Starting training loop, compiling the first step...") + for step, batch in zip(range(first_step + 1, total_steps + 1), train_iter): + mw.step_start(step) + + with jax.profiler.StepTraceAnnotation("train_step", step_num=step): + with u.chrono.log_timing("z/secs/update0", noop=step > first_step + 1): + with mesh, nn.logical_axis_rules(sharding_rules): + train_state, measurements = update_fn(train_state, rng_loop, batch) + + # On the first host, let's always profile a handful of early steps. + if jax.process_index() == 0: + prof = u.startstop_prof(prof, step, first_step, get_steps("log_training")) + + # Report training progress + if (u.itstime(step, get_steps("log_training"), total_steps, host=0) + or u.chrono.warmup and jax.process_index() == 0): + for i, sched_fn_cpu in enumerate(sched_fns_cpu): + mw.measure(f"global_schedule{i if i else ''}", + sched_fn_cpu(u.put_cpu(step - 1))) + measurements = jax.device_get(measurements) + for name, value in measurements.items(): + mw.measure(name, value) + u.chrono.tick(step) + measure_per_dataset_times(step) + + for k in ("training_loss", "l2_params", "l2_grads"): + if not np.isfinite(measurements.get(k, 0.0)): + raise RuntimeError(f"{k} became nan or inf somewhere within steps " + f"[{step - get_steps('log_training')}, {step}]") + + # Checkpoint saving + keep_last = total_steps if get_steps("ckpt", None) else None + keep_ckpt_steps = get_steps("keep_ckpt", None) or keep_last + if save_ckpt_path and ( + (keep := u.itstime(step, keep_ckpt_steps, total_steps, first=False)) + or u.itstime(step, get_steps("ckpt", None), total_steps, first=True) + ): + u.chrono.pause(wait_for=train_state) + + # Copy because we add extra stuff to the checkpoint. + ckpt = {**train_state} + + # To save chrono state correctly and safely in a multihost setup, we + # broadcast the state to all hosts and convert it to a global array. + with jax.transfer_guard("allow"): + chrono_ckpt = multihost_utils.broadcast_one_to_all(u.chrono.save()) + chrono_shardings = jax.tree.map(lambda _: repl_sharding, chrono_ckpt) + ckpt = ckpt | {"chrono": u.reshard(chrono_ckpt, chrono_shardings)} + + ckpt_mngr = ckpt_mngr or array_serial.GlobalAsyncCheckpointManager() + u.save_checkpoint_ts(ckpt_mngr, ckpt, save_ckpt_path, step, keep) + u.chrono.resume() + + for (name, evaluator, log_steps, prefix) in evaluators(): + if u.itstime(step, log_steps, total_steps, first=False, last=True): + u.chrono.pause(wait_for=train_state) + u.chrono.tick(step) # Record things like epoch number, core hours etc. + write_note(f"{name} evaluation...\n{u.chrono.note}") + with u.chrono.log_timing(f"z/secs/eval/{name}"): + with mesh, nn.logical_axis_rules(sharding_rules): + for key, value in evaluator.run(train_state): + mw.measure(f"{prefix}{key}", jax.device_get(value)) + u.chrono.resume() + mw.step_end() + + # Always give a chance to stop the profiler, no matter how things ended. + # TODO: can we also do this when dying of an exception like OOM? + if jax.process_index() == 0 and prof is not None: + u.startstop_prof(prof) + + # Last note needs to happen before the pool's closed =) + write_note(f"Done!\n{u.chrono.note}") + + pool.close() + pool.join() + mw.close() + if ckpt_mngr: + ckpt_mngr.wait_until_finished() + + # Make sure all hosts stay up until the end of main. + u.sync() + + u.maybe_cleanup_workdir(workdir, flags.FLAGS.cleanup, info) + + +if __name__ == "__main__": + app.run(main) diff --git a/big_vision/trainers/proj/uvim/coco_utils.py b/big_vision/trainers/proj/uvim/coco_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..70183030ac9d441e0e9b49007be7775b16cab233 --- /dev/null +++ b/big_vision/trainers/proj/uvim/coco_utils.py @@ -0,0 +1,75 @@ +# Copyright 2022 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities to inspect coco data and predictions in notebooks.""" +# pylint: disable=consider-using-from-import +import functools +import json + +import numpy as np +from panopticapi import utils as pycoco_utils +from skimage import segmentation + +import tensorflow.io.gfile as gfile + + +import os +ROOT = os.environ.get('COCO_DATA_DIR', '.') + + +PANOPTIC_COCO_CATS_FILE = f'{ROOT}/panoptic_coco_categories.json' + + +@functools.lru_cache(maxsize=None) +def _coco_panoptic_categories(): + with gfile.GFile(PANOPTIC_COCO_CATS_FILE, 'r') as f: + categories_list = json.load(f) + return tuple(categories_list) + + +def rgb_panoptic_from_twochannels(twochannels, boundaries: bool = False): + """Makes a RGB panoptic output and segments_info from a twochannels view.""" + semantics = twochannels[..., 0] + instances = twochannels[..., 1] + max_instances = np.max(instances) + 1 + merged = semantics * max_instances + instances + merged = np.where(semantics < 0, semantics, merged) + + categories_list = _coco_panoptic_categories() + categories = {category['id']: category for category in categories_list} + id_generator = pycoco_utils.IdGenerator(categories) + segments_info = {} + rgb = np.zeros((*instances.shape[:2], 3), dtype=np.uint8) + + for merged_id in np.unique(merged): + if merged_id // max_instances > 0: + category = categories_list[int(merged_id // max_instances) - 1] + segment_id, color = id_generator.get_id_and_color(category['id']) + else: + category = {'id': -1, 'name': 'void', 'isthing': False} + segment_id, color = -1, np.array([0, 0, 0]) + segments_info[segment_id] = { + 'id': segment_id, + 'color': color, + 'category_id': category['id'], + 'name': category['name'], + 'isthing': category['isthing'], + } + rgb[merged == merged_id] = color + + if boundaries: + boundaries = segmentation.find_boundaries( + pycoco_utils.rgb2id(rgb), mode='thick') + rgb[boundaries] = 0 + return rgb, segments_info diff --git a/big_vision/trainers/proj/uvim/colorization_task.py b/big_vision/trainers/proj/uvim/colorization_task.py new file mode 100644 index 0000000000000000000000000000000000000000..1624e1b2d3465f65f69d0634a3268d682a1c1ccc --- /dev/null +++ b/big_vision/trainers/proj/uvim/colorization_task.py @@ -0,0 +1,62 @@ +# Copyright 2022 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Inputs, outputs and losses for colorization task.""" +import einops +import jax.numpy as jnp +import numpy as np + +ONE_HOT_AXIS = -2 + + +def input_pp(batch, config): + """Make inputs for colorization task.""" + if "labels" not in batch: + # During predict of phase2 there is no 'labels' field. + x = None + else: + hp, wp = config.model.patch_size + x = { + "color": batch["labels"], + } + # Convert labels from (B, H, W) to (B, num_patches, C, patch_size) + x["color"] = einops.rearrange( + x["color"], "b (hn hp) (wn wp) c -> b (hn wn) c (hp wp)", hp=hp, wp=wp) + ctx = batch.get("image_ctx", batch.get("image", None)) + return {"ctx": ctx, "x": x} + + +def loss_fn(logits, batch, config): + """Compute loss for colorization task.""" + labels = input_pp(batch, config)["x"] + error = logits["color"] - labels["color"] + loss = jnp.square(error) + return loss, {"loss_color": loss} + + +def predict_outputs(logits, config): + """Make outputs for colorization task.""" + # Map logits to (height, width, channels). + hp, wp = config.model.patch_size + hn, wn = np.array(config.model.input_size) // np.array((hp, wp)) + assert ONE_HOT_AXIS == -2, "Rearrange below depends on this." + output = einops.rearrange( + logits["color"], + "b (hn wn) c (hp wp) -> b (hn hp) (wn wp) c", + hn=hn, + wn=wn, + hp=hp, + wp=wp) + output = jnp.clip(output, -1., 1.) + return {"color": output} diff --git a/big_vision/trainers/proj/uvim/depth_task.py b/big_vision/trainers/proj/uvim/depth_task.py new file mode 100644 index 0000000000000000000000000000000000000000..e4768b769e586da2397e56d6312e61268e479830 --- /dev/null +++ b/big_vision/trainers/proj/uvim/depth_task.py @@ -0,0 +1,91 @@ +# Copyright 2022 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Inputs, outputs and losses for depth prediction task.""" +import big_vision.utils as u +import einops +import jax +import jax.numpy as jnp +import numpy as np + + +ONE_HOT_AXIS = -2 + + +def input_pp(batch, config): + """Makes inputs for depth prediction task.""" + if "labels" not in batch: + x = None + else: + hp, wp = config.model.patch_size + depth = batch["labels"][..., 0] + + # Discretize to [0, ..., bins - 1]. + nbins = config.model.inputs.depth[ONE_HOT_AXIS] + mind = config.min_depth + maxd = config.max_depth + depth = (depth - mind) / (maxd - mind) + depth *= nbins + depth = jnp.floor(depth).astype(jnp.int32) + depth = jnp.minimum(depth, nbins - 1) + depth = jnp.maximum(depth, 0) + + # Converts labels from (B, H, W, c) to (B, num_patches, c, patch_size). + depth = jax.nn.one_hot( + einops.rearrange( + depth, "b (hn hp) (wn wp) -> b (hn wn) (hp wp)", hp=hp, wp=wp), + num_classes=config.model.inputs.depth[ONE_HOT_AXIS], + axis=ONE_HOT_AXIS) + x = {"depth": depth} + ctx = batch.get("image_ctx", batch.get("image", None)) + return {"ctx": ctx, "x": x} + + +def loss_fn(predictions, batch, config): + """Computes loss for depth prediction task.""" + labels = input_pp(batch, config)["x"] + losses = {} + loss = u.softmax_xent( + logits=predictions["depth"], labels=labels["depth"], reduction=False, + axis=ONE_HOT_AXIS) + # Do not train on the closest class; usually regions of the image with + # depth==0, which is the default for regions with no depth signal. + # TODO: Encode depth==0 as class==-1. + mask = jnp.argmax(labels["depth"], ONE_HOT_AXIS) != 0 + loss = loss * mask + losses["loss_depth"] = loss + return sum(losses.values()), losses + + +def predict_outputs(predictions, config): + """Makes outputs for depth predictin tasks.""" + # Maps predictions to (height, width, channels). + hp, wp = config.model.patch_size + hn, wn = np.array(config.model.input_size) // np.array((hp, wp)) + depth = einops.rearrange( + predictions["depth"], + "b (hn wn) c (hp wp) -> b (hn hp) (wn wp) c", + hn=hn, wn=wn, hp=hp, wp=wp) + + depth = jnp.argmax(depth, axis=-1) # [B, H, W] + + # Revert discretization. + nbins = config.model.inputs.depth[ONE_HOT_AXIS] + mind = config.min_depth + maxd = config.max_depth + depth = depth.astype(jnp.float32) + 0.5 # Undoes floor in expectation. + depth /= nbins + depth = depth * (maxd - mind) + mind + + return {"depth": depth} diff --git a/big_vision/trainers/proj/uvim/panoptic_task.py b/big_vision/trainers/proj/uvim/panoptic_task.py new file mode 100644 index 0000000000000000000000000000000000000000..4049c496252105ff283e8f0120dbce06fb24ef2b --- /dev/null +++ b/big_vision/trainers/proj/uvim/panoptic_task.py @@ -0,0 +1,87 @@ +# Copyright 2022 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Inputs, outputs and losses for panoptic task.""" +import big_vision.utils as u +import einops +import jax +import jax.numpy as jnp +import numpy as np + +ONE_HOT_AXIS = -2 + + +def input_pp(batch, config): + """Make inputs for panoptic segmentation task.""" + if "labels" not in batch: + # During predict of phase2 there is no 'labels' field. + x = None + else: + hp, wp = config.model.patch_size + x = { + "semantics": batch["labels"][..., 0], + "instances": batch["labels"][..., 1], + } + # Convert labels from (B, H, W) to (B, num_patches, num_classes, patch_size) + for key in ["semantics", "instances"]: + x[key] = jax.nn.one_hot( + einops.rearrange( + x[key], "b (hn hp) (wn wp) -> b (hn wn) (hp wp)", hp=hp, wp=wp), + num_classes=config.model.inputs[key][ONE_HOT_AXIS], axis=ONE_HOT_AXIS) + ctx = batch.get("image_ctx", batch.get("image", None)) + return {"ctx": ctx, "x": x} + + +def loss_fn(logits, batch, config): + """Compute loss for panoptic task.""" + labels = input_pp(batch, config)["x"] + losses = {} + for key in ["semantics", "instances"]: + losses[f"loss_{key}"] = u.softmax_xent( + logits=logits[key], labels=labels[key], reduction=False, + axis=ONE_HOT_AXIS) + return sum(losses.values()), losses + + +def predict_outputs(logits, config, min_fraction=0.0): + """Make outputs for panoptic segmentation task.""" + # Map logits to (height, width, channels). + hp, wp = config.model.patch_size + hn, wn = np.array(config.model.input_size) // np.array((hp, wp)) + outputs = {} + for key in ["semantics", "instances"]: + assert ONE_HOT_AXIS == -2, "Rearrange below depends on this." + outputs[key] = einops.rearrange( + logits[key], + "b (hn wn) c (hp wp) -> b (hn hp) (wn wp) c", + hn=hn, wn=wn, hp=hp, wp=wp) + return panoptic_predictions_from_logits( + **outputs, min_fraction=min_fraction) + + +def panoptic_predictions_from_logits(semantics, instances, min_fraction=0.0): + """Make panoptic prediction from logits.""" + ins = jnp.argmax(instances, axis=-1) + # Note: Make sure each instance has all pixels annotated with same label. + # Otherwise they are further split into more instances and greatly affect + # the number of unmatched predicted segments (FP) and RQ. + masks = jax.nn.one_hot(ins, instances.shape[-1], dtype=jnp.int32) + label = jnp.argmax(jnp.einsum("bhwk,bhwn->bnk", semantics, masks), axis=-1) + sem = jnp.einsum("bhwn,bn->bhw", masks, label) + out = jnp.stack([sem, ins], axis=-1) + # Filter out small objects + fraction = jnp.sum(masks, axis=(1, 2), keepdims=True)/np.prod(ins.shape[1:3]) + mask_big = (fraction > min_fraction).astype("int32") + mask_big_spatial = jnp.sum(masks * mask_big, axis=-1, keepdims=True) > 0 + return out * mask_big_spatial.astype("int32") diff --git a/big_vision/trainers/proj/uvim/train.py b/big_vision/trainers/proj/uvim/train.py new file mode 100644 index 0000000000000000000000000000000000000000..bbaec203f748197ca7be2fd76a0c8c4aa58c804a --- /dev/null +++ b/big_vision/trainers/proj/uvim/train.py @@ -0,0 +1,440 @@ +# Copyright 2022 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Train loop for training the stage-II model.""" +# pylint: disable=consider-using-from-import +import functools +import importlib +import multiprocessing.pool +import os + +from absl import app +from absl import flags +from absl import logging +from big_vision import input_pipeline +import big_vision.datasets.core as ds_core +import big_vision.evaluators.common as eval_common +import big_vision.models.proj.uvim.decode as decode +import big_vision.optax as bv_optax +import big_vision.pp.builder as pp_builder +import big_vision.utils as u +from clu import parameter_overview +import flax +import jax +import jax.numpy as jnp +from ml_collections import config_flags +import numpy as np +import optax + +import tensorflow.io.gfile as gfile + + +config_flags.DEFINE_config_file( + "config", None, "Training configuration.", lock_config=True) + +flags.DEFINE_string("workdir", default=None, help="Work unit directory.") +flags.DEFINE_boolean("cleanup", default=False, + help="Delete workdir (only) after successful completion.") + +# Adds jax flags to the program. +jax.config.parse_flags_with_absl() + + +FLAGS = flags.FLAGS +ONE_HOT_AXIS = -2 +partial = functools.partial + + +def get_model(config): + mod = importlib.import_module(f"big_vision.models.{config.model_name}") + model = mod.Model(**config.model) + return model, mod + + +def setup_task(config): + """Get functions and params to encode and decode labels as token sequences.""" + config = config.oracle + + # Define task input and predict functions. + task_module = importlib.import_module(f"big_vision.trainers.{config.task}") + input_fn = partial(task_module.input_pp, config=config) + predict_outputs_fn = partial(task_module.predict_outputs, config=config) + + oracle, mod = get_model(config) + if config.get("model_init", None): + params, state = mod.load(None, config.model_init) + params = {"params": params, "state": state} + else: + params = {} + + def encode_labels(params, batch): + inputs = input_fn(batch) + code = oracle.apply(params, **inputs, method=oracle.encode)[1]["code"] + return code + 1 # To avoid padding symbol. + + def decode_labels(params, code, batch, **kwargs): + code = code - 1 + inputs = input_fn(batch) + inputs["x"] = code + logits, _ = oracle.apply( + params, **inputs, discrete_input=True, **kwargs, method=oracle.decode) + return logits + + return encode_labels, decode_labels, predict_outputs_fn, params + + +def main(argv): + del argv + + config = FLAGS.config + workdir = FLAGS.workdir + logging.info("\u001b[33mHello from process %i holding %i/%i devices and " + "writing to workdir %s.\u001b[0m", jax.process_index(), + jax.local_device_count(), jax.device_count(), workdir) + + save_ckpt_path = None + if workdir: # Always create if requested, even if we may not write into it. + gfile.makedirs(workdir) + save_ckpt_path = os.path.join(workdir, "checkpoint.npz") + + # The pool is used to perform misc operations such as logging in async way. + pool = multiprocessing.pool.ThreadPool() + + # Here we register preprocessing ops from modules listed on `pp_modules`. + for m in config.get("pp_modules", + ["ops_general", "ops_image", "proj.uvim.pp_ops"]): + importlib.import_module(f"big_vision.pp.{m}") + + # This seed makes the Jax part of things (like model init) deterministic. + # However, full training still won't be deterministic, for example due to the + # tf.data pipeline not being deterministic even if we would set TF seed. + # See (internal link) for a fun read on what it takes. + rng = jax.random.PRNGKey(config.get("seed", 0)) + + # These functions do more stuff internally, for OSS release we mock them by + # trivial alternatives in order to minize disruptions in the code. + xid, wid = -1, -1 + fillin = lambda s: s + def info(s, *a): + logging.info("\u001b[33mNOTE\u001b[0m: " + s, *a) + def write_note(note): + if jax.process_index() == 0: + info("%s", note) + + write_note("Initializing...") + + batch_size = config.input.batch_size + if batch_size % jax.device_count() != 0: + raise ValueError(f"Batch size ({batch_size}) must " + f"be divisible by device number ({jax.device_count()})") + info("Global batch size %d on %d hosts results in %d local batch size. With " + "%d dev per host (%d dev total), that's a %d per-device batch size.", + batch_size, jax.process_count(), batch_size // jax.process_count(), + jax.local_device_count(), jax.device_count(), + batch_size // jax.device_count()) + + # First thing after above sanity checks, so we can log "start" ticks. + mw = u.BigVisionMetricWriter(xid, wid, workdir, config) + chrono = u.Chrono() + + write_note("Initializing train dataset...") + train_data = ds_core.get(**config.input.data) + train_ds = input_pipeline.make_for_train( + data=train_data.get_tfdata(ordered=False), + batch_size=batch_size, + preprocess_fn=pp_builder.get_preprocess_fn(config.input.get("pp")), + shuffle_buffer_size=config.input.get("shuffle_buffer_size"), + cache_raw=config.input.get("cache_raw", False), + filter_fn=config.input.get("filter_fn"), + ) + + # Start prefetching already. + n_prefetch = config.get("prefetch_to_device", 1) + train_iter = input_pipeline.start_input_pipeline(train_ds, n_prefetch) + ntrain_img = train_data.total_examples + + def get_steps(name, default=ValueError): # partial doesn't work well here. + return u.steps(name, config, ntrain_img, batch_size, default) + total_steps = get_steps("total") + + info("Running for %d steps, that means %f epochs", + total_steps, total_steps * batch_size / ntrain_img) + + write_note(f"Initializing {config.model_name} model...") + model, model_mod = get_model(config) + + encode_labels, decode_labels, predict_outputs_fn, task_params = ( + setup_task(config)) + + # We want all parameters to be created in host RAM, not on any device, they'll + # be sent there later as needed, otherwise we already encountered two + # situations where we allocate them twice. + @partial(jax.jit, backend="cpu") + def init(rng): + batch = jax.tree_map( + lambda x: jnp.zeros(x.shape, x.dtype.as_numpy_dtype), + train_ds.element_spec) + images = batch["image"] + labels = encode_labels(task_params, batch) + variables = model.init(rng, images, labels) + params = flax.core.unfreeze(variables["params"]) + return params + + rng, init_rng = jax.random.split(rng) + params_cpu = init(init_rng) + + if jax.process_index() == 0: + num_params = sum(p.size for p in jax.tree_leaves(params_cpu)) + parameter_overview.log_parameter_overview(params_cpu, msg="init params") + mw.measure("num_params", num_params) + + write_note(f"Initializing {config.optax_name} optimizer...") + tx, sched_fns = bv_optax.make(config, params_cpu, sched_kw=dict( + total_steps=total_steps, batch_size=batch_size, data_size=ntrain_img)) + + # We jit this, such that the arrays are created on the CPU, not device[0]. + opt_cpu = jax.jit(tx.init, backend="cpu")(params_cpu) + sched_fns_cpu = [jax.jit(sched_fn, backend="cpu") for sched_fn in sched_fns] + + @partial(jax.pmap, axis_name="batch", donate_argnums=(0, 1)) + def update_fn(params, opt, batch, update_rng, task_params): + """Update step.""" + images = batch["image"] + labels = encode_labels(task_params, batch) + + measurements = {} + + rng, new_rng = jax.random.split(update_rng) + # bind the rng key to the device id (which is unique across hosts) + rng_local = jax.random.fold_in(rng, jax.lax.axis_index("batch")) + + def loss_fn(params, images, labels): + logits = model.apply({"params": params}, images, labels, train=True, + rngs={"dropout": rng_local}) + loss = u.weighted_softmax_xent( + logits=logits, labels=labels, + reduction=True, normalize=True) + return loss + + l, grads = jax.value_and_grad(loss_fn)(params, images, labels) + l, grads = jax.lax.pmean((l, grads), axis_name="batch") + updates, opt = tx.update(grads, opt, params) + params = optax.apply_updates(params, updates) + + gs = jax.tree_leaves(bv_optax.replace_frozen(config.schedule, grads, 0.)) + measurements["l2_grads"] = jnp.sqrt(sum([jnp.vdot(g, g) for g in gs])) + ps = jax.tree_leaves(params) + measurements["l2_params"] = jnp.sqrt(sum([jnp.vdot(p, p) for p in ps])) + us = jax.tree_leaves(updates) + measurements["l2_updates"] = jnp.sqrt(sum([jnp.vdot(u, u) for u in us])) + + return params, opt, l, new_rng, measurements + + # Define evaluators. + def validation_fn(params, batch): + """Compute per-example metrics.""" + params, task_params = params["params"], params["task_params"] + images = batch["image"] + labels = encode_labels(task_params, batch) + logits = model.apply({"params": params}, images, labels, train=False) + loss = u.weighted_softmax_xent( + logits=logits, labels=labels, + reduction=False, normalize=True) + losses = {"loss": loss} + return jax.tree_map( + lambda x: jnp.mean(x, axis=tuple(range(1, x.ndim))), + losses) + + def predict_fn(params, batch, seed=0, temperature=1e-7, **extra): + params, task_params = params["params"], params["task_params"] + + # Derive a rng key from the inputs so that all batches use different keys. + if "image/id" in batch: + key = batch["image/id"] + else: + key = batch["image"].sum(axis=[1, 2, 3]).astype(jnp.int32) + local_rng = jax.lax.scan( + lambda k, x: (jax.random.fold_in(k, x), None), + jax.random.PRNGKey(seed), + key, + )[0] + + images = batch["image"] + batch_size = images.shape[0] + prompts = jnp.zeros((batch_size, config.model.seq_len), dtype=jnp.int32) + seqs, _, _ = decode.temperature_sampling( + params={"params": params}, model=model, seed=local_rng, + inputs=images, prompts=prompts, + num_samples=1, eos_token=-1, prefill=False, + temperature=temperature) + seqs = jnp.squeeze(seqs, 1) + logits = decode_labels(task_params, seqs, batch) + return predict_outputs_fn(logits, **extra) + + # Only initialize evaluators when they are first needed. + @functools.lru_cache(maxsize=None) + def evaluators(): + return eval_common.from_config( + config, {"predict": predict_fn, "validation": validation_fn}, + lambda s: write_note(f"Initializing evaluator: {s}...\n{chrono.note}") + ) + + # Decide how to initialize training. The order is important. + # 1. Always resumes from the existing checkpoint, e.g. resumes a finetune job. + # 2. Resume from a previous checkpoint, e.g. start a cooldown training job. + # 3. Initialize model from something, e,g, start a fine-tuning job. + # 4. Initialize part of the model from something, eg. only encoder or decoder. + # 5. Train from scratch. + resume_ckpt_path = None + if save_ckpt_path and gfile.exists(save_ckpt_path): + resume_ckpt_path = save_ckpt_path + elif config.get("resume"): + resume_ckpt_path = fillin(config.resume) + if resume_ckpt_path: + write_note("Resume training from checkpoint...") + checkpoint = { + "params": params_cpu, + "opt": opt_cpu, + "chrono": chrono.save(), + } + checkpoint_tree = jax.tree_structure(checkpoint) + loaded = u.load_checkpoint(checkpoint_tree, resume_ckpt_path) + # bfloat16 type gets lost when data is saved to disk, so we recover it. + checkpoint = jax.tree_map(u.recover_dtype, loaded) + params_cpu, opt_cpu = checkpoint["params"], checkpoint["opt"] + chrono.load(checkpoint["chrono"]) + elif config.get("model_init"): + write_note(f"Initialize model from {config.model_init}...") + params_cpu = model_mod.load( + params_cpu, config.model_init, config.model, + **config.get("model_load", {})) + if jax.process_index() == 0: + parameter_overview.log_parameter_overview( + params_cpu, msg="restored params") + + write_note("Kicking off misc stuff...") + first_step = bv_optax.get_count(opt_cpu) + chrono.inform(first_step, total_steps, batch_size, ntrain_img / batch_size) + prof = None # Keeps track of start/stop of profiler state. + + write_note(f"Replicating...\n{chrono.note}") + params_repl = flax.jax_utils.replicate(params_cpu) + opt_repl = flax.jax_utils.replicate(opt_cpu) + task_params = flax.jax_utils.replicate(task_params) + update_rngs = flax.jax_utils.replicate(rng) + + ckpt_writer = None + + write_note(f"First step compilations...\n{chrono.note}") + error = None # For exiting with an error after cleanup. Avoids indentation. + + # Using a python integer for step here, because opt.state.step is allocated + # on TPU during replication. + for step, batch in zip(range(first_step + 1, total_steps + 1), train_iter): + mw.step_start(step) + + with jax.profiler.StepTraceAnnotation("train_step", step_num=step): + params_repl, opt_repl, loss_value, update_rngs, measurements = ( + update_fn( + params_repl, + opt_repl, + batch, + update_rng=update_rngs, + task_params=task_params)) + + # On the first host, let's always profile a handful of early steps. + if jax.process_index() == 0: + prof = u.startstop_prof(prof, step, first_step, get_steps("log_training")) + + # Report training progress + if (u.itstime(step, get_steps("log_training"), total_steps, host=0) + or chrono.warmup and jax.process_index() == 0): + for i, sched_fn_cpu in enumerate(sched_fns_cpu): + mw.measure(f"global_schedule{i if i else ''}", sched_fn_cpu(step - 1)) + l = mw.measure("training_loss", loss_value[0]) + for name, value in measurements.items(): + mw.measure(name, value[0]) + chrono.tick(step, mw.measure, write_note) + if not np.isfinite(l): + error = (f"The loss became nan or inf somewhere within steps " + f"[{step - get_steps('log_training')}, {step}]") + break + + # Checkpoint saving + if (save_ckpt_path and + (u.itstime(step, get_steps("ckpt", None), total_steps, host=0) or + u.itstime(step, get_steps("keep_ckpt", None), total_steps, host=0))): + chrono.pause(wait_for=(params_repl, opt_repl)) + u.checkpointing_timeout(ckpt_writer, config.get("ckpt_timeout", 1)) + # We need to transfer the weights over now or else we risk keeping them + # alive while they'll be updated in a future step, creating hard to debug + # memory errors (see (internal link)). Also, takes device 0's params only. + opt_cpu = jax.tree_map(lambda x: np.array(x[0]), opt_repl) + params_cpu = jax.tree_map(lambda x: np.array(x[0]), params_repl) + + # Check whether we want to keep a copy of the current checkpoint. + copy_step = None + if u.itstime(step, get_steps("keep_ckpt", None), total_steps): + copy_step = step + + ckpt = {"params": params_cpu, "opt": opt_cpu, "chrono": chrono.save()} + ckpt_writer = pool.apply_async( + u.save_checkpoint, (ckpt, save_ckpt_path, copy_step)) + chrono.resume() + + for (name, evaluator, log_steps, prefix) in evaluators(): + if u.itstime(step, log_steps, total_steps, first=log_steps < total_steps, + last=False): + chrono.pause(wait_for=(params_repl, task_params)) + write_note(f"{name} evaluation...\n{chrono.note}") + for key, value in evaluator.run( + {"params": params_repl, "task_params": task_params}): + mw.measure(f"{prefix}{key}", value) + chrono.resume() + mw.step_end() + + # Always give a chance to stop the profiler, no matter how things ended. + # TODO: can we also do this when dying of an exception like OOM? + if jax.process_index() == 0 and prof is not None: + u.startstop_prof(prof) + + # Run final evalution, also used for eval only jobs (when total_steps == 0). + for (name, evaluator, _, prefix) in evaluators(): + write_note(f"{name} evaluation...\n{chrono.note}") + for key, value in evaluator.run( + {"params": params_repl, "task_params": task_params}): + mw.measure(f"{prefix}{key}", value) + + # Last note needs to happen before the pool's closed =) + if not error: + write_note(f"Done!\n{chrono.note}") + else: + write_note(f"Failed!\n{error}\n{chrono.note}") + + pool.close() + pool.join() + mw.close() + + # Make sure all hosts stay up until the end of main. + u.sync() + + # Before cleanup, as cleanup should only run for successful jobs. + if error is not None: + raise RuntimeError(error) + + u.maybe_cleanup_workdir(workdir, flags.FLAGS.cleanup, info) + + +if __name__ == "__main__": + app.run(main) diff --git a/big_vision/trainers/proj/uvim/vqvae.py b/big_vision/trainers/proj/uvim/vqvae.py new file mode 100644 index 0000000000000000000000000000000000000000..13467d02ca3c95c4bd8c9a38b9970db24c0f4e3a --- /dev/null +++ b/big_vision/trainers/proj/uvim/vqvae.py @@ -0,0 +1,414 @@ +# Copyright 2022 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Train loop for training the stage-I model.""" +# pylint: disable=consider-using-from-import +import functools +import importlib +import multiprocessing.pool +import os + +from absl import app +from absl import flags +from absl import logging +from big_vision import input_pipeline +import big_vision.datasets.core as ds_core +import big_vision.evaluators.common as eval_common +import big_vision.optax as bv_optax +import big_vision.pp.builder as pp_builder +import big_vision.utils as u +from clu import parameter_overview +import flax +import jax +import jax.numpy as jnp +from ml_collections import config_flags +import numpy as np +import optax + +import tensorflow.io.gfile as gfile + + +SG = jax.lax.stop_gradient +partial = functools.partial + +config_flags.DEFINE_config_file( + "config", None, "Training configuration.", lock_config=True) + +flags.DEFINE_string("workdir", default=None, help="Work unit directory.") +flags.DEFINE_boolean("cleanup", default=False, + help="Delete workdir (only) after successful completion.") + +# Adds jax flags to the program. +jax.config.parse_flags_with_absl() + + +def main(argv): + del argv + + config = flags.FLAGS.config + workdir = flags.FLAGS.workdir + logging.info("Workdir: %s", workdir) + + logging.info("\u001b[33mHello from process %i holding %i/%i devices and " + "writing to workdir %s.\u001b[0m", jax.process_index(), + jax.local_device_count(), jax.device_count(), workdir) + + # Define task input, loss and predict functions. + task_module = importlib.import_module(f"big_vision.trainers.{config.task}") + input_pp_fn = partial(task_module.input_pp, config=config) + task_loss_fn = partial(task_module.loss_fn, config=config) + predict_outputs_fn = partial(task_module.predict_outputs, config=config) + + save_ckpt_path = None + if workdir: # Always create if requested, even if we may not write into it. + gfile.makedirs(workdir) + save_ckpt_path = os.path.join(workdir, "checkpoint.npz") + + # The pool is used to perform misc operations such as logging in async way. + pool = multiprocessing.pool.ThreadPool() + + # Here we register preprocessing ops from modules listed on `pp_modules`. + for m in config.get("pp_modules", + ["ops_general", "ops_image", "proj.uvim.pp_ops"]): + importlib.import_module(f"big_vision.pp.{m}") + + # This seed makes the Jax part of things (like model init) deterministic. + # However, full training still won't be deterministic, for example due to the + # tf.data pipeline not being deterministic even if we would set TF seed. + # See (internal link) for a fun read on what it takes. + rng = jax.random.PRNGKey(config.get("seed", 0)) + + # These functions do more stuff internally, for OSS release we mock them by + # trivial alternatives in order to minize disruptions in the code. + xid, wid = -1, -1 + fillin = lambda s: s + def info(s, *a): + logging.info("\u001b[33mNOTE\u001b[0m: " + s, *a) + def write_note(note): + if jax.process_index() == 0: + info("%s", note) + + write_note("Initializing...") + + batch_size = config.input.batch_size + if batch_size % jax.device_count() != 0: + raise ValueError(f"Batch size ({batch_size}) must " + f"be divisible by device number ({jax.device_count()})") + info("Global batch size %d on %d hosts results in %d local batch size. With " + "%d dev per host (%d dev total), that's a %d per-device batch size.", + batch_size, jax.process_count(), batch_size // jax.process_count(), + jax.local_device_count(), jax.device_count(), + batch_size // jax.device_count()) + + # First thing after above sanity checks, so we can log "start" ticks. + mw = u.BigVisionMetricWriter(xid, wid, workdir, config) + chrono = u.Chrono() + + write_note("Initializing train dataset...") + train_data = ds_core.get(**config.input.data) + train_ds = input_pipeline.make_for_train( + data=train_data.get_tfdata(ordered=False), + batch_size=batch_size, + preprocess_fn=pp_builder.get_preprocess_fn(config.input.get("pp")), + shuffle_buffer_size=config.input.get("shuffle_buffer_size"), + cache_raw=config.input.get("cache_raw", False), + filter_fn=config.input.get("filter_fn"), + ) + + # Start prefetching already. + n_prefetch = config.get("prefetch_to_device", 1) + train_iter = input_pipeline.start_input_pipeline(train_ds, n_prefetch) + ntrain_img = train_data.total_examples + + def get_steps(name, default=ValueError): # partial doesn't work well here. + return u.steps(name, config, ntrain_img, batch_size, default) + total_steps = get_steps("total") + + info("Running for %d steps, that means %f epochs", + total_steps, total_steps * batch_size / ntrain_img) + + write_note(f"Initializing {config.model_name} model...") + model_mod = importlib.import_module(f"big_vision.models.{config.model_name}") + model = model_mod.Model(**config.model) + + # We want all parameters to be created in host RAM, not on any device, they'll + # be sent there later as needed, otherwise we already encountered two + # situations where we allocate them twice. + @partial(jax.jit, backend="cpu") + def init(rng): + batch = jax.tree_map( + lambda x: jnp.zeros(x.shape, x.dtype.as_numpy_dtype), + train_ds.element_spec) + init_res = flax.core.unfreeze(model.init(rng, **input_pp_fn(batch))) + params, state = init_res["params"], init_res["state"] + + # Set bias in the heads to a low value, such that loss is small initially. + for key in config.model.outputs: + params[f"head_{key}"]["bias"] = jnp.full_like( + params[f"head_{key}"]["bias"], config.get("init_head_bias", 0)) + + return params, state + + rng, rng_init = jax.random.split(rng) + + rng_init_params, rng_init_state = jax.random.split(rng_init) + params_cpu, state_cpu = init({"params": rng_init_params, + "state": rng_init_state}) + + if jax.process_index() == 0: + num_params = sum(p.size for p in jax.tree_leaves(params_cpu)) + parameter_overview.log_parameter_overview(params_cpu, msg="init params") + mw.measure("num_params", num_params) + + write_note(f"Initializing {config.optax_name} optimizer...") + tx, sched_fns = bv_optax.make(config, params_cpu, sched_kw=dict( + total_steps=total_steps, batch_size=batch_size, data_size=ntrain_img)) + + # We jit this, such that the arrays are created on the CPU, not device[0]. + opt_cpu = jax.jit(tx.init, backend="cpu")(params_cpu) + sched_fns_cpu = [jax.jit(sched_fn, backend="cpu") for sched_fn in sched_fns] + + @partial(jax.pmap, axis_name="batch", donate_argnums=(0, 1, 2), + static_broadcasted_argnums=(5,)) + def update_fn(params, opt, state, batch, rng, update_dict=True): + """Update step.""" + measurements = {} + + # Get device-specific loss rng. + rng, rng_model = jax.random.split(rng, 2) + rng_model_local = jax.random.fold_in(rng_model, jax.lax.axis_index("batch")) + + def loss_fn(params, state, batch): + (logits, out), mutated_col = model.apply( + {"params": params, "state": state}, + **input_pp_fn(batch), + train=True, update_dict=update_dict, + rngs={"dropout": rng_model_local, "vqvae": rng_model}, + mutable=["state"]) + btlneck = out["bottleneck"] + btlneck_q = out["bottleneck_q"] + + loss_rec, logs = jax.tree_map(jnp.mean, task_loss_fn(logits, batch)) + loss_commitment = jnp.mean(jnp.square(btlneck - SG(btlneck_q))) + loss = loss_rec + config.get("w_commitment", 0.25) * loss_commitment + aux = { + "loss_rec": jax.lax.pmean(loss_rec, axis_name="batch"), + "loss_commitment": jax.lax.pmean(loss_commitment, axis_name="batch"), + "codebook_zeros_ratio": out["codebook_zeros_ratio"], + "codebook_max_ratio": out["codebook_max_ratio"], + "state": mutated_col["state"], + **jax.tree_map(partial(jax.lax.pmean, axis_name="batch"), logs), + } + return loss, aux + + (l, aux), grads = jax.value_and_grad(loss_fn, has_aux=True)( + params, state, batch) + l, grads = jax.lax.pmean((l, grads), axis_name="batch") + updates, opt = tx.update(grads, opt, params) + params = optax.apply_updates(params, updates) + state = aux.pop("state") + measurements = {**measurements, **aux} + + gs = jax.tree_leaves(bv_optax.replace_frozen(config.schedule, grads, 0.)) + measurements["l2_grads"] = jnp.sqrt(sum([jnp.vdot(g, g) for g in gs])) + ps = jax.tree_leaves(params) + measurements["l2_params"] = jnp.sqrt(sum([jnp.vdot(p, p) for p in ps])) + us = jax.tree_leaves(updates) + measurements["l2_updates"] = jnp.sqrt(sum([jnp.vdot(u, u) for u in us])) + + return params, opt, state, l, rng, measurements + + # Define evaluators. + def validation_fn(params, batch): + """Compute per-example metrics.""" + logits, out = model.apply(params, **input_pp_fn(batch)) + _, losses = task_loss_fn(logits, batch) + btlneck = out["bottleneck"] + btlneck_q = out["bottleneck_q"] + losses["loss_commitment"] = jnp.square(btlneck - btlneck_q) + return jax.tree_map( + lambda x: jnp.mean(x, axis=tuple(range(1, x.ndim))), + losses) + + def predict_fn(params, batch): + logits, _ = model.apply(params, **input_pp_fn(batch)) + outputs = predict_outputs_fn(logits) + return outputs + + # Only initialize evaluators when they are first needed. + @functools.lru_cache(maxsize=None) + def evaluators(): + return eval_common.from_config( + config, {"predict": predict_fn, "validation": validation_fn}, + lambda s: write_note(f"Initializing evaluator: {s}...\n{chrono.note}") + ) + + # Decide how to initialize training. The order is important. + # 1. Always resumes from the existing checkpoint, e.g. resumes a finetune job. + # 2. Resume from a previous checkpoint, e.g. start a cooldown training job. + # 3. Initialize model from something, e,g, start a fine-tuning job. + # 4. Train from scratch. + resume_ckpt_path = None + if save_ckpt_path and gfile.exists(save_ckpt_path): + resume_ckpt_path = save_ckpt_path + elif config.get("resume"): + resume_ckpt_path = fillin(config.resume) + if resume_ckpt_path: + write_note("Resume training from checkpoint...") + checkpoint = { + "params": params_cpu, + "state": state_cpu, + "opt": opt_cpu, + "chrono": chrono.save(), + } + checkpoint_tree = jax.tree_structure(checkpoint) + loaded = u.load_checkpoint(checkpoint_tree, resume_ckpt_path) + # bfloat16 type gets lost when data is saved to disk, so we recover it. + checkpoint = jax.tree_map(u.recover_dtype, loaded) + params_cpu = checkpoint["params"] + state_cpu = checkpoint["state"] + opt_cpu = checkpoint["opt"] + chrono.load(checkpoint["chrono"]) + elif config.get("model_init"): + write_note(f"Initialize model from {config.model_init}...") + params_cpu, state_cpu = model_mod.load( + {"params": params_cpu, "state": state_cpu}, + config.model_init, config.model, + **config.get("model_load", {})) + if jax.process_index() == 0: + parameter_overview.log_parameter_overview( + params_cpu, msg="restored params") + + write_note("Kicking off misc stuff...") + first_step = bv_optax.get_count(opt_cpu) + chrono.inform(first_step, total_steps, batch_size, ntrain_img / batch_size) + prof = None # Keeps track of start/stop of profiler state. + + write_note(f"Replicating...\n{chrono.note}") + params_repl = flax.jax_utils.replicate(params_cpu) + opt_repl = flax.jax_utils.replicate(opt_cpu) + state_repl = flax.jax_utils.replicate(state_cpu) + + rng, rng_loop = jax.random.split(rng, 2) + rngs_loop = flax.jax_utils.replicate(rng_loop) + ckpt_writer = None + + write_note(f"First step compilations...\n{chrono.note}") + error = None # For exiting with an error after cleanup. Avoids indentation. + + # Using a python integer for step here, because opt.state.step is allocated + # on TPU during replication. + for step, batch in zip(range(first_step + 1, total_steps + 1), train_iter): + mw.step_start(step) + + with jax.profiler.StepTraceAnnotation("train_step", step_num=step): + params_repl, opt_repl, state_repl, loss_value, rngs_loop, measurements = ( + update_fn( + params_repl, + opt_repl, + state_repl, + batch, + rngs_loop, + not config.get("freeze_dict", True))) + + # On the first host, let's always profile a handful of early steps. + if jax.process_index() == 0: + prof = u.startstop_prof(prof, step, first_step, get_steps("log_training")) + + # Report training progress + if (u.itstime(step, get_steps("log_training"), total_steps, host=0) + or chrono.warmup and jax.process_index() == 0): + for i, sched_fn_cpu in enumerate(sched_fns_cpu): + mw.measure(f"global_schedule{i if i else ''}", sched_fn_cpu(step - 1)) + l = mw.measure("training_loss", loss_value[0]) + for name, value in measurements.items(): + mw.measure(name, value[0]) + chrono.tick(step, mw.measure, write_note) + if not np.isfinite(l): + error = (f"The loss became nan or inf somewhere within steps " + f"[{step - get_steps('log_training')}, {step}]") + break + + # Checkpoint saving + if (save_ckpt_path and + (u.itstime(step, get_steps("ckpt", None), total_steps, host=0) or + u.itstime(step, get_steps("keep_ckpt", None), total_steps, host=0))): + chrono.pause(wait_for=(params_repl, opt_repl, state_repl)) + u.checkpointing_timeout(ckpt_writer, config.get("ckpt_timeout", 1)) + # We need to transfer the weights over now or else we risk keeping them + # alive while they'll be updated in a future step, creating hard to debug + # memory errors (see (internal link)). Also, takes device 0's params only. + params_cpu, opt_cpu, state_cpu = jax.tree_map( + lambda x: np.array(x[0]), (params_repl, opt_repl, state_repl)) + + # Check whether we want to keep a copy of the current checkpoint. + copy_step = None + if u.itstime(step, get_steps("keep_ckpt", None), total_steps): + copy_step = step + + ckpt = { + "params": params_cpu, + "state": state_cpu, + "opt": opt_cpu, + "chrono": chrono.save(), + } + ckpt_writer = pool.apply_async( + u.save_checkpoint, (ckpt, save_ckpt_path, copy_step)) + chrono.resume() + + for (name, evaluator, log_steps, prefix) in evaluators(): + if u.itstime(step, log_steps, total_steps): + chrono.pause(wait_for=(params_repl, state_repl)) + write_note(f"{name} evaluation...\n{chrono.note}") + for key, value in evaluator.run( + {"params": params_repl, "state": state_repl}): + mw.measure(f"{prefix}{key}", value) + chrono.resume() + mw.step_end() + + # Always give a chance to stop the profiler, no matter how things ended. + # TODO: can we also do this when dying of an exception like OOM? + if jax.process_index() == 0 and prof is not None: + u.startstop_prof(prof) + + # Support eval only runs: run evaluation if total_steps (or num_epochs) is 0. + if total_steps == 0: + for (name, evaluator, _, prefix) in evaluators(): + write_note(f"{name} evaluation...\n{chrono.note}") + for key, value in evaluator.run( + {"params": params_repl, "state": state_repl}): + mw.measure(f"{prefix}{key}", value) + + # Last note needs to happen before the pool's closed =) + if not error: + write_note(f"Done!\n{chrono.note}") + else: + write_note(f"Failed!\n{error}\n{chrono.note}") + + pool.close() + pool.join() + mw.close() + + # Make sure all hosts stay up until the end of main. + u.sync() + + # Before cleanup, as cleanup should only run for successful jobs. + if error is not None: + raise RuntimeError(error) + + u.maybe_cleanup_workdir(workdir, flags.FLAGS.cleanup, info) + + +if __name__ == "__main__": + app.run(main) diff --git a/big_vision/utils.py b/big_vision/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..14b3b03211500b076d719d3f673baeda7761641c --- /dev/null +++ b/big_vision/utils.py @@ -0,0 +1,1427 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utils very specific to this project, not generic.""" + +import collections +import contextlib +import dataclasses +import functools +import io +import json +import multiprocessing +import multiprocessing.pool +import os +import re +import sys +import time +from typing import Mapping + +from absl import flags +from absl import logging +from big_vision.pp import registry as pp_registry +import einops +import flax +import flax.jax_utils as flax_utils +import jax +from jax.experimental.array_serialization import serialization as array_serial +import jax.numpy as jnp +import ml_collections as mlc +import numpy as np + +import tensorflow.io.gfile as gfile # pylint: disable=consider-using-from-import + + +Registry = pp_registry.Registry + + +# pylint: disable=logging-fstring-interpolation + + +def pad_shard_unpad(wrapped, static_argnums=(0,), static_argnames=()): + """Wraps a function with code that pads, shards, then un-shards, un-pads. + + Args: + wrapped: the function to be wrapped. Signature is `params, *args, *kwargs`. + static_argnums: indices of arguments to `wrapped` that should _not_ be + padded and sharded, but instead be forwarded as-is. The default is (0,) + because by far the most common use-case is to pass `params` first. + static_argnames: names of kwargs to `wrapped` that should _not_ be padded + and sharded, but instead be forwarded as-is. + + Returns: + A new function that pads and shards its arguments before passing them to + the wrapped function, and un-shards and un-pads the returned pytree. + + This is useful for calling a pmap'ed function with inputs that aren't + divisible by the number of devices. A typical use is: + @pad_shard_unpad + @jax.pmap + def forward(params, x): ... + + Notes: + The padding is done in host-memory before being passed to the function, and + the values returned by the function are transferred back to host memory. + + The returned function is augmented with a new keyword-only argument + `min_device_batch` that, if specified, forces padding inputs to at least + this size per device. This can be useful to avoid recompiles for the last + batch and reduce memory fragmentation. + """ + + def pad_shard_unpad_wrapper(*args, min_device_batch=None, **kw): + d = jax.local_device_count() # d = devices, b = batch + + # Find the batch-sizes of all non-static arguments. + def get_bs(x): + batch_sizes = jax.tree.map(lambda y: y.shape[0], x) + return jax.tree.flatten(batch_sizes)[0] + + bs_a = [get_bs(a) for i, a in enumerate(args) if i not in static_argnums] + bs_kw = [get_bs(v) for k, v in kw.items() if k not in static_argnames] + bs = set([n for b in (bs_a + bs_kw) for n in b]) + assert len(bs) == 1, f"Inconsistent batch-sizes: {bs}" + b = bs.pop() + + def pad(x): + _, *shape = x.shape + db, rest = divmod(b, d) + if rest: + x = np.concatenate([x, np.zeros((d - rest, *shape), x.dtype)], axis=0) + db += 1 + if min_device_batch and db < min_device_batch: + x = np.concatenate( + [x, np.zeros((d * (min_device_batch - db), *shape), x.dtype)]) + db = min_device_batch + return x.reshape(d, db, *shape) + + def maybe_pad(x, actually_pad=True): + if not actually_pad: return x # For call-site convenience below. + return jax.tree.map(pad, x) + + args = [maybe_pad(a, i not in static_argnums) for i, a in enumerate(args)] + kw = {k: maybe_pad(v, k not in static_argnames) for k, v in kw.items()} + out = wrapped(*args, **kw) + + def unpad(x): + # Transfer back before cutting, to reduce on-device shape diversity. + return einops.rearrange(jax.device_get(x), "d b ... -> (d b) ...")[:b] + return jax.tree.map(unpad, out) + + return pad_shard_unpad_wrapper + + +def onehot(labels, num_classes, on_value=1.0, off_value=0.0): + x = (labels[..., None] == jnp.arange(num_classes)[None]) + x = jax.lax.select(x, jnp.full(x.shape, on_value), + jnp.full(x.shape, off_value)) + return x.astype(jnp.float32) + + +def npload(fname): + """Loads `fname` and returns an np.ndarray or dict thereof.""" + # Load the data; use local paths directly if possible: + if os.path.exists(fname): + loaded = np.load(fname, allow_pickle=False) + else: + # For other (remote) paths go via gfile+BytesIO as np.load requires seeks. + with gfile.GFile(fname, "rb") as f: + data = f.read() + loaded = np.load(io.BytesIO(data), allow_pickle=False) + + # Support loading both single-array files (np.save) and zips (np.savez). + if isinstance(loaded, np.ndarray): + return loaded + else: + return dict(loaded) + + +def load_checkpoint_np(npz, tree=None): + """Loads a jax pytree from a npz file. + + Args: + npz: Either path to the checkpoint file (.npz), or a dict-like. + tree: deprecated, use None. + Bwd-compat for old format that only stored values: the pytree structure. + + Returns: + A pytree that is the checkpoint. + """ + if isinstance(npz, str): # If not already loaded, then load. + npz = npload(npz) + keys, values = zip(*list(npz.items())) + if tree: + checkpoint = tree.unflatten(values) + else: + checkpoint = recover_tree(keys, values) + return checkpoint + + +def load_params(ckpt, **kw): + """Loads the parameters of a big_vision checkpoint, both old or new format. + + Args: + ckpt: Path to the checkpoint (.npz, .ts) or dict-like. + **kw: forwarded to the underlying load function (_np or _ts). + + Returns: + A pytree that is the checkpoint, potentially sharded. + + Notes: + The `ckpt` string can contain an colon-separated "submodel" indicator, like + `img` in the example `/path/to/file.npz:img`. + This is used to load sub-parts of a model, for example the image load the + image encoder out of a two_tower (SigLIP) checkpoint, or distillation. + This way, ANY model that uses this function can load itself from a + checkpoint that contains multiple sub-models. + """ + key = None # Whether we want to extract only a sub-key of the model. + + if isinstance(ckpt, str): # Most common case of passing a checkpoint path. + # Potentially read out the sub-part to load from after the colon + # '/path/to/file:img/head' => '/path/to/file', 'img/head' + # 'gs://path/to/file' => 'gs://path/to/file', None + if match := re.match(r"^(.*?/.*?)(?::([\w/]+))?$", ckpt): + ckpt, key = match.groups() + else: + raise ValueError(f"Weird ckpt path: {ckpt} ; Maybe prepend ./ ?") + + # Use the checkpoint filename to detect when we're loading old-style .npz + # checkpoints, as opposed to new-style tensorstore checkpoint folders. + if ".npz" in ckpt: # Not a perfect heuristic, but good enough. + checkpoint = load_checkpoint_np(ckpt, **kw) + checkpoint = jax.tree.map(recover_dtype, checkpoint) + if "params" in checkpoint: + # Checkpoint with optax state (after (internal link)). + params = checkpoint["params"] + elif "opt" in checkpoint: + # Checkpoint with Flax optimizer. + params = checkpoint["opt"]["target"] + else: + # When open-sourcing, we often shared only the params directly. + params = checkpoint + else: + # Here we're now loading new-style tensorstore checkpoints. + # We can be a more efficient and load params and `key` only right away. + regex = f"params/{key}($|/.*)" if key else "params/.*" + checkpoint = load_checkpoint_ts(ckpt, regex=regex) + params = checkpoint["params"] + + if key is not None: + params = tree_get(params, key) + + return params + + +def prefetch_scalar(it, nprefetch=1, devices=None): + n_loc_dev = len(devices) if devices else jax.local_device_count() + repl_iter = (np.ones(n_loc_dev) * i for i in it) + return flax_utils.prefetch_to_device(repl_iter, nprefetch, devices) + + +def sigmoid_xent(*, logits, labels, reduction=True): + # NOTE: This implementation is stable, see these two: + # (internal link) + # https://github.com/google/jax/issues/2140 + log_p = jax.nn.log_sigmoid(logits) + log_not_p = jax.nn.log_sigmoid(-logits) + nll = -jnp.sum(labels * log_p + (1. - labels) * log_not_p, axis=-1) + return jnp.mean(nll) if reduction else nll + + +def bidirectional_contrastive_loss(zimg, ztxt, t, mask=None, reduction=False): + """Bidirectional contrastive loss (e.g. for contrastive trainer/evaluator).""" + # BF.FB = BB + logits = jnp.dot(zimg, ztxt.T) * t + + if mask is not None: + # Set to negative infinity where mask = 0. Masked examples will disappear + # under softmax, and be ignored by ncorrect (NINF will never win argmax). + exclude = jnp.logical_not(mask) # Now 1 if we don't want to keep. + exclude = jnp.logical_or(exclude[:, None], exclude[None, :]) + logits = jnp.where(exclude, -jnp.inf, logits) + + # Note: assumed t is in a good range e.g. already passed through exp/softplus. + l1 = -jnp.diag(jax.nn.log_softmax(logits, axis=1)) # NLL img->txt + l2 = -jnp.diag(jax.nn.log_softmax(logits, axis=0)) # NLL txt->img + l = 0.5 * (l1 + l2) + + if mask is not None: + l = jnp.where(mask, l, 0) + + redux = jnp.mean if reduction else lambda x: x + if reduction and mask is not None: + redux = lambda x: jnp.sum(x * mask) / (jnp.sum(mask) + 1e-8) + + # Also return extra measurements. + return redux(l), { + "ncorrect": redux(jnp.argmax(logits, axis=1) == jnp.arange(len(logits))), + } + + +def softmax_xent(*, logits, labels, reduction=True, kl=False, axis=-1): + log_p = jax.nn.log_softmax(logits, axis=axis) + nll = -jnp.sum(labels * log_p, axis=axis) + if kl: + nll += jnp.sum(labels * jnp.log(jnp.clip(labels, 1e-8)), axis=axis) + return jnp.mean(nll) if reduction else nll + + +def weighted_softmax_xent(*, + logits, + labels, + reduction=True, + weights=None, + label_smoothing=0.0, + normalize=True): + """Compute weighted cross entropy. + + Args: + logits: [batch, length, num_classes] float array. + labels: categorical targets [batch, length] int array. + reduction: reduce across batch dim. + weights: None or array of shape [batch, length]. + label_smoothing: label smoothing constant, used to determine the on and off + values. + normalize: normalize each "sentence" loss by the number of tokens in it. + + Returns: + Tuple of scalar loss and batch normalizing factor. + """ + if logits.ndim != labels.ndim + 1: + raise ValueError("Incorrect shapes. Got shape %s logits and %s targets" % + (str(logits.shape), str(labels.shape))) + vocab_size = logits.shape[-1] + confidence = 1.0 - label_smoothing + low_confidence = (1.0 - confidence) / (vocab_size - 1) + soft_targets = onehot( + labels, vocab_size, on_value=confidence, off_value=low_confidence) + + loss = -jnp.sum(soft_targets * jax.nn.log_softmax(logits), axis=-1) + + normalizing_factor = labels.shape[1] + if weights is not None: + loss = loss * weights + normalizing_factor = jnp.clip(weights.sum(axis=1), 2e-38) + + loss = loss.sum(axis=1) + if normalize: + loss = loss / normalizing_factor + + return loss.mean() if reduction else loss + + +def accumulate_gradient(loss_and_grad_fn, params, images, labels, accum_steps): + """Accumulate gradient over multiple steps to save on memory.""" + # See (internal link) for details and experiments. + if accum_steps and accum_steps > 1: + assert images.shape[0] % accum_steps == 0, ( + f"Bad accum_steps {accum_steps} for batch size {images.shape[0]}") + step_size = images.shape[0] // accum_steps + l, g = loss_and_grad_fn(params, images[:step_size], labels[:step_size]) + def acc_grad_and_loss(i, l_and_g): + imgs = jax.lax.dynamic_slice(images, (i*step_size, 0, 0, 0), + (step_size,) + images.shape[1:]) + lbls = jax.lax.dynamic_slice(labels, (i*step_size, 0), + (step_size, labels.shape[1])) + li, gi = loss_and_grad_fn(params, imgs, lbls) + l, g = l_and_g + return (l + li, jax.tree.map(lambda x, y: x + y, g, gi)) + l, g = jax.lax.fori_loop(1, accum_steps, acc_grad_and_loss, (l, g)) + return jax.tree.map(lambda x: x / accum_steps, (l, g)) + else: + return loss_and_grad_fn(params, images, labels) + + +def itstime(step, every_n_steps, total_steps, host=None, last=True, first=True, + drop_close_to_last=0.25): + """Returns True if it's time to execute an action. + + Args: + step: the current step representing "now". + every_n_steps: the action should run every this many steps. + total_steps: the step number of the last step of training. + host: host number. If provided, only run if we are this process. + last: whether to run on the last step or not. + first: whether to run on the first step or not. + drop_close_to_last: if a step would run, but is this close (in terms of + fraction of every_n_step) to the last one, skip. + + Returns: + True if the action should be executed, False if not. + """ + + # This logic avoids running `itstime` "a few" steps before the last step. + # Canonical example: don't save checkpoint 2 steps before the last, and then + # at the last again; it's pointless and checkpoint timing will time out. + close_to_last = False + if drop_close_to_last and every_n_steps: + close_to_last = abs(step - total_steps) < drop_close_to_last * every_n_steps + + is_host = host is None or jax.process_index() == host + is_step = every_n_steps and (step % every_n_steps == 0) and not close_to_last + is_last = every_n_steps and step == total_steps + is_first = every_n_steps and step == 1 + return is_host and (is_step or (last and is_last) or (first and is_first)) + + +def checkpointing_timeout(writer, timeout): + # Make sure checkpoint writing is not a bottleneck + if writer is not None: + try: + # Note: `writer` is a multiprocessing.AsyncResult, and + # timeout is in seconds. + writer.get(timeout=timeout) + except multiprocessing.TimeoutError as e: + raise TimeoutError( + "Checkpoint writing seems to be a bottleneck. Make sure you do " + "not do something wrong, like writing checkpoints to a distant " + "cell. In a case you are OK with checkpoint writing being a " + "bottleneck, you can configure `ckpt_timeout` parameter") from e + + +def hms(s): + """Format time in hours/minutes/seconds.""" + if s < 60: + return f"{s:.0f}s" + m, s = divmod(s, 60) + if m < 60: + return f"{m:.0f}m{s:.0f}s" + h, m = divmod(m, 60) + if h < 25: + return f"{h:.0f}h{m:.0f}m" # Seconds intentionally omitted. + d, h = divmod(h, 24) + return f"{d:.0f}d{h:.0f}h{m:.0f}m" # Seconds intentionally omitted. + + +class Chrono: + """Measures time and reports progress, hyper-specific to our train loops. + + Some concepts: + 1. This differentiates between three "types" of time: + - training time: the time spent on actual training (fprop/bprop/update) + - program time: overall time the program runs, including all overheads + - pause time: the chronometer can be paused (eg during evals). + 2. This handles a "warmup": the first step is skipped for training time + purposes, as it includes significant compilation overheads, which distort + estimates. + 3. `accum`ulates (i.e. integrates) timings, and save/load them across + restarts. + """ + + def __init__(self): + self._timing_history = collections.defaultdict(list) + self._measure = None + self._write_note = None + + self.program_start_time = time.monotonic() + self.train_start_time = None + self.train_start_step = None # When we started timing (after warmup) + + self.prev_time = None + self.prev_step = None + + self.pause_start = None + self.paused_time = 0 + + self.total_steps = None + self.global_bs = None + self.steps_per_epoch = None + + self.warmup = 2 # How many calls to `tick` to skip. + self.load() # Inits accum integrators. + self.note = "Chrono n/a" + + def inform(self, *, first_step=None, total_steps=None, global_bs=None, + steps_per_epoch=None, measure=None, write_note=None): + """Provide some extra info that's only known later in the program.""" + # The pattern of `self.x = x or self.x` allows one to call `inform` various + # times with various subset of information (args), as they become available. + # Except for `first_step` which can be 0 so is a bit more verbose. + self.prev_step = first_step if first_step is not None else self.prev_step + self.total_steps = total_steps or self.total_steps + self.steps_per_epoch = steps_per_epoch or self.steps_per_epoch + self.global_bs = global_bs or self.global_bs + self._measure = measure or self._measure + self._write_note = write_note or self._write_note + if self.total_steps and self.prev_step is not None: + self.note = (f"Steps:{self.prev_step}/{self.total_steps} " + f"[{self.prev_step/self.total_steps:.1%}]") + + def tick(self, step, measure=None, write_note=None): + """A chronometer tick.""" + if step == self.prev_step: return # Can happen from evals for example. + + measure = measure or self._measure + write_note = write_note or self._write_note + + now = time.monotonic() + measure("uptime", now - self.program_start_time) + self.flush_timings() + + # We do always count examples, regardless of the timing-related warmup that + # happens a few lines below. + ds = step - self.prev_step # Steps between ticks + self.prev_step = step + self.accum_examples_seen += ds * self.global_bs + measure("examples_seen", self.accum_examples_seen) + measure("progress", step / self.total_steps) + if self.steps_per_epoch: + measure("epoch", step / self.steps_per_epoch) + + # We take the start as the second time `tick` is called, so we avoid + # measuring the overhead of compilation and don't include it in time + # estimates. + if self.warmup > 1: + self.warmup -= 1 + write_note(self.note) # This can help debugging. + return + if self.warmup == 1: + self.train_start_time = self.prev_time = now + self.train_start_step = step + self.accum_program_time += now - self.program_start_time + self.paused_time = 0 # Drop pauses that happened before timing starts. + self.warmup = 0 + write_note(self.note) # This can help debugging. + return + + # Measurement with micro-timings of current training steps speed. + # Time between ticks (ignoring pause) + dt = now - self.prev_time - self.paused_time + ncores = jax.device_count() # Global device count + measure("img/sec/core", self.global_bs * ds / dt / ncores) + + # Accumulate (integrate) times, good for plots. + self.accum_train_time += dt + self.accum_pause_time += self.paused_time + self.accum_program_time += dt + self.paused_time + + # Convert to, and log as, core hours. + core_hours = self.accum_train_time * ncores / 60 / 60 + devtype = jax.devices()[0].device_kind + measure(f"core_hours_{devtype}", core_hours) + measure("core_hours", core_hours) # For convenience as x-axis in sweeps. + + # Progress note with "global" full-program average timings + # (eg in program-time minus warmup) + dt = now - self.train_start_time # Time elapsed since end of warmup. + steps_timed = step - self.train_start_step + steps_todo = self.total_steps - step + self.note = f"Steps:{step}/{self.total_steps} [{step/self.total_steps:.1%}]" + self.note += f"\nWalltime:{hms(self.accum_program_time)}" + self.note += f" ({hms(self.accum_pause_time)} eval)" + self.note += f"\nETA:{hms(dt / steps_timed*steps_todo)}" + self.note += f"\nTotal train time:{hms(dt / steps_timed*self.total_steps)}" + write_note(self.note) + + self.prev_time = now + self.paused_time = 0 + + def pause(self, wait_for=()): + assert self.pause_start is None, "Don't pause twice." + jax.block_until_ready(wait_for) + self.pause_start = time.monotonic() + + def resume(self): + self.paused_time += time.monotonic() - self.pause_start + self.pause_start = None + + def save(self): + return dict( + accum_program_time=self.accum_program_time, + accum_train_time=self.accum_train_time, + accum_pause_time=self.accum_pause_time, + accum_examples_seen=self.accum_examples_seen, + ) + + def load(self, ckpt={}): # pylint: disable=dangerous-default-value + self.accum_program_time = float(ckpt.get("accum_program_time", 0.0)) + self.accum_train_time = float(ckpt.get("accum_train_time", 0.0)) + self.accum_pause_time = float(ckpt.get("accum_pause_time", 0.0)) + self.accum_examples_seen = int(ckpt.get("accum_examples_seen", 0)) + + @contextlib.contextmanager + def log_timing(self, name, *, noop=False): + """Use this when you time sth once per step and want instant flushing.""" + t0 = time.monotonic() + yield + dt = time.monotonic() - t0 + if not noop: + if self._measure: # So that timed things still work in colab. + self._measure(name, dt) + logging.info("TIMING[%s]: %s", name, dt) + logging.flush() + + @contextlib.contextmanager + def log_timing_avg(self, name, *, noop=False): + """Use this when you time sth multiple times per step (eg in a loop).""" + t0 = time.monotonic() + yield + dt = time.monotonic() - t0 + if not noop: + self._timing_history[name].append(dt) + logging.info("TIMING[%s]: avg %s current %s", + name, np.mean(self._timing_history[name]), dt) + logging.flush() + + def flush_timings(self): + assert self._measure is not None + for name, times in self._timing_history.items(): + self._measure(name, np.mean(times)) + self._timing_history.clear() + + +# Singleton to use from everywhere. https://stackoverflow.com/a/6760726/2366315 +chrono = Chrono() + + +def _traverse_with_names(tree, with_inner_nodes=False): + """Traverses nested dicts/dataclasses and emits (leaf_name, leaf_val).""" + if dataclasses.is_dataclass(tree): + tree = flax.serialization.to_state_dict(tree) + # Don't output the non-leaf nodes. If the optimizer doesn't have a state + # the tree leaves can be Nones which was interpreted as a leaf by this + # function but not by the other functions (like jax.tree.map). + if tree is None: + return + elif isinstance(tree, Mapping): + keys = sorted(tree.keys()) + for key in keys: + for path, v in _traverse_with_names(tree[key], with_inner_nodes): + yield (key + "/" + path).rstrip("/"), v + if with_inner_nodes: + yield "", tree + elif isinstance(tree, (list, tuple)): + for idx in range(len(tree)): + for path, v in _traverse_with_names(tree[idx], with_inner_nodes): + yield (str(idx) + "/" + path).rstrip("/"), v + if with_inner_nodes: + yield "", tree + else: + yield "", tree + + +def tree_flatten_with_names(tree): + """Populates tree_flatten with leaf names. + + This function populates output of tree_flatten with leaf names, using a + custom traversal that produces names is provided. The custom traversal does + NOT have to traverse tree in the same order as jax, as we take care of + automatically aligning jax' and custom traversals. + + Args: + tree: python tree. + + Returns: + A list of values with names: [(name, value), ...] + """ + vals, tree_def = jax.tree.flatten(tree) + + # "Fake" token tree that is use to track jax internal tree traversal and + # adjust our custom tree traversal to be compatible with it. + tokens = range(len(vals)) + token_tree = tree_def.unflatten(tokens) + val_names, perm = zip(*_traverse_with_names(token_tree)) + inv_perm = np.argsort(perm) + + # Custom traverasal should visit the same number of leaves. + assert len(val_names) == len(vals) + + return [(val_names[i], v) for i, v in zip(inv_perm, vals)], tree_def + + +def tree_unflatten(names_and_vals): + """Reverses `tree_flatten_with_names(tree)[0]`.""" + return recover_tree(*zip(*names_and_vals)) + + +def tree_map_with_names(f, tree, *rest): + """Like jax.tree.map but with a filter on the leaf path name. + + Args: + f: A function with first parameter `name` (path-like "a/b/c") and remaining + parameters values of `tree` and `*rest` corresponding to the given `name` + Should return a new value for parameter `name`. + tree: The tree of parameters `f` should be applied to. + *rest: more trees of the exact same structure. + + Returns: + A tree identical in structure to `tree` and `*rest` but with the leaves the + result of calling `f` on corresponding name/leaves in `tree` and `*rest`. + """ + names_and_vals, tree_def = tree_flatten_with_names(tree) + names, vals = zip(*names_and_vals) + rest_vals = [list(zip(*tree_flatten_with_names(t)[0]))[1] for t in rest] + vals = [f(*name_and_vals) for name_and_vals in zip(names, vals, *rest_vals)] + return tree_def.unflatten(vals) + + +def tree_map_with_regex(f, tree, regex_rules, not_f=lambda x: x, name=None): + """Apply jax-style tree_map based on regex rules. + + Args: + f: a function that is being applied to every variable. + tree: jax tree of arrays. + regex_rules: a list of tuples `(pattern, args)`, where `pattern` is a regex + which used for variable matching and `args` are positional arguments + passed to `f`. If some variable is not matched, we apply `not_f` transform + which is id by default. If multiple patterns match, then only the first + rule is applied. + not_f: optional function which is applied to variables that do not match any + pattern. + name: a name of transform for logging purposes. + + Returns: + a tree, transformed by `f` according to the given rules. + """ + def _f(vname, v): + for pattern, arg in regex_rules: + if re.fullmatch(pattern, vname): + if name and jax.process_index() == 0: + logging.info("Applying %s to %s with %s due to `%s`", + name, vname, arg, pattern) + return f(v, arg) + return not_f(v) + return tree_map_with_names(_f, tree) + + +def tree_get(tree, name): + """Get an entry of pytree by flattened key name, eg a/b/c, with nice error. + + Args: + tree: the pytree to be queried. + name: the path to extract from the tree, see below for examples. + + Returns: + A few examples: + tree = {'a': 1, 'b': {'c': 2, 'd': 3}} + tree_get(tree, 'a') == 1 + tree_get(tree, 'b/c') == 2 + tree_get(tree, 'b') == {'c': 2, 'd': 3} + """ + flattened = dict(_traverse_with_names(tree, with_inner_nodes=True)) + try: + return flattened[name] + except KeyError as e: + class Msg(str): # Reason: https://stackoverflow.com/a/70114007/2366315 + def __repr__(self): + return str(self) + msg = "\n".join([name, "Available keys:", *flattened, ""]) + # Turn into configdict to use its "did you mean?" error message! + msg = mlc.ConfigDict(flattened)._generate_did_you_mean_message(name, msg) # pylint: disable=protected-access + raise KeyError(Msg(msg)) from e + + +def tree_replace(tree, replacements): + """Renames/removes (nested) keys. + + Example usage: + + tree = {'a': {'b': 2, 'c': 3}, 'c': 4} + replacements = { + 'a/b': 'a/b/x', # replaces 'a/b' with 'a/b/x' + '.*c': 'C', # replaces 'c' with 'C' ('a/c' is removed) + 'C': 'D', # replaces 'C' (which was 'c') with 'D' + '.*/c': None, # removes 'a/c' + } + tree2 = rename_remove(tree, replacements) + assert tree2 == {'D': 4, 'a': {'b': {'x': 2}}} + + Args: + tree: A nested dictionary. + replacements: Rules specifying `regex` as keys and `replacement` as values + to be used with `m = re.match(regex, key)` and `m.expand(replacement)` + for every `key` independently. + + Note that: + 1. If any rule matches with `replacement=None`, then the key is removed. + 2. The rules are applied in order. It's possible to have multiple + transformations on a single key. + + Returns: + Updated `tree` according to rules defined in `replacements`. + """ + replacements = { + re.compile(kk): vv for kk, vv in replacements.items() + } + + def rename(k): + for kk, vv in replacements.items(): + m = kk.match(k) + if m: + k = k[:m.start()] + m.expand(vv) + k[m.end():] + return k + + def should_remove(k): + return any(vv is None and kk.match(k) for kk, vv in replacements.items()) + + names_and_vals, _ = tree_flatten_with_names(tree) + names_and_vals = [ + (rename(k), v) for k, v in names_and_vals if not should_remove(k) + ] + return tree_unflatten(names_and_vals) + + +def tree_compare(tree1, tree2): + """Returns `(tree1_only, tree2_only, dtype_shape_mismatch)`.""" + tree1 = flax.traverse_util.flatten_dict(tree1, sep="/") + tree2 = flax.traverse_util.flatten_dict(tree2, sep="/") + return set(tree1) - set(tree2), set(tree2) - set(tree1), { + k: [(v.dtype, v.shape), (tree2[k].dtype, tree2[k].shape)] + for k, v in tree1.items() + if k in tree2 and (v.dtype != tree2[k].dtype or v.shape != tree2[k].shape) + } + + +def tree_filter(tree, mask): + """Returns nested dict structure with only a subset of children.""" + # TODO: The code below only works for nested-dict and only when they + # have same structure. Consider relax this. + if not isinstance(tree, dict): + assert isinstance(mask, bool), f"Mask leaves must be boolean! {mask}" + return tree + assert sorted(tree.keys()) == sorted(mask.keys()), ( + f"Keys in tree and mask are not equal! {tree.keys()} != {mask.keys()}") + return {k: tree_filter(v, mask[k]) for k, v in tree.items() + if mask[k] is not False} + + +def recover_dtype(a): + """Numpy's `save` stores bfloat16 type as "void" type, so we recover it.""" + if hasattr(a, "dtype") and a.dtype.type is np.void: + assert a.itemsize == 2, "Unknown dtype!" + return a.view(jax.numpy.bfloat16) + else: + return a + + +def recover_tree(keys, values): + """Recovers a tree as a nested dict from flat names and values. + + This function is useful to analyze checkpoints that are saved by our programs + without need to access the exact source code of the experiment. In particular, + it can be used to extract an reuse various subtrees of the scheckpoint, e.g. + subtree of parameters. + + Args: + keys: a list of keys, where '/' is used as separator between nodes. + values: a list of leaf values. + + Returns: + A nested tree-like dict. + """ + tree = {} + sub_trees = collections.defaultdict(list) + for k, v in zip(keys, values): + if "/" not in k: + tree[k] = v + else: + k_left, k_right = k.split("/", 1) + sub_trees[k_left].append((k_right, v)) + for k, kv_pairs in sub_trees.items(): + k_subtree, v_subtree = zip(*kv_pairs) + tree[k] = recover_tree(k_subtree, v_subtree) + return tree + + +def tssave(mngr, pytree, path, on_commit=lambda *_, **__: None): + """Save pytree using jax tensorstore-based checkpoint manager. + + NOTE: When overwriting an existing checkpoint with a different pytree, the + result is, counterintuitively, the union of both, not only the new one. + + Args: + mngr: An instance of GlobalAsyncCheckpointManager. + pytree: What to store; any pytree of arrays. + path: Where to save the pytree. Creates subfolders as needed. + on_commit: A callback when writing is done, see `mngr.serialize`. + """ + names, vals = zip(*tree_flatten_with_names(pytree)[0]) + + for name in names: + if "~" in name: + raise ValueError(f"Symbol '~' is not allowed in names. Found in {name}.") + + gfile.makedirs(path) + with jax.transfer_guard("allow"): + names = [name.replace("/", "~") for name in names] + mngr.serialize_with_paths( + list(vals), [os.path.join(path, name) for name in names], + on_commit_callback=functools.partial(on_commit, array_names=names)) + + +def save_checkpoint_ts(mngr, checkpoint, path, step, keep=True): + """Preemption-safe saving of checkpoints using tssave.""" + # The tensorstore checkpoint format is a folder with (potentially) many files. + # On some file-systems, operations on these (copy, rename, delete) are slow, + # so we implement a flow that's both robust to pre-emptions/crashes during + # checkpointing and makes minimal use of these slow operations. + + # The logic goes as follows. It's infaillible :) + # (...if file move is atomic, which it is.) + # We always write the current checkpoint to a new folder, which contains the + # step number in its name. If we don't need to keep it indefinitely, we append + # "-tmp" to its name. + # After writing the next checkpoint, we remove the previous one if it had + # "-tmp" in its name. + # We also have a -LAST file that contains a pointer to the latest complete + # checkpoint. File operations are cheap to make atomic, that's why. + + def _on_commit_callback(array_names): # Runs after writing ckpt is done. + with gfile.GFile(f"{path}-CUR", "w") as f: + f.write(curr) + + last = "" + if gfile.exists(f"{path}-LAST"): + with gfile.GFile(f"{path}-LAST", "r") as f: + last = f.read().strip() + + gfile.rename(f"{path}-CUR", f"{path}-LAST", overwrite=True) + + if last.endswith("-tmp"): + # If pre-emption happens here, some old checkpoints may not be deleted. + multiprocessing.pool.ThreadPool().map( + gfile.rmtree, + [f"{path}-{last}/{name}" for name in array_names]) + gfile.rmtree(f"{path}-{last}") + + # NOTE: The jax checkpoint manager automatically waits for the previous save + # to be finished before writing again, so we don't need to do it here. + + # Always write to path with step number in it. + curr = f"{step:09d}{'-tmp' if not keep else ''}" + tssave(mngr, checkpoint, f"{path}-{curr}", _on_commit_callback) + + +def load_checkpoint_ts(path, **tsload_kw): + """Loads a big_vision checkpoint saved by `save_checkpoint_ts`.""" + to_load = path + + try: + # When passing a general path (not a specific step), get the last available. + with gfile.GFile(f"{path}-LAST", "r") as f: + to_load = f"{path}-{f.read().strip()}" + except Exception: # Differs based on backend, so blanket catch. pylint:disable=broad-exception-caught + pass + + return tsload(to_load, **tsload_kw) + + +def tsload(path, *, tree=None, shardings=None, regex=None): + """Loads tensorstore-based array-tree from disk. + + If `tree` argument is provided, then array names to load and target structure + is derived from the tree. If `tree` is None, then array names to load are + derived from array filenames on the disk, and, optionally, `regex` is applied + to filter these names. The`tree` argument is then automatically derived from + array names with `recover_tree` util. + + Arrays are loaded to CPU/TPU/GPU memory as specified by the `shardings` + argument, which is a pytree of CPU/TPU/GPU shardings (can be mixed within a + single pytree). `shardings` should a prefix tree of the `tree` argument. We + automatically broadcast `shardings` to a full `tree`. For example, a user can + specify `shardings=jax.sharding.SingleDeviceSharing(jax.devices('cpu')[0])`, + which will be broadcasted to a full tree. + + Args: + path: a directory where the checkpoint arrays are stored. + tree: a target pytree, which defines array names to load and the target tree + structure. If tree is None, then `tree` is inferred from the names of + arrays stored on the disk. + shardings: a prefix pytree (with respect to `tree`) of the target shardings. + regex: regex to filter array names from the disk, if `tree` is not provided. + + Returns: + A pytree of loaded arrays that has the same structure as `shardings` arg. + """ + if (tree is not None) and (regex is not None): + raise ValueError("If tree is specified, regex filtering is not allowed.") + + if tree is None: + # Some file-systems (gs://) list folders with a trailing /, get rid of it. + path_names = set([p.rstrip("/").replace("~", "/") + for p in gfile.listdir(path)]) + regex = re.compile(regex) if regex is not None else re.compile(".*") + path_names = [p for p in path_names if regex.match(p)] + tree = recover_tree(path_names, [0] * len(path_names)) + + names_and_vals, tree_def = tree_flatten_with_names(tree) + names_to_load, _ = zip(*names_and_vals) + + if shardings is None: + shardings = jax.sharding.SingleDeviceSharding( + jax.local_devices(backend="cpu")[0] + ) + shardings = list(jax.tree.leaves(tree_broadcast(shardings, tree))) + + names_to_load = [os.path.join(path, name.replace("/", "~")) + for name in names_to_load] + specs = [array_serial.get_tensorstore_spec(n) for n in names_to_load] + arrays = array_serial.run_deserialization(shardings, specs) + return tree_def.unflatten(arrays) + + +def steps(prefix, config, data_size=None, batch_size=None, total_steps=None, + default=ValueError): + """Gets duration named `prefix` out of `config` and converts it to steps. + + Using this function to access a configuration value that denotes some kind + of duration (eg training time, warmup, checkpoint frequency, ...) allows the + duration to be specified in terms of steps, epochs, examples, or percent of + training time, and converts any of these into steps, such that the training + code only deals with steps. + If the result is not an integer step number, it is rounded to the nearest one. + + Args: + prefix: The name of the duration to query. The actual config fields can + then be one of `prefix_steps`, `prefix_examples`, or `prefix_epochs`. + config: The dictionary (config) from which to read the duration. + data_size: The total number of training examples in one epoch. + batch_size: The number of examples processed per step. + total_steps: The total number of training steps to run. + default: The default value to return when no duration of the name `prefix` + is found in the `config`. Set to `ValueError` (the default) to raise an + error instead of returning a default value. + + Returns: + The number of steps from the config, or the default value. + + Raises: + ValueError if there is no such duration in the config and no default is set. + """ + # Be helpful and make sure only match one of the following suffixes. + suffixes = {"steps", "examples", "epochs", "percent"} + matches = {f"{prefix}_{s}" for s in suffixes if f"{prefix}_{s}" in config + and config[f"{prefix}_{s}"] is not None} + # Note that steps=0 is also a valid value (e.g. to only run evaluators). + assert len(matches) <= 1, f"Only one of '{matches}' should be defined." + + if f"{prefix}_steps" in config: + return config[f"{prefix}_steps"] + + def to_integer(x): + # Round to nearest but always executed at least one step unless explictily + # asked for 0. E.g. total_epochs=0 vs total_epochs=0.0001 + return max(1, round(x)) if x else 0 + + if batch_size and f"{prefix}_examples" in config: + return to_integer(config[f"{prefix}_examples"] / batch_size) + + if batch_size and data_size and f"{prefix}_epochs" in config: + steps_per_epoch = data_size / batch_size + return to_integer(config[f"{prefix}_epochs"] * steps_per_epoch) + + if total_steps and f"{prefix}_percent" in config: + pct = config[f"{prefix}_percent"] + assert 0.0 <= pct <= 1.0, ( # Be helpful, since it's not obvious. + f"Percents should lie in [0.0, 1.0], but {prefix}_percent is {pct}") + return to_integer(pct * total_steps) + + if default is ValueError: + raise ValueError( + f"Cannot convert {prefix} to steps, due to missing batch_size " + f"({batch_size}), data_size ({data_size}), total_steps ({total_steps})" + ", or corresponding entry in config:\n" + "\n".join(config.keys())) + + return default + + +def create_learning_rate_schedule( + total_steps, batch_size=None, data_size=None, + base=1.0, decay_type="stair", + scale_with_batchsize=False, **kw): + """Creates learning rate schedule, see (internal link). + + Args: + total_steps: The total number of steps to run. + batch_size: The global batch-size optionally used for scaling. + data_size: Number of examples in the training data (for epoch conversion). + base: The starting learning-rate (without warmup). + decay_type: 'linear' or 'cosine', 'rsqrt', 'stair'. + scale_with_batchsize: Whether or not to scale lr automatically. + **kw: extra arguments specific to individual decay_types. Also contains + declaration of `{warmup,cooldown}_{steps,epochs,examples}` that applies + on top of any/all decay_type. + + Returns: + A function learning_rate(step): float -> {"learning_rate": float}. + """ + + warmup_steps = steps( + "warmup", kw, data_size, batch_size, total_steps, default=0) + cooldown_steps = steps( + "cooldown", kw, data_size, batch_size, total_steps, default=0) + + # Early catch hard to backtrack errors due to warmup_steps >= total_steps, + # but let it run for 0 and 1 steps used to eval and debug runs. + assert (total_steps <= 1) or (warmup_steps < total_steps), ( + "warmup_steps is >= total_steps") + + def step_fn(step): + """Step to learning rate function.""" + lr = base + + # This implements the linear scaling rule following + # Goyal et al. at arxiv.org/abs/1706.02677. + # The reference batch size in literature is 256, so we scale the lr to + # adjust to the literature lr when bach_size changes. + if scale_with_batchsize: + lr = lr * batch_size / 256.0 + + progress = (step - warmup_steps) / float(total_steps - warmup_steps) + progress = jnp.clip(progress, 0.0, 1.0) + if decay_type in ("linear", "polynomial"): + power = kw.get("power", 1) + zero = kw.get("end", kw.get("linear_end", 0)) + lr = zero + (lr - zero) * (1.0 - progress) ** power + elif decay_type == "cosine": + lr = lr * 0.5 * (1. + jnp.cos(jnp.pi * progress)) + elif decay_type == "rsqrt": + # See (internal link) for details, especially how to set timescale + # and shift in order to continue smoothly when changing batch-size. + if "timescale_examples" in kw: + t = kw["timescale_examples"] / batch_size + else: + t = kw.get("timescale", 10_000) # bwd-compat default. + shift = kw.get("shift", 0) + lr = jnp.where( + warmup_steps <= step, + lr / jnp.sqrt(1 + (step + shift - warmup_steps) / t), # In decay + lr / jnp.sqrt(1 + shift / t)) # In warmup. + elif decay_type == "stair": + i = jnp.searchsorted(jnp.array(kw.get("steps", [])), step + 1) + lr = lr * jnp.take(jnp.array([1.0] + list(kw.get("mults", []))), i) + else: + raise ValueError(f"Unknown lr type {decay_type}") + + if warmup_steps: + lr = lr * jnp.minimum(1., step / warmup_steps) + if cooldown_steps: + lr = lr * jnp.minimum(1., (total_steps - step) / cooldown_steps) + + return jnp.asarray(lr, dtype=jnp.float32) + + return step_fn + + +def get_mixup(rng, p): + """Perform mixup https://arxiv.org/abs/1710.09412.""" + rng, rng_mixup = jax.random.split(rng) + a = jax.random.beta(rng_mixup, p, p) + a = jnp.maximum(a, 1.0 - a) # see (internal link) for the context. + def _mixup(*things, **more_things): + mix = lambda thing: a * thing + (1 - a) * jnp.roll(thing, shift=1, axis=0) + return rng, *jax.tree.map(mix, (things, more_things)) + return _mixup + + +# For backwards compatability with legacy code. +def mixup(rng, *things, p, **more_things): + return get_mixup(rng, p)(*things, **more_things) + + +def sync(): + """Syncs hosts and empties async computation queue.""" + x = reshard(np.ones(jax.device_count()), + jax.sharding.PositionalSharding(jax.devices())) + jax.jit(jnp.sum)(x).block_until_ready() + + +def check_and_compile_patterns(patterns): + """Validates and compiles a list of param-patterns. + + The validation consists of checking for common mistakes, currently only that + the pattern does not start with a slash, because unlike FLAX, our parameter + names don't start with a slash. + + Args: + patterns: a single (string) pattern (regex), or a list of patterns. + + Returns: + A list of compiled and verified regexes. + """ + if isinstance(patterns, str): + patterns = [patterns] + + assert isinstance(patterns, (list, tuple)), patterns + + def check_and_compile(pattern): + assert not pattern.startswith("/"), ( + f"Big vision parameter names never start with '/': '{pattern}") + return re.compile(pattern) + + return list(map(check_and_compile, patterns)) + + +def make_mask_trees(tree, patterns, *, log=None): + """Returns a boolean mask tree for every pattern (only first match).""" + compiled_patterns = check_and_compile_patterns(patterns) + + def matchfirst(name, _): + matches = [] + for pattern in compiled_patterns: + matches.append(not any(matches) and bool(pattern.fullmatch(name))) + if log is not None and True in matches and jax.process_index() == 0: + logging.info("%s: %s - matched by %s", log, name, + patterns[matches.index(True)]) + return np.array(matches) + + multimask = tree_map_with_names(matchfirst, tree) + return [ + jax.tree.map(lambda matches, i=idx: matches[i], multimask) + for idx in range(len(patterns)) + ] + + +@contextlib.contextmanager +def profile(name, ttl=3 * 365 * 24 * 3600, noop=False): + if not noop: + sess = startstop_prof_at_steps(None, name=name, ttl=ttl) + yield + if not noop: + startstop_prof_at_steps(sess, name=name, ttl=ttl) + + +def startstop_prof(sess, step=None, first_step=0, + log_steps=1, surround=20, **kw): + """Runs the profiler for `surround` steps around the next `log_steps`.""" + first_log = first_step + log_steps - (first_step % log_steps) + # don't start before first! + start = max(first_log - surround//2, first_step + 1) + return startstop_prof_at_steps(sess, step, start, start + surround, **kw) + + +def startstop_prof_at_steps( + sess, step=None, first_step=None, last_step=None, + name="steps", ttl=3 * 365 * 24 * 3600): + del sess, step, first_step, last_step, name, ttl + pass # TODO: implement using `jax.profiler` API. Needs workdir. + + +# This is a very minimal variant for open-sourcing. Our internal code makes use +# of multiple internal logging tools instead. +class BigVisionMetricWriter: + """A class for logging metrics.""" + + def __init__(self, xid=-1, wid=-1, workdir=None, config=None): + self.step_start(0) + if jax.process_index() != 0: return # Only one host shall write stuff. + + self.pool = multiprocessing.pool.ThreadPool(1) # 1 is important here. + self.fname = None + if workdir: + if xid != -1 and wid != -1: + self.fname = os.path.join(workdir, + f"big_vision_{xid}_{wid}_metrics.txt") + else: + self.fname = os.path.join(workdir, "big_vision_metrics.txt") + if config: + with gfile.GFile(os.path.join(workdir, "config.json"), "w") as f: + f.write(config.to_json()) + + def step_start(self, step): + self.step = step + self.step_metrics = {} + + def measure(self, name, value): + """Logs the metric value.""" + if jax.process_index() != 0: return # Only one host shall write stuff. + + # Convenience for accepting scalar np/DeviceArrays, as well as N-d single + # scalars, like [[[123]]] or similar, avoiding silly mistakes. + value = np.array(value).squeeze() + + # If the value is a scalar, we keep it in mind to append a line to the logs. + # If it has any structure, we instead just log its shape. + value = float(value) if value.ndim == 0 else value.shape + + logging.info(f"\u001b[35m[{self.step}]\u001b[0m {name} = {value}") + logging.flush() + self.step_metrics[name] = value + + return value # Just for convenience + + def step_end(self): + """Ends a training step, write its full row.""" + if not self.step_metrics: return + + def write(metrics): + with gfile.GFile(self.fname, "a") as f: + f.write(json.dumps({"step": self.step, **metrics}) + "\n") + + if self.fname: + self.pool.apply(lambda: None) # Potentially wait for past writes. + self.pool.apply_async(write, (self.step_metrics,)) + + def close(self): + self.step_end() + if jax.process_index() == 0: + self.pool.close() + self.pool.join() + + +def maybe_cleanup_workdir(workdir, cleanup, info): + """Potentially removes workdirs at end of run for cleanup.""" + if not workdir: + return + + if not cleanup: + info("Logs/checkpoints are in %s", workdir) + elif jax.process_index() == 0: + gfile.rmtree(workdir) + try: # Only need this on the last work-unit, if already empty. + gfile.remove(os.path.join(workdir, "..")) + except tf.errors.OpError: + pass + + +def tree_broadcast(prefix, target): + """Broadcasts a prefix tree to a full tree. + + Input-output examples: + 1. prefix: {"x": 10, "y": 20} + target: {"x": {"a": 1, "b": 2}, "y": 3} + + Result: {"x": {"a": 10, "b": 10}, "y": 20} + + 2. prefix: 100 + target: {"x": {"a": 1, "b": 2}, "y": 3} + + Result: {"x": {"a": 100, "b": 100}, "y": 100} + + 3. prefix: {"x": 10} + target: {"x": {"a": 1, "b": 2}, "y": 3} + + Result: ValueError + + Args: + prefix: prefix pytree. + target: boradcast target for a prefix tree. + + Returns: + prefix tree broadcasted to a target tree. + """ + def _broadcast(leaf, subtree): + return jax.tree.map(lambda _: leaf, subtree) + return jax.tree.map(_broadcast, prefix, target) + + +def reshard(tree, shardings): + """Take an arbitrarily* sharded pytree and shard it according to `shardings`. + + This is a no-op for tree elements which are already sharded as requested. + + *Arrays that are fully addressable (for example, CPU arrays) are assumed to be + identical (i.e. replicated) across hosts. + + *It does not work if an element of `tree` is not fully-addressable, unless its + sharding is already consistent with the target sharding. + If this is needed, please ping lbeyer@ or akolesnikov@. + + Args: + tree: a pytree of arrays. + shardings: a (prefix) pytree of jax array shardings. + Returns: + A pytree of global jax arrays that follows provided shardings. + """ + def _make_global_arr(x, shard, shape): + # Avoid unnecessary copies and transfers: + if hasattr(x, "sharding") and x.sharding.is_equivalent_to(shard, len(shape)): # pylint: disable=line-too-long + return x + if not getattr(x, "is_fully_addressable", True): + raise RuntimeError("Trying to reshard a non-fully-addressable array. " + "Please see the doc-comment for detailed explanation.") + x = jax.device_get(x) # Might be on local devices. + xs = [jax.device_put(x[s], device=d) + for d, s in shard.addressable_devices_indices_map(shape).items()] + return jax.make_array_from_single_device_arrays(shape, shard, xs) + + shapes = jax.tree.map(np.shape, tree) + shardings = tree_broadcast(shardings, tree) + return jax.tree.map(_make_global_arr, tree, shardings, shapes) + + +def put_cpu(x): + """Places array/pytree on a CPU device.""" + return jax.device_put(x, jax.local_devices(backend="cpu")[0]) + + +def make_fsarray_from_local_slice(local_slice, global_devices): + """Create a fully-sharded global device array from local host arrays. + + Args: + local_slice: Something convertible to a numpy array (eg also TF tensors) + that is this host's slice of the global array. + global_devices: The list of global devices. Needed for consistent ordering. + + Returns: + The global on-device array which consists of all local slices stacked + together in the order consistent with the devices. + """ + mesh = jax.sharding.Mesh(global_devices, ("devices",)) + sharding = jax.sharding.NamedSharding( + mesh, jax.sharding.PartitionSpec("devices")) + local_ds = mesh.local_devices + + x = np.asarray(memoryview(local_slice)) # No-copy: http://(internal link) + xs = jax.device_put(np.split(x, len(local_ds), axis=0), local_ds) + + global_shape = (x.shape[0] * jax.process_count(), *x.shape[1:]) + return jax.make_array_from_single_device_arrays(global_shape, sharding, xs) + + +def get_local_slice_from_fsarray(global_array): + """Return numpy array for the host-local slice of fully-sharded array. + + Args: + global_array: JAX array, globally sharded on devices across hosts. + + Returns: + NumPy array that holds the part of `global_array` that is held by the + devices on the host that calls this function. + """ + # For now, for simplicity, we only implement slicing along the first axis. + for shard in global_array.addressable_shards: + assert all(idx == slice(None) for idx in shard.index[1:]), ( + f"global_array is sharded along non-first dimensions:\n{shard.index}") + + # Get the shards back in the same order in which the global array was created + # in the first place. This makes sure it's consistent with other things in the + # batch, for example (assuming the whole batch is consistent). + m = {s.device: s for s in global_array.addressable_shards} + local_shards = [m[d] for d in global_array.sharding.mesh.local_devices] + return np.concatenate([jax.device_get(s.data) for s in local_shards], axis=0) + + +def assert_local_slices_same(*global_arrays): + """Check whether all `global_arrays` have local slices at the same indices.""" + slices = [ + tuple( + tuple((idx.start, idx.end, idx.step) for idx in s.index) + for s in a.addressable_shards) + for a in global_arrays] + assert len(set(slices)) == 1, f"Not all slices are the same: {slices}" + + +# TODO: remove this logic when the +# issue is github fixed https://github.com/google/jax/issues/15600. +def jit_cpu(**extra_kwargs): + def _decorator(fun): + def _wrapped(*args, **kwargs): + sh = jax.sharding.SingleDeviceSharding( + jax.local_devices(backend="cpu")[0] + ) + return jax.jit(fun, **extra_kwargs, out_shardings=sh)(*args, **kwargs) + return _wrapped + return _decorator diff --git a/big_vision/utils_test.py b/big_vision/utils_test.py new file mode 100644 index 0000000000000000000000000000000000000000..571a0d01d8de655021a30cc282f5f50d7bb2a9c0 --- /dev/null +++ b/big_vision/utils_test.py @@ -0,0 +1,357 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for utils.""" + +from functools import partial +import os + +from absl.testing import parameterized +from big_vision import utils +import chex +import flax +import jax +from jax.experimental.array_serialization import serialization as array_serial +import jax.numpy as jnp +import numpy as np +import tensorflow as tf + +from tensorflow.io import gfile + + +NDEV = 4 + + +def setUpModule(): + chex.set_n_cpu_devices(NDEV) + + +class PadShardUnpadTest(chex.TestCase, tf.test.TestCase): + BATCH_SIZES = [NDEV, NDEV + 1, NDEV - 1, 5 * NDEV, 5 * NDEV + 1, 5 * NDEV - 1] + DTYPES = [np.float32, np.uint8, jax.numpy.bfloat16, np.int32] + + def tearDown(self): + chex.clear_trace_counter() + super().tearDown() + + @parameterized.product(dtype=DTYPES, bs=BATCH_SIZES) + def test_basics(self, dtype, bs): + # Just tests that basic calling works without exploring caveats. + @partial(utils.pad_shard_unpad, static_argnums=()) + def add(a, b): + return a + b + + x = np.arange(bs, dtype=dtype) + y = add(x, 10*x) + chex.assert_type(y.dtype, x.dtype) + np.testing.assert_allclose(np.float64(y), np.float64(x + 10*x)) + + @parameterized.parameters(DTYPES) + def test_min_device_batch_avoids_recompile(self, dtype): + @partial(utils.pad_shard_unpad, static_argnums=()) + @jax.jit + @chex.assert_max_traces(n=1) + def add(a, b): + return a + b + + chex.clear_trace_counter() + + for bs in self.BATCH_SIZES: + x = np.arange(bs, dtype=dtype) + y = add(x, 10*x, min_device_batch=9) # pylint: disable=unexpected-keyword-arg + chex.assert_type(y.dtype, x.dtype) + np.testing.assert_allclose(np.float64(y), np.float64(x + 10*x)) + + @parameterized.product(dtype=DTYPES, bs=BATCH_SIZES) + def test_static_argnum(self, dtype, bs): + @partial(utils.pad_shard_unpad, static_argnums=(1,)) + def add(a, b): + return a + b + + x = np.arange(bs, dtype=dtype) + y = add(x, 10) + chex.assert_type(y.dtype, x.dtype) + np.testing.assert_allclose(np.float64(y), np.float64(x + 10)) + + @parameterized.product(dtype=DTYPES, bs=BATCH_SIZES) + def test_static_argnames(self, dtype, bs): + # In this test, leave static_argnums at the default value too, in order to + # test the default/most canonical path where `params` are the first arg. + @partial(utils.pad_shard_unpad, static_argnames=('b',)) + def add(params, a, *, b): + return params * a + b + + x = np.arange(bs, dtype=dtype) + y = add(5, x, b=10) + chex.assert_type(y.dtype, x.dtype) + np.testing.assert_allclose(np.float64(y), np.float64(5 * x + 10)) + + +class TreeTest(tf.test.TestCase): + + def setUp(self): + super().setUp() + + self.d1 = {'w1': 1, 'w2': 2, 'w34': (3, 4)} + self.d1_flat = [1, 2] + self.d1_flat_jax = jax.tree.flatten(self.d1)[0] + self.d1_named_flat = [('w1', 1), ('w2', 2), ('w34/0', 3), ('w34/1', 4)] + self.d1_named_flat_jax = [('w1', 1), ('w2', 2), ('w34/0', 3), ('w34/1', 4)] + + self.d2 = {'conv1': {'kernel': 0, 'bias': 1}, + 'conv2': {'kernel': 2, 'bias': 3}} + self.d2_flat = [1, 0, 3, 2] + self.d2_flat_jax = jax.tree.flatten(self.d2)[0] + self.d2_named_flat = [('conv1/bias', 1), ('conv1/kernel', 0), + ('conv2/bias', 3), ('conv2/kernel', 2)] + self.d2_named_flat_jax = [('conv1/bias', 1), ('conv1/kernel', 0), + ('conv2/bias', 3), ('conv2/kernel', 2)] + self.d2_named_flat_inner = [ + ('conv1/bias', 1), ('conv1/kernel', 0), ('conv1', self.d2['conv1']), + ('conv2/bias', 3), ('conv2/kernel', 2), ('conv2', self.d2['conv2']), + ('', self.d2), + ] + + # This is a very important testcase that checks whether we correctly + # recover jax' traversal order, even though our custom traversal may not + # be consistent with jax' traversal order. In particular, jax traverses + # FlaxStruct in the order of attribute definition, while our custom + # traversal is alphabetical. + @flax.struct.dataclass + class FlaxStruct(): + v3: float + v2: int + v1: str + self.d3 = {'a': 0, 'flax': FlaxStruct(2.0, 1, 's')} + self.d3_flat = [0, 1, 2.0, 's'] + self.d3_flat_jax = jax.tree.flatten(self.d3)[0] + self.d3_named_flat = [ + ('a', 0), ('flax/v1', 's'), ('flax/v2', 1), ('flax/v3', 2.0)] + self.d3_named_flat_jax = [ + ('a', 0), ('flax/v3', 2.0), ('flax/v2', 1), ('flax/v1', 's')] + + def test_traverse_with_names(self): + names_and_vals = list(utils._traverse_with_names(self.d1)) + self.assertEqual(names_and_vals, self.d1_named_flat) + + names_and_vals = list(utils._traverse_with_names(self.d2)) + self.assertEqual(names_and_vals, self.d2_named_flat) + + names_and_vals = list(utils._traverse_with_names( + self.d2, with_inner_nodes=True)) + self.assertEqual(names_and_vals, self.d2_named_flat_inner) + + names_and_vals = list(utils._traverse_with_names(self.d3)) + self.assertEqual(names_and_vals, self.d3_named_flat) + + def test_tree_flatten_with_names(self): + names_and_vals = utils.tree_flatten_with_names(self.d1)[0] + self.assertEqual(names_and_vals, self.d1_named_flat_jax) + self.assertEqual([x for _, x in names_and_vals], self.d1_flat_jax) + + names_and_vals = utils.tree_flatten_with_names(self.d2)[0] + self.assertEqual(names_and_vals, self.d2_named_flat_jax) + self.assertEqual([x for _, x in names_and_vals], self.d2_flat_jax) + + names_and_vals = utils.tree_flatten_with_names(self.d3)[0] + self.assertEqual(names_and_vals, self.d3_named_flat_jax) + self.assertEqual([x for _, x in names_and_vals], self.d3_flat_jax) + + def test_tree_map_with_names(self): + d1 = utils.tree_map_with_names( + lambda name, x: -x if 'w2' in name else x, self.d1) + self.assertEqual(d1, {'w1': 1, 'w2': -2, 'w34': (3, 4)}) + + d1 = utils.tree_map_with_names( + lambda name, x1, x2: x1 + x2 if 'w2' in name else x1, self.d1, self.d1) + self.assertEqual(d1, {'w1': 1, 'w2': 4, 'w34': (3, 4)}) + + def test_recover_tree(self): + keys = ['a/b', 'a/c/x', 'a/c/y', 'd'] + values = [0, 1, 2, 3] + self.assertEqual(utils.recover_tree(keys, values), + {'a': {'b': 0, 'c': {'x': 1, 'y': 2}}, 'd': 3}) + + def test_make_mask_trees(self): + F, T = False, True # pylint: disable=invalid-name + tree = {'a': {'b': 0, 'x': 1}, 'b': {'x': 2, 'y': 3}} + msk1 = {'a': {'b': F, 'x': T}, 'b': {'x': T, 'y': F}} + msk2 = {'a': {'b': F, 'x': F}, 'b': {'x': F, 'y': T}} + # Note that 'b' matches '^b' only and not '.*/b'. + # Also note that "b/x" is matched by rule 1 only (because it comes first). + self.assertEqual( + utils.make_mask_trees(tree, ('.*/x', 'b/.*')), [msk1, msk2]) + + def test_tree_get(self): + tree = {'a': {'b': 0, 'x': 1}, 'b': {'x': 2, 'y': 3}} + self.assertEqual(utils.tree_get(tree, 'a/b'), 0) + self.assertEqual(utils.tree_get(tree, 'a/x'), 1) + self.assertEqual(utils.tree_get(tree, 'b/x'), 2) + self.assertEqual(utils.tree_get(tree, 'b/y'), 3) + self.assertEqual(utils.tree_get(tree, 'a'), tree['a']) + self.assertEqual(utils.tree_get(tree, 'b'), tree['b']) + + def test_tree_replace(self): + tree = {'a': {'b': 2, 'c': 3}, 'c': 4} + replacements = { + 'a/b': 'a/b/x', # replaces 'a/b' with 'a/b/x' + '.*c': 'C', # replaces 'c' with 'C' ('a/c' is removed) + 'C': 'D', # replaces 'C' (which was 'c') with 'D' + '.*/c': None, # removes 'a/c' + } + tree2 = utils.tree_replace(tree, replacements) + self.assertEqual(tree2, {'D': 4, 'a': {'b': {'x': 2}}}) + + def test_tree_compare(self): + tree1_only, tree2_only, dtype_shape_mismatch = utils.tree_compare( + {'a': {'b': jnp.array(2), 'c': jnp.array(3)}}, + {'a': {'B': jnp.array(2), 'c': jnp.array(3.)}}, + ) + self.assertEqual(tree1_only, {'a/b'}) + self.assertEqual(tree2_only, {'a/B'}) + self.assertEqual( + dtype_shape_mismatch, + {'a/c': [(jnp.dtype('int32'), ()), (jnp.dtype('float32'), ())]}) + + +class StepConversionTest(parameterized.TestCase, tf.test.TestCase): + + @parameterized.named_parameters( + ('nice_steps', 1000, None, None, dict(foo_steps=3), 3), + ('nice_epochs', 1000, 100, None, dict(foo_epochs=3), 30), + ('nice_examples', None, 100, None, dict(foo_examples=300), 3), + ('nice_percent', None, None, 10, dict(foo_percent=0.30), 3), + ('offbyone_steps', 1001, None, None, dict(foo_steps=3), 3), + ('offbyone_epochs', 1001, 100, None, dict(foo_epochs=3), 30), + ('offbyone_examples', None, 101, None, dict(foo_examples=300), 3), + ('offbyone_percent', None, None, 11, dict(foo_percent=0.30), 3), + ) + def test_steps(self, data_size, batch_size, total, cfg, expected): + # Correct default usage: + step = utils.steps('foo', cfg, data_size=data_size, batch_size=batch_size, + total_steps=total) + self.assertEqual(step, expected) + + # Inexitent entry: + with self.assertRaises(ValueError): + step = utils.steps('bar', cfg, data_size=data_size, batch_size=batch_size, + total_steps=total) + step = utils.steps('bar', cfg, data_size=data_size, batch_size=batch_size, + total_steps=total, default=1234) + self.assertEqual(step, 1234) + + +class CreateLearningRateScheduleTest(parameterized.TestCase, tf.test.TestCase): + + @parameterized.named_parameters( + ('linear', 'linear', {}, 13, .5), + ('polynomial', 'polynomial', {'end': .1, 'power': 2}, 13, .325), + ('cosine', 'cosine', {}, 13, .5), + ('rsqrt', 'rsqrt', {'timescale': 1}, 13, 0.3333333), + ('stair_5', 'stair', {'steps': [10], 'mults': [.5]}, 5, 1.), + ('stair_10', 'stair', {'steps': [10], 'mults': [.5]}, 10, .5), + ('warmup_before', 'rsqrt', {'timescale': 1}, 3, .6), + ('cooldown_after', 'rsqrt', {'timescale': 1}, 20, .05), + ) + def test_schedule(self, decay_type, extra_kwargs, step, expected_lr): + lr_fn = utils.create_learning_rate_schedule( + total_steps=21, + batch_size=512, + base=.5, + decay_type=decay_type, + scale_with_batchsize=True, + warmup_steps=5, + cooldown_steps=5, + **extra_kwargs) + lr = lr_fn(step) + self.assertAlmostEqual(lr, expected_lr) + + +class CheckpointTest(tf.test.TestCase): + + def setup(self): + gacm = array_serial.GlobalAsyncCheckpointManager() + + save_path = os.path.join(self.create_tempdir('workdir'), 'checkpoint.bv') + x = utils.put_cpu(np.array([1, 2, 3, 4])) + y = utils.put_cpu(np.array([5, 6, 7, 8])) + ckpt = {'x': x, 'y': {'z': y}} + + sharding = jax.sharding.SingleDeviceSharding( + jax.local_devices(backend='cpu')[0] + ) + shardings = jax.tree.map(lambda _: sharding, ckpt) + + return gacm, save_path, ckpt, shardings + + def test_save_and_load(self): + gacm, save_path, ckpt, shardings = self.setup() + step = 100 + utils.save_checkpoint_ts(gacm, ckpt, save_path, step, keep=True) + gacm.wait_until_finished() + ckpt_loaded = utils.load_checkpoint_ts(save_path, + tree=ckpt, shardings=shardings) + chex.assert_trees_all_equal(ckpt_loaded, ckpt) + + save_path_step = f'{save_path}-{step:09d}' + ckpt_loaded_step = utils.tsload(save_path_step, shardings=shardings) + chex.assert_trees_all_equal(ckpt_loaded_step, ckpt) + + def test_save_and_partial_load(self): + gacm, save_path, ckpt, shardings = self.setup() + utils.save_checkpoint_ts(gacm, ckpt, save_path, step=100) + gacm.wait_until_finished() + _ = shardings.pop('x'), ckpt.pop('x') + ckpt_loaded = utils.load_checkpoint_ts(save_path, + tree=ckpt, shardings=shardings) + chex.assert_trees_all_equal(ckpt_loaded, ckpt) + + def test_save_and_cpu_load(self): + gacm, save_path, ckpt, _ = self.setup() + utils.save_checkpoint_ts(gacm, ckpt, save_path, step=100) + gacm.wait_until_finished() + ckpt_loaded = utils.load_checkpoint_ts(save_path) + chex.assert_trees_all_equal(ckpt_loaded, ckpt) + + def test_save_and_partial_cpu_load(self): + gacm, save_path, ckpt, _ = self.setup() + utils.save_checkpoint_ts(gacm, ckpt, save_path, step=100) + gacm.wait_until_finished() + ckpt.pop('y') + ckpt_loaded = utils.load_checkpoint_ts(save_path, regex='x.*') + chex.assert_trees_all_equal(ckpt_loaded, ckpt) + + def test_keep_deletes(self): + def x(tree, factor): # x as in "times" for multiplying. + return jax.tree.map(lambda a: a * factor, tree) + + gacm, save_path, ckpt, _ = self.setup() + utils.save_checkpoint_ts(gacm, ckpt, save_path, step=100, keep=False) + utils.save_checkpoint_ts(gacm, x(ckpt, 2), save_path, step=200, keep=True) + utils.save_checkpoint_ts(gacm, x(ckpt, 3), save_path, step=300, keep=False) + gacm.wait_until_finished() + ckpt_loaded_200 = utils.tsload(f'{save_path}-{200:09d}') + chex.assert_trees_all_equal(ckpt_loaded_200, x(ckpt, 2)) + ckpt_loaded_300 = utils.tsload(f'{save_path}-{300:09d}-tmp') + chex.assert_trees_all_equal(ckpt_loaded_300, x(ckpt, 3)) + ckpt_loaded_last = utils.load_checkpoint_ts(save_path) + chex.assert_trees_all_equal(ckpt_loaded_last, x(ckpt, 3)) + with self.assertRaises(Exception): # Can different types depending on fs. + _ = utils.tsload(f'{save_path}-{100:09d}') + # Test that ckpt@100 was deleted + self.assertFalse(gfile.exists(f'{save_path}-{100:09d}-tmp')) + + +if __name__ == '__main__': + tf.test.main() diff --git a/log.py b/log.py new file mode 100644 index 0000000000000000000000000000000000000000..def11cbe05a96286380e39ddac762f6dfb17122e --- /dev/null +++ b/log.py @@ -0,0 +1,234 @@ + +import os +import io +import jax +import base64 +import warnings +import functools +import numpy as np +import sentencepiece +import ml_collections +from PIL import Image +import big_vision.utils +import tensorflow as tf +import supervision as sv +import big_vision.sharding +from typing import Tuple, List, Optional +from big_vision.models.proj.paligemma import paligemma +from big_vision.trainers.proj.paligemma import predict_fns + + + +SEQLEN = 128 + +class PaliGemmaManager: + _instance = None + + def __new__(cls, *args, **kwargs): + if not cls._instance: + cls._instance = super(PaliGemmaManager, cls).__new__(cls) + return cls._instance + + def __init__(self, model, params, tokenizer): + self.model = model + self.params = params + self.tokenizer = tokenizer + self.decode_fn = None + self.decode = None + self.mesh = None + self.data_sharding = None + self.params_sharding = None + self.trainable_mask = None + + self.initialise_model() + + + def initialise_model(self): + self.decode_fn = predict_fns.get_all(self.model)['decode'] + self.decode = functools.partial(self.decode_fn, devices=jax.devices(), eos_token=self.tokenizer.eos_id()) + + def is_trainable_param(name, param): + if name.startswith("llm/layers/attn/"): return True + if name.startswith("llm/"): return False + if name.startswith("img/"): return False + raise ValueError(f"Unexpected param name {name}") + self.trainable_mask = big_vision.utils.tree_map_with_names(is_trainable_param, self.params) + + self.mesh = jax.sharding.Mesh(jax.devices(), ("data")) + + self.data_sharding = jax.sharding.NamedSharding( + self.mesh, jax.sharding.PartitionSpec("data")) + + self.params_sharding = big_vision.sharding.infer_sharding( + self.params, strategy=[('.*', 'fsdp(axis="data")')], mesh=self.mesh) + def preprocess_image(self,image, size=224): + image = np.asarray(image) + if image.ndim == 2: # Convert image without last channel into greyscale. + image = np.stack((image,)*3, axis=-1) + + image = image[..., :3] # Remove alpha layer. + assert image.shape[-1] == 3 + + image = tf.constant(image) + image = tf.image.resize(image, (size, size), method='bilinear', antialias=True) + return image.numpy() / 127.5 - 1.0 + + def preprocess_tokens(self, prefix, suffix=None, seqlen=None): + separator = "\n" + tokens = self.tokenizer.encode(prefix, add_bos=True) + self.tokenizer.encode(separator) + mask_ar = [0] * len(tokens) # 0 to use full attention for prefix. + mask_loss = [0] * len(tokens) # 0 to not use prefix tokens in the loss. + + if suffix: + suffix = self.tokenizer.encode(suffix, add_eos=True) + tokens += suffix + mask_ar += [1] * len(suffix) # 1 to use causal attention for suffix. + mask_loss += [1] * len(suffix) # 1 to use suffix tokens in the loss. + + mask_input = [1] * len(tokens) # 1 if its a token, 0 if padding. + if seqlen: + padding = [0] * max(0, seqlen - len(tokens)) + tokens = tokens[:seqlen] + padding + mask_ar = mask_ar[:seqlen] + padding + mask_loss = mask_loss[:seqlen] + padding + mask_input = mask_input[:seqlen] + padding + + return jax.tree.map(np.array, (tokens, mask_ar, mask_loss, mask_input)) + + def postprocess_tokens(self, tokens): + tokens = tokens.tolist() # np.array to list[int] + try: # Remove tokens at and after EOS if any. + eos_pos = tokens.index(self.tokenizer.eos_id()) + tokens = tokens[:eos_pos] + except ValueError: + pass + return self.tokenizer.decode(tokens) + + def split_and_keep_second_part(s): + parts = s.split('\n', 1) + if len(parts) > 1: + return parts[1] + return s + + def data_iterator(self, image_bytes, caption): + image = Image.open(io.BytesIO(image_bytes)) + image = self.preprocess_image(image) + tokens, mask_ar, _, mask_input = self.preprocess_tokens(caption, seqlen=SEQLEN) + + yield { + "image": np.asarray(image), + "text": np.asarray(tokens), + "mask_ar": np.asarray(mask_ar), + "mask_input": np.asarray(mask_input), + } + + def make_predictions(self, data_iterator, *, num_examples=None, + batch_size=4, seqlen=SEQLEN, sampler="greedy"): + outputs = [] + while True: + examples = [] + try: + for _ in range(batch_size): + examples.append(next(data_iterator)) + examples[-1]["_mask"] = np.array(True) # Indicates true example. + except StopIteration: + if len(examples) == 0: + return outputs + + + while len(examples) % batch_size: + examples.append(dict(examples[-1])) + examples[-1]["_mask"] = np.array(False) # Indicates padding example. + + + batch = jax.tree.map(lambda *x: np.stack(x), *examples) + batch = big_vision.utils.reshard(batch, self.data_sharding) + tokens = self.decode({"params": self.params}, batch=batch, + max_decode_len=seqlen, sampler=sampler) + + # Fetch model predictions to device and detokenize. + tokens, mask = jax.device_get((tokens, batch["_mask"])) + tokens = tokens[mask] # remove padding examples. + responses = [self.postprocess_tokens(t) for t in tokens] + + for example, response in zip(examples, responses): + outputs.append((example["image"], response)) + if num_examples and len(outputs) >= num_examples: + return outputs + + def process_result_to_bbox(self, image, caption, classes, w, h): + image = ((image + 1)/2 * 255).astype(np.uint8) # [-1,1] -> [0, 255] + + try: + detections = sv.Detections.from_lmm( + lmm='paligemma', + result=caption, + resolution_wh=(w, h), + classes=caption) + + xyxy = list(detections.xyxy[0]) + x1, y1, x2, y2 = xyxy[0], xyxy[1], xyxy[2], xyxy[3] #The number here could be result of 224x224 + width = x2 - x1 + height = y2 - y1 + output = [x1, y1, width, height] + except Exception as e: + print('Error detection') + print(e) + output = [0,0,0,0] + + return output + + def predict(self, image: bytes, caption: str) -> List[int]: + image_original = Image.open(io.BytesIO(image)) + original_width, original_height = image_original.size + if "detect" not in caption: + caption = f"detect {caption}" + # print("Making predictions...") + for image, response in self.make_predictions(self.data_iterator(image, caption), num_examples=1): + classes = response.replace("detect ", "") + output = self.process_result_to_bbox(image, response, classes, original_width, original_height) + + return (output, response) + + + + + + +INFERENCE_IMAGE = '3_(backup)AdityaBY_img_14.png' +INFERENCE_PROMPT = "A mother takes a picture of her daughter holding a colourful wind spinner in front of the entrance." + + + + +TOKENIZER_PATH = '/home/lyka/air/Paligemma/pali-package/pali_open_vocab_annotations_tokenizer.model' +MODEL_PATH = '/home/lyka/air/Paligemma/pali-package/pali_open_vocab_annotations_segmentation.npz' + + +model_config = ml_collections.FrozenConfigDict({ + "llm": {"vocab_size": 257_152}, + "img": {"variant": "So400m/14", "pool_type": "none", "scan": True, "dtype_mm": "float16"} +}) +model = paligemma.Model(**model_config) +tokenizer = sentencepiece.SentencePieceProcessor(TOKENIZER_PATH) + +# Load params - this can take up to 1 minute in T4 colabs. +params = paligemma.load(None, MODEL_PATH, model_config) +paligemma_manager = PaliGemmaManager(model, params, tokenizer) + +with open(INFERENCE_IMAGE, 'rb') as f: + image_bytes = f.read() + +output, response = paligemma_manager.predict(image_bytes, + INFERENCE_PROMPT) +image = Image.open(INFERENCE_IMAGE) +detections = sv.Detections.from_lmm( + lmm='paligemma', + result=response, + resolution_wh=image.size, + classes=response) + +coordinates = detections.xyxy[0] # assuming we want the first detection +x1, y1, x2, y2 = coordinates + +print('x1,y1,x2,y2:',coordinates) \ No newline at end of file diff --git a/pali_open_vocab_annotations_segmentation.npz b/pali_open_vocab_annotations_segmentation.npz new file mode 100644 index 0000000000000000000000000000000000000000..faba5abe78b9a1353be1be7c055d3c8c42f0cca2 --- /dev/null +++ b/pali_open_vocab_annotations_segmentation.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c4762c80454621cefbf9616fdc6f580ce991714c344ff125c54c0b1fd7e37776 +size 6186419752 diff --git a/pali_open_vocab_annotations_tokenizer.model b/pali_open_vocab_annotations_tokenizer.model new file mode 100644 index 0000000000000000000000000000000000000000..22ad048d392b4f5191be7cec47d03eb3cdb360f4 --- /dev/null +++ b/pali_open_vocab_annotations_tokenizer.model @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8986bb4f423f07f8c7f70d0dbe3526fb2316056c17bae71b1ea975e77a168fc6 +size 4264023 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..25683d98daa71b8ac8c624aa890d851a17d91ae4 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,261 @@ +absl-py==2.1.0 +anyio==4.4.0 +aqtp==0.7.5 +argon2-cffi==23.1.0 +argon2-cffi-bindings==21.2.0 +array_record==0.5.1 +arrow==1.3.0 +asttokens==2.4.1 +astunparse==1.6.3 +async-lru==2.0.4 +attrs==23.2.0 +Automat==20.2.0 +Babel==2.15.0 +bcrypt==3.2.0 +beautifulsoup4==4.12.3 +bleach==6.1.0 +blinker==1.4 +certifi==2020.6.20 +cffi==1.16.0 +chardet==4.0.0 +charset-normalizer==3.3.2 +chex==0.1.86 +click==8.0.3 +cloud-init==23.3.3 +cloudpickle==3.0.0 +clu @ git+https://github.com/google/CommonLoopUtils@b64aa2985d47a9d362131f60b34600c204e51fbd +colorama==0.4.4 +comm==0.2.2 +command-not-found==0.3 +configobj==5.0.6 +constantly==15.1.0 +contextlib2==21.6.0 +contourpy==1.2.1 +cryptography==3.4.8 +cycler==0.12.1 +dbus-python==1.2.18 +debugpy==1.8.2 +decorator==5.1.1 +defusedxml==0.7.1 +distrax==0.1.5 +distro==1.7.0 +distro-info==1.1+ubuntu0.2 +dm-tree==0.1.8 +docstring_parser==0.16 +editdistance==0.8.1 +einops==0.8.0 +et-xmlfile==1.1.0 +etils==1.7.0 +exceptiongroup==1.2.2 +executing==2.0.1 +fastjsonschema==2.20.0 +filelock==3.15.4 +filetype==1.2.0 +flatbuffers==24.3.25 +flax==0.8.5 +flaxformer @ git+https://github.com/google/flaxformer@399ea3a85e9807ada653fd0de1a9de627eb0acde +fonttools==4.53.1 +fqdn==1.5.1 +fsspec==2024.6.1 +gast==0.6.0 +google-pasta==0.2.0 +grpcio==1.65.1 +h11==0.14.0 +h5py==3.11.0 +httpcore==1.0.5 +httplib2==0.20.2 +httpx==0.27.0 +huggingface-hub==0.23.5 +hyperlink==21.0.0 +idna==3.7 +immutabledict==4.2.0 +importlib-metadata==4.6.4 +importlib_resources==6.4.0 +incremental==21.3.0 +ipykernel==6.29.5 +ipython==8.26.0 +ipython_genutils==0.2.0 +isoduration==20.11.0 +jax==0.4.30 +jax-cuda12-pjrt==0.4.30 +jax-cuda12-plugin==0.4.30 +jaxlib==0.4.30 +jedi==0.19.1 +jeepney==0.7.1 +Jinja2==3.0.3 +json5==0.9.25 +jsonpatch==1.32 +jsonpointer==2.0 +jsonschema==4.23.0 +jsonschema-specifications==2023.12.1 +jupyter-events==0.10.0 +jupyter-lsp==2.2.5 +jupyter_client==8.6.2 +jupyter_core==5.7.2 +jupyter_server==2.14.2 +jupyter_server_terminals==0.5.3 +jupyterlab==4.2.4 +jupyterlab_pygments==0.3.0 +jupyterlab_server==2.27.3 +keras==3.4.1 +keyring==23.5.0 +kiwisolver==1.4.5 +launchpadlib==1.10.16 +lazr.restfulclient==0.14.4 +lazr.uri==1.0.6 +libclang==18.1.1 +Markdown==3.6 +markdown-it-py==3.0.0 +MarkupSafe==2.1.5 +matplotlib==3.9.1 +matplotlib-inline==0.1.7 +mdurl==0.1.2 +mistune==3.0.2 +ml-collections==0.1.1 +ml-dtypes==0.4.0 +more-itertools==8.10.0 +mpmath==1.3.0 +msgpack==1.0.8 +namex==0.0.8 +nbclient==0.10.0 +nbconvert==7.16.4 +nbformat==5.10.4 +nest-asyncio==1.6.0 +netifaces==0.11.0 +networkx==3.3 +notebook==7.2.1 +notebook_shim==0.2.4 +numpy==1.26.4 +nvidia-cublas-cu12==12.1.3.1 +nvidia-cuda-cupti-cu12==12.1.105 +nvidia-cuda-nvcc-cu12==12.5.82 +nvidia-cuda-nvrtc-cu12==12.1.105 +nvidia-cuda-runtime-cu12==12.1.105 +nvidia-cudnn-cu12==9.2.1.18 +nvidia-cufft-cu12==11.0.2.54 +nvidia-curand-cu12==10.3.2.106 +nvidia-cusolver-cu12==11.4.5.107 +nvidia-cusparse-cu12==12.1.0.106 +nvidia-nccl-cu12==2.20.5 +nvidia-nvjitlink-cu12==12.5.82 +nvidia-nvtx-cu12==12.1.105 +oauthlib==3.2.0 +opencv-python-headless==4.10.0.84 +openpyxl==3.1.5 +opt-einsum==3.3.0 +optax==0.2.3 +optree==0.12.1 +orbax-checkpoint==0.5.22 +overrides==7.7.0 +packaging==24.1 +pandas==2.2.2 +pandocfilters==1.5.1 +panopticapi @ git+https://github.com/akolesnikoff/panopticapi.git@a698a12deb21e4cf0f99ef0581b2c30c466bf355 +parso==0.8.4 +pexpect==4.8.0 +pillow==10.4.0 +platformdirs==4.2.2 +prometheus_client==0.20.0 +promise==2.3 +prompt_toolkit==3.0.47 +protobuf==3.20.3 +psutil==6.0.0 +ptyprocess==0.7.0 +pure_eval==0.2.3 +pyarrow==17.0.0 +pyasn1==0.4.8 +pyasn1-modules==0.2.1 +pycocoevalcap==1.2 +pycocotools==2.0.8 +pycparser==2.22 +Pygments==2.18.0 +PyGObject==3.42.1 +PyHamcrest==2.0.2 +PyJWT==2.3.0 +pyOpenSSL==21.0.0 +pyparsing==2.4.7 +pyrsistent==0.18.1 +pyserial==3.5 +python-apt==2.4.0+ubuntu3 +python-dateutil==2.9.0.post0 +python-debian==0.1.43+ubuntu1.1 +python-dotenv==1.0.1 +python-json-logger==2.0.7 +python-magic==0.4.24 +pytz==2022.1 +PyYAML==5.4.1 +pyzmq==26.0.3 +referencing==0.35.1 +regex==2024.5.15 +requests==2.32.3 +requests-toolbelt==1.0.0 +rfc3339-validator==0.1.4 +rfc3986-validator==0.1.1 +rich==13.7.1 +roboflow==1.1.36 +rpds-py==0.19.1 +safetensors==0.4.3 +scipy==1.14.0 +screen-resolution-extra==0.0.0 +SecretStorage==3.3.1 +Send2Trash==1.8.3 +sentencepiece==0.2.0 +service-identity==18.1.0 +simple_parsing==0.1.5 +six==1.16.0 +sniffio==1.3.1 +sos==4.5.6 +soupsieve==2.5 +ssh-import-id==5.11 +stack-data==0.6.3 +supervision @ git+https://github.com/roboflow/supervision.git@ad2220bc1da2e018d1ce08685359eb02ab3c5bd4 +sympy==1.13.0 +systemd-python==234 +tensorboard==2.17.0 +tensorboard-data-server==0.7.2 +tensorflow==2.17.0 +tensorflow-cpu==2.17.0 +tensorflow-gan==2.1.0 +tensorflow-hub==0.16.1 +tensorflow-io-gcs-filesystem==0.37.1 +tensorflow-metadata==1.15.0 +tensorflow-probability==0.24.0 +tensorflow-text==2.17.0 +tensorstore==0.1.63 +termcolor==2.4.0 +terminado==0.18.1 +tf_keras==2.17.0 +tfds-nightly==4.9.6.dev202407220044 +tinycss2==1.3.0 +tokenizers==0.19.1 +toml==0.10.2 +tomli==2.0.1 +toolz==0.12.1 +torch==2.3.1 +torchaudio==2.3.1 +torchvision==0.18.1 +tornado==6.4.1 +tqdm==4.66.4 +traitlets==5.14.3 +transformers==4.42.4 +triton==2.3.1 +Twisted==22.1.0 +types-python-dateutil==2.9.0.20240316 +typing_extensions==4.12.2 +tzdata==2024.1 +ubuntu-advantage-tools==8001 +ubuntu-drivers-common==0.0.0 +ufw==0.36.1 +unattended-upgrades==0.1 +uri-template==1.3.0 +urllib3==2.2.2 +wadllib==1.3.6 +wcwidth==0.2.13 +webcolors==24.6.0 +webencodings==0.5.1 +websocket-client==1.8.0 +Werkzeug==3.0.3 +wrapt==1.16.0 +xkit==0.0.0 +zipp==1.0.0 +zope.interface==5.4.0